From 63cd4a0d21aab85678825240463cb6636ef27aff Mon Sep 17 00:00:00 2001 From: Akshat Shrivastava Date: Thu, 18 Sep 2025 07:03:39 +0000 Subject: [PATCH 01/77] initial isaac implementation --- src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 2 + .../models/auto/processing_auto.py | 1 + src/transformers/models/isaac/__init__.py | 28 + .../models/isaac/configuration_isaac.py | 141 ++ .../models/isaac/modeling_isaac.py | 1173 +++++++++++++ .../models/isaac/modular_isaac.py | 1467 +++++++++++++++++ .../models/isaac/processing_isaac.py | 509 ++++++ tests/models/isaac/__init__.py | 13 + 10 files changed, 3337 insertions(+) create mode 100644 src/transformers/models/isaac/__init__.py create mode 100644 src/transformers/models/isaac/configuration_isaac.py create mode 100644 src/transformers/models/isaac/modeling_isaac.py create mode 100644 src/transformers/models/isaac/modular_isaac.py create mode 100644 src/transformers/models/isaac/processing_isaac.py create mode 100644 tests/models/isaac/__init__.py diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 18d74ade4126..4377af520590 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -256,6 +256,7 @@ from .pegasus_x import * from .perceiver import * from .perception_lm import * + from .isaac import * from .persimmon import * from .phi import * from .phi3 import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index a9303913e861..d0a4d937f6f0 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -300,6 +300,7 @@ ("perceiver", "PerceiverConfig"), ("perception_encoder", "TimmWrapperConfig"), ("perception_lm", "PerceptionLMConfig"), + ("isaac", "IsaacConfig"), ("persimmon", "PersimmonConfig"), ("phi", "PhiConfig"), ("phi3", "Phi3Config"), @@ -744,6 +745,7 @@ ("perceiver", "Perceiver"), ("perception_encoder", "PerceptionEncoder"), ("perception_lm", "PerceptionLM"), + ("isaac", "Isaac"), ("persimmon", "Persimmon"), ("phi", "Phi"), ("phi3", "Phi3"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 571f654a9499..13e6ae32ab8f 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -299,6 +299,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("perceiver", "PerceiverModel"), ("perception_encoder", "PerceptionEncoder"), ("perception_lm", "PerceptionLMModel"), + ("isaac", "IsaacModel"), ("persimmon", "PersimmonModel"), ("phi", "PhiModel"), ("phi3", "Phi3Model"), @@ -1034,6 +1035,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("ovis2", "Ovis2ForConditionalGeneration"), ("paligemma", "PaliGemmaForConditionalGeneration"), ("perception_lm", "PerceptionLMForConditionalGeneration"), + ("isaac", "IsaacForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"), ("pixtral", "LlavaForConditionalGeneration"), ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 13583c55002f..5ee7a913114c 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -112,6 +112,7 @@ ("owlvit", "OwlViTProcessor"), ("paligemma", "PaliGemmaProcessor"), ("perception_lm", "PerceptionLMProcessor"), + ("isaac", "IsaacProcessor"), ("phi4_multimodal", "Phi4MultimodalProcessor"), ("pix2struct", "Pix2StructProcessor"), ("pixtral", "PixtralProcessor"), diff --git a/src/transformers/models/isaac/__init__.py b/src/transformers/models/isaac/__init__.py new file mode 100644 index 000000000000..fbc25598385d --- /dev/null +++ b/src/transformers/models/isaac/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_isaac import * + from .modeling_isaac import * + from .processing_isaac import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) \ No newline at end of file diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py new file mode 100644 index 000000000000..d29e44b68e4d --- /dev/null +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -0,0 +1,141 @@ +# ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ +# This file was automatically generated from src/transformers/models/isaac/modular_isaac.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_isaac.py file directly. One of our CI enforces this. +# ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ + +from ...configuration_utils import PretrainedConfig, layer_type_validation +from ...modeling_rope_utils import rope_config_validation + + +class PixelShuffleSiglip2VisionConfig(PretrainedConfig): + """Vision configuration for Isaac with Pixel Shuffle support. + + Extends Siglip2VisionConfig with additional fields for pixel shuffle. + """ + + model_type = "pixel_shuffle_siglip2" + base_config_key = "vision_config" + + def __init__( + self, + pixel_shuffle_scale_factor: int = 1, + num_patches: int = 256, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.num_patches = num_patches + + # Add our custom fields + self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor + + +class IsaacConfig(PretrainedConfig): + """Configuration class for Isaac multimodal model.""" + + model_type = "isaac" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Isaac` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + sub_configs = {"vision_config": PixelShuffleSiglip2VisionConfig} + + def __init__( + self, + vision_config=None, + vision_patch_size: int = 16, + vision_max_num_patches: int = 256, + vision_min_num_patches: int | None = None, + pixel_shuffle_scale: int = 1, + max_sequence_length: int = 16384, + vision_token: str = "", + **kwargs, + ): + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if self.use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + # Handle vision config - either dict or PixelShuffleSiglip2VisionConfig instance + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + else: + self.vision_config = vision_config + + # EventStreamProcessor parameters (for backward compatibility) + self.video_patch_size = vision_patch_size + self.vision_max_num_patches = vision_max_num_patches + self.vision_min_num_patches = vision_min_num_patches + self.pixel_shuffle_scale = pixel_shuffle_scale + + # Processing parameters + self.max_sequence_length = max_sequence_length + self.vision_token = vision_token + + +__all__ = ["IsaacConfig"] diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py new file mode 100644 index 000000000000..f182f2bb6477 --- /dev/null +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -0,0 +1,1173 @@ +# ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ +# This file was automatically generated from src/transformers/models/isaac/modular_isaac.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_isaac.py file directly. One of our CI enforces this. +# ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ + +from collections import defaultdict +from typing import Any, Callable, Optional, TypedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from perceptron.tensorstream import TensorStream, TextType, VisionType, group_streams +from perceptron.tensorstream.ops import ( + compute_mrope_pos_tensor, + modality_mask, + reconstruct_tensor_stream_from_compact_dict, +) + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from .configuration_isaac import IsaacConfig, PixelShuffleSiglip2VisionConfig + + +class Siglip2VariableSequenceEmbeddings(nn.Module): + def __init__(self, config: PixelShuffleSiglip2VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Linear( + in_features=config.num_channels * self.patch_size * self.patch_size, + out_features=self.embed_dim, + ) + + self.num_patches = config.num_patches + self.position_embedding_size = int(self.num_patches**0.5) + self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + + def positional_embeddings( + self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ) -> torch.Tensor: + # Prepare positional embeddings grid: (1, embed_dim, h, w) + positional_embeddings = ( + self.position_embedding.weight.reshape(self.position_embedding_size, self.position_embedding_size, -1) + .permute(2, 0, 1) + .unsqueeze(0) + ) + + _seq_patches, _seq_sizes, spatial_shapes = packed_seq_patches + pos_embeds_list = [] + mode = "bilinear" + align_corners = False + antialias = True + for spatial_shape in spatial_shapes: + height, width = spatial_shape + # Guard to ensure height and width are positive for torch.compile + if height > 0 and width > 0: + resized_pos_embed = F.interpolate( + positional_embeddings, + size=(height, width), + mode=mode, + align_corners=align_corners, + antialias=antialias, + ) + # Reshape from (1, embed_dim, height, width) to (height*width, embed_dim) + resized_pos_embed = resized_pos_embed.reshape(self.embed_dim, height * width).transpose(0, 1) + else: + # Fallback - should never happen in practice + resized_pos_embed = positional_embeddings.reshape( + self.embed_dim, self.position_embedding_size * self.position_embedding_size + ).transpose(0, 1)[: height * width] + pos_embeds_list.append(resized_pos_embed) + + # Concatenate all positional embeddings along the sequence dimension + pos_embeds = torch.cat(pos_embeds_list, dim=0) + return pos_embeds + + def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor]): + seq_patches, _seq_sizes, _spatial_shapes = packed_seq_patches + + # Apply patch embeddings + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(seq_patches.to(dtype=target_dtype)) + pos_embeds = self.positional_embeddings(packed_seq_patches) + + # Add positional embeddings to patch embeddings + embeddings = patch_embeds + pos_embeds + return embeddings + + +class Siglip2VariableLengthAttention(nn.Module): + """Custom attention that supports variable-length sequences with flash attention.""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None): + batch_size, seq_len, _ = hidden_states.size() + + # For variable-length attention, we need to reshape to (total_tokens, embed_dim) + if batch_size != 1: + raise ValueError("Variable-length attention expects batch_size=1 for packed sequences") + hidden_states = hidden_states.squeeze(0) # Remove batch dimension: (seq_len, embed_dim) + + # Store original dtype + orig_dtype = hidden_states.dtype + + # 1. Linear projections + Q = self.q_proj(hidden_states) # (seq_len, embed_dim) + K = self.k_proj(hidden_states) # (seq_len, embed_dim) + V = self.v_proj(hidden_states) # (seq_len, embed_dim) + + # 2. Reshape for multi-head attention: (seq_len, n_heads, head_dim) + Q = Q.view(-1, self.num_heads, self.embed_dim // self.num_heads) + K = K.view(-1, self.num_heads, self.embed_dim // self.num_heads) + V = V.view(-1, self.num_heads, self.embed_dim // self.num_heads) + + # 3. Apply variable-length attention using flash attention + attn_output, _, _, _, _ = torch.ops.aten._flash_attention_forward( + query=Q, + key=K, + value=V, + cum_seq_q=cu_seqlens, + cum_seq_k=cu_seqlens, + max_q=max_seqlen, + max_k=max_seqlen, + dropout_p=self.dropout if self.training else 0.0, + is_causal=False, + return_debug_mask=False, + scale=self.scale, + window_size_left=-1, + window_size_right=-1, + alibi_slopes=None, + ) + + # 4. Reshape attention output from (seq_len, n_heads, head_dim) to (seq_len, embed_dim) + attn_output = attn_output.reshape(seq_len, self.embed_dim) + + # 5. Convert back to original dtype if needed + if attn_output.dtype != orig_dtype: + attn_output = attn_output.to(orig_dtype) + + # 6. Project output + attn_output = self.out_proj(attn_output) # (seq_len, embed_dim) + + # 7. Add back batch dimension for compatibility + attn_output = attn_output.unsqueeze(0) # (1, seq_len, embed_dim) + + return attn_output, None + + +class IsaacSiglip2EncoderLayer(nn.Module): + """Siglip2 encoder layer with variable-length attention.""" + + def __init__(self, config: PixelShuffleSiglip2VisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Siglip2VariableLengthAttention(config) + + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Siglip2MLP(config) # Use HF's Siglip2MLP + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor = None, + max_seqlen: int = None, + ) -> tuple[torch.FloatTensor]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return (hidden_states,) + + +class IsaacEncoder(nn.Module): + """Encoder using Isaac encoder layers with variable-length attention support.""" + + def __init__(self, config: PixelShuffleSiglip2VisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([IsaacSiglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + inputs_embeds, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, + output_hidden_states: bool = False, + ): + all_hidden_states = () if output_hidden_states else None + + hidden_states = inputs_embeds + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens, + max_seqlen, + ) + + hidden_states = layer_outputs[0] + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return hidden_states, all_hidden_states, None + + +def create_cumulative_seq_lengths(seq_sizes: torch.Tensor, device: torch.device) -> tuple[torch.Tensor, int]: + """Create cumulative sequence lengths for variable-length attention.""" + cu_seqlens = torch.zeros(len(seq_sizes) + 1, dtype=torch.int32, device=device) + cu_seqlens[1:] = seq_sizes.cumsum(0) + max_seqlen = int(seq_sizes.max().item()) if len(seq_sizes) > 0 else 0 + return cu_seqlens, max_seqlen + + +def create_pixel_shuffle_index_map( + seq_sizes: torch.Tensor, + token_grids: torch.Tensor, + scale_factor: int = 1, + device: torch.device | None = None, +) -> torch.Tensor: + """ + Build a gather-index map that tells us, for every *output* token after + pixel-shuffle, which `scale_factor**2` *input* tokens are being merged. + + Args + ---- + seq_sizes : (num_images,) - #patches in each image (row-major order) + token_grids : (num_images,2) - (height, width) for every image + scale_factor : spatial down-scale factor (โ‰ฅ2) + device : (optional) overrides `seq_sizes.device` + + Returns + ------- + gather_idx : (new_total_seq_len, scale_factor**2) int64 tensor. + gather_idx[i, j] is the *flat* index into the *original* + packed sequence for the j-th sub-patch that forms the + i-th output token. + """ + if device is None: + device = seq_sizes.device + + r = int(scale_factor) + if r < 2: + raise ValueError("`scale_factor` must be โ‰ฅ 2") + + # Safety: all spatial dims must be divisible by r + # Cannot run under torch compile fullgraph mode hence + if not torch.compiler.is_compiling(): + if not ((token_grids[:, 0] % r == 0).all() and (token_grids[:, 1] % r == 0).all()): + raise AssertionError( + f"Every (H,W) in `token_grids` must be divisible by scale_factor={r}, got {token_grids.tolist()}" + ) + + gather_chunks: list[torch.Tensor] = [] + tok_offset = 0 + + for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist(), strict=False): + # Build the (H, W) grid of flat indices for this image + grid = torch.arange(seq_len, device=device, dtype=torch.int64) + tok_offset + grid = grid.view(h, w) # (H, W) + + # -------- identical ordering to your fixed-res routine -------- + # Step 1: split width into blocks of r + grid = grid.view(h, w // r, r) # (H, W/r, r) + # Step 2: now split height into blocks of r + grid = grid.view(h // r, r, w // r, r) # (H/r, r, W/r, r) + # Step 3: final permutation to (H/r, W/r, r, r) + grid = grid.permute(0, 2, 1, 3).contiguous() # (H/r, W/r, r, r) + # Step 4: each (r, r) block forms one output token + gather_chunks.append(grid.reshape(-1, r * r)) # (H*W / rยฒ, rยฒ) + + tok_offset += seq_len + + # Concatenate over all images in the packed batch + gather_idx = torch.cat(gather_chunks, dim=0) # (ฮฃ_i HแตขWแตข/rยฒ, rยฒ) + return gather_idx + + +def pixel_shuffle_varlen( + x: torch.Tensor, + token_grids: torch.Tensor, + scale_factor: int = 1, +) -> torch.Tensor: + r"""Apply pixel shuffle to a packed vision sequence without unpacking per image. + + Args: + x (`torch.Tensor`): + Concatenated vision embeddings. Accepts `(seq_len, hidden_size)` or `(1, seq_len, hidden_size)` shapes + produced by stacking image patches. + token_grids (`torch.Tensor`): + Integer tensor of shape `(num_images, 2)` whose rows give the `(height, width)` patch grid sizes + corresponding to each image segment inside `x`. + scale_factor (`int`, *optional*, defaults to 1): + Spatial down-sampling factor specific to pixel shuffle. Values greater than one merge `scale_factor**2` neighboring patches into a + single embedding channel-group. + + Returns: + `torch.Tensor`: Pixel-shuffled embeddings with shape matching the input convention: + `(seq_len, hidden_size * scale_factor**2)` when the input was 2D, or `(1, seq_len, hidden_size * scale_factor**2)` + if the singleton batch dimension was present. + + Raises: + ValueError: If more than one batch item is provided. + """ + keep_batch_dim = x.dim() == 3 + if keep_batch_dim: + if x.size(0) != 1: + raise AssertionError("Packed sequence is expected to have batch_size == 1") + x_ = x.squeeze(0) # (seq, embed) + else: + x_ = x # (seq, embed) + + embed_dim = x_.size(-1) + r = int(scale_factor) + + # Calculate seq_sizes from token_grids + seq_sizes = torch.prod(token_grids, dim=-1) + + # Build index map and gather in one go + gather_idx = create_pixel_shuffle_index_map( + seq_sizes=seq_sizes, + token_grids=token_grids, + scale_factor=r, + device=x_.device, + ) # (new_seq, rยฒ) + + # Gather โ†’ (new_seq, rยฒ, embed_dim) + gathered = x_[gather_idx] # fancy indexing keeps gradient + + # Merge the rยฒ group dimension into channels to finish the shuffle + out = gathered.reshape(gathered.size(0), embed_dim * r * r) + + # Restore batch dimension if needed + if keep_batch_dim: + out = out.unsqueeze(0) + return out + + +class Siglip2SequenceVisionTransformer(nn.Module): + def __init__(self, config: PixelShuffleSiglip2VisionConfig): + super().__init__() + self.config = config + self.embeddings = Siglip2VariableSequenceEmbeddings(config) + self.encoder = IsaacEncoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor + + def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): + seq_patches, token_grids = packed_seq_patches + seq_sizes = torch.prod(token_grids, dim=-1) + + # Get embeddings from packed sequence + hidden_states = self.embeddings((seq_patches, seq_sizes, token_grids)) + + # Add a pseudo batch dimension for the encoder + hidden_states = hidden_states.unsqueeze(0) + + # Generate cumulative sequence lengths for variable-length attention + cu_seqlens, max_seqlen = create_cumulative_seq_lengths(seq_sizes, hidden_states.device) + + # Pass through encoder with variable-length attention parameters + hidden_states, _, _ = self.encoder( + inputs_embeds=hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + # Apply final layer normalization + hidden_states = self.post_layernorm(hidden_states) + + if self.pixel_shuffle_scale_factor > 1: + hidden_states = pixel_shuffle_varlen( + x=hidden_states, + token_grids=token_grids, + scale_factor=self.pixel_shuffle_scale_factor, + ) + # Remove the pseudo batch dimension we added earlier + hidden_states = hidden_states.squeeze(0) + + # Return the full sequence of embeddings + return hidden_states + + +class RopeScaling(TypedDict, total=False): + rope_type: str + factor: float + mrope_section: list[int] + mrope_interleaved: bool + low_freq_factor: float + high_freq_factor: float + original_max_position_embeddings: int + + +def precompute_inv_freq(theta: float, dim: int) -> torch.Tensor: + """ + Returns shape (dim//2,). + """ + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + return inv_freq # type: ignore[return-value] + + +def precompute_cos_sin_3d( + position_ids: torch.Tensor, # shape (3, B, T) + inv_freq: torch.Tensor, # shape (dim//2,) + mrope_half_section: list[int], # sum to dim//2 +) -> tuple[torch.Tensor, torch.Tensor]: + r"""Generate 3D rotary embeddings for multi-axis positions. + + Args: + position_ids (`torch.Tensor`): + Tensor of shape `(3, batch_size, seq_len)` containing positional indices for the x/y/t axes. + inv_freq (`torch.Tensor`): + Precomputed inverse frequency vector used to derive rotary phases. + mrope_half_section (`list[int]`): + Sizes the axis-specific frequency blocks. + + Returns: + `tuple[torch.Tensor, torch.Tensor]`: Cosine and sine tensors, each of shape `(batch_size, seq_len, dim)`, ready + to be passed into rotary attention layers. + """ + B = position_ids.shape[1] + T = position_ids.shape[2] + dim_half = inv_freq.shape[0] + device = position_ids.device + + # Initialize with full dimension (not half) to match LLaMA + cos_3d = torch.zeros((B, T, dim_half * 2), dtype=torch.float32, device=device) + sin_3d = torch.zeros((B, T, dim_half * 2), dtype=torch.float32, device=device) + + offset = 0 + for d in range(3): + block_size = mrope_half_section[d] + freq_slice = inv_freq[offset : offset + block_size] # shape => (block_size,) + # shape => (B, T, block_size) + phase = position_ids[d].unsqueeze(-1).float() * freq_slice + + cos_part = phase.cos() + sin_part = phase.sin() + + # Duplicate values for both halves of the dimension + cos_3d[:, :, offset : offset + block_size] = cos_part + cos_3d[:, :, dim_half + offset : dim_half + offset + block_size] = cos_part + sin_3d[:, :, offset : offset + block_size] = sin_part + sin_3d[:, :, dim_half + offset : dim_half + offset + block_size] = sin_part + + offset += block_size + + return cos_3d, sin_3d + + +class IsaacRotaryEmbedding(nn.Module): + def __init__(self, config: IsaacConfig, device=None): + super().__init__() + + # Extract dimensions from config + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim + + # Get rope_scaling config - use direct access when available + rope_scaling = getattr(config, "rope_scaling", None) or {} + + # Read RopeScaling parameters + self.rope_type = rope_scaling.get("rope_type", "default") + + self.mrope_section = [ + self.head_dim // 4, # 2x more for temporal dim + self.head_dim // 8, + self.head_dim // 8, + ] + + rope_base = getattr(config, "rope_theta", 10000.0) + inv_freq = precompute_inv_freq(rope_base, self.head_dim) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, position_ids: torch.Tensor, modality_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + with torch.no_grad(): + # Ensure non-spatial tokens have 1D rotation equivalence + not_spatial = ~(modality_tensor == VisionType.image.value) + # shape is [N, 1] + data_1d = position_ids[not_spatial][..., 0].unsqueeze(-1) + # now broadcast it from [N, 1] -> [N, D] so it matches pos[not_spatial] exactly + data_1d = data_1d.expand(-1, position_ids.shape[-1]) # expand along the last dim + position_ids = position_ids.clone() # Clone to avoid warning about in-place operations on expanded tensors + position_ids[not_spatial] = data_1d + position_ids = position_ids.permute(2, 0, 1) # pos dim first -> (3, B, L) + cos, sin = precompute_cos_sin_3d(position_ids, self.inv_freq, self.mrope_section) + + return cos, sin + + +@use_kernel_forward_from_hub("RMSNorm") +class IsaacRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + IsaacRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class IsaacMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class IsaacAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: IsaacConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = IsaacRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = IsaacRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class IsaacDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: IsaacConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = IsaacAttention(config=config, layer_idx=layer_idx) + + self.mlp = IsaacMLP(config) + self.input_layernorm = IsaacRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = IsaacRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class IsaacPreTrainedModel(PreTrainedModel): + config: IsaacConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["IsaacDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": IsaacDecoderLayer, + "attentions": IsaacAttention, + } + + +# ============================================================================ +# Model +# ============================================================================ + + +def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: + r"""Create 3D positional indices for token input. + + Args: + input_ids (`torch.Tensor`): + Tensor of shape `(batch_size, seq_len)` containing token ids. + + Returns: + `torch.Tensor`: Positional indices with shape `(batch_size, seq_len, 3)` where each channel duplicates the + 1D position so it can be consumed by the 3-axis MRoPE rotary embedding. + """ + batch_size, seq_length = input_ids.shape + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.unsqueeze(2).expand(-1, -1, 3) # Add 3D for MRoPE + return position_ids + + +@auto_docstring +class IsaacModel(IsaacPreTrainedModel): + def __init__(self, config: IsaacConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = torch.nn.ModuleList( + [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = IsaacRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) + self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + vision_cfg = config.vision_config + if vision_cfg is None: + raise ValueError("IsaacConfig should always have vision_config") + + hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2) + self.vision_embedding = nn.Sequential( + Siglip2SequenceVisionTransformer(vision_cfg), + nn.Linear( + hidden_dim, + 4 * hidden_dim, + bias=False, + ), + nn.SiLU(), + nn.Linear(4 * hidden_dim, config.hidden_size, bias=False), + ) + + # Dispatch table for TensorStream balanced embedding (text + vision) + self.embed_fns = { + TextType: self.embed_text_tokens, + VisionType: self.embed_vision, + } + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + tensor_stream: TensorStream | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + modality_tensor: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs, + ) -> tuple | BaseModelOutputWithPast: + """ + Forward pass with MRoPE position embeddings. + + Computes position embeddings once and passes them through all layers. + """ + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Get inputs + if tensor_stream is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both tensor_stream and inputs_embeds") + elif tensor_stream is not None: + # Embed TensorStream directly + inputs_embeds = self.embed_stream(tensor_stream) + # Create modality tensor if not provided + if modality_tensor is None: + modality_tensor = modality_mask(tensor_stream) + elif input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + inputs_embeds = self.embed_tokens(input_ids) + # Create text modality tensor if not provided + if modality_tensor is None: + batch_size, seq_length = input_ids.shape + modality_tensor = torch.full( + (batch_size, seq_length), TextType.text.value, device=input_ids.device, dtype=torch.long + ) + elif inputs_embeds is None: + raise ValueError("You have to specify either tensor_stream, input_ids or inputs_embeds") + + # Create default position_ids if not provided + if position_ids is None: + if tensor_stream is not None: + position_ids = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) + else: + position_ids = compute_position_ids_input_ids(input_ids) + + # Compute MRoPE position embeddings if we have custom rotary_emb + cos, sin = self.rotary_emb(position_ids, modality_tensor) + cos = cos.to(inputs_embeds.dtype) + sin = sin.to(inputs_embeds.dtype) + + # Prepare attention mask + if attention_mask is not None: + attention_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, False + ) + + # Initialize hidden states + hidden_states = inputs_embeds + + for decoder_layer in self.layers: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=(cos, sin), + **kwargs, + ) + + hidden_states = layer_outputs[0] + + # Final layer norm + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: + """Embed text tokens, squeezing singleton dimensions.""" + # Text events are shaped as (..., 1); squeeze the singleton index dim + h = self.embed_tokens(token_ids) + if h.dim() >= 2 and h.size(-2) == 1: + h = h[..., 0, :] + return h + + def embed_vision(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + """Embed vision tokens using the vision encoder.""" + # vision tokens is (seq_patches, token_grids) + return self.vision_embedding(vision_tokens) + + def embed_stream(self, tensor_stream: TensorStream) -> torch.Tensor: + """ + Embed each modality stream independently, preserving the original TensorStream + structure. + """ + flat_stream = tensor_stream.flat_stream() + per_modality_stream = group_streams(flat_stream, group_fn=lambda ev: ev.type, schedule=False) + per_modality_compact_stream = {k: v.compact() for k, v in per_modality_stream.items()} + + # Collect per-event grids for vision tokens (H, W like dims sans time) + token_grids = defaultdict(list) + for stream in tensor_stream.streams: + for event in stream: + token_grids[event.type].append(event.dims(virtual=False)) + + embedded_compact = {} + for stream_type, modality_payload_tensor in per_modality_compact_stream.items(): + if stream_type.modality == VisionType: + # Build a (N_events, 2) grid tensor with spatial dims only + grids = token_grids.get(stream_type, []) + if len(grids) == 0: + input_tensor = modality_payload_tensor + else: + token_grids_tensor = torch.tensor(grids, dtype=torch.long, device=tensor_stream.device)[:, 1:] + input_tensor = (modality_payload_tensor, token_grids_tensor) + embedded_compact[stream_type] = self.embed_fns[stream_type.modality](input_tensor) + else: + embedded_compact[stream_type] = self.embed_fns[stream_type.modality](modality_payload_tensor) + + # Reconstruct a TensorStream with embedded payloads and compact + embedded_ts = reconstruct_tensor_stream_from_compact_dict(tensor_stream, embedded_compact) + h = embedded_ts.compact() # (B, T, D) + return h + + +@auto_docstring +class IsaacForConditionalGeneration(IsaacPreTrainedModel, GenerationMixin): + """Isaac multimodal model for conditional generation.""" + + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + config_class = IsaacConfig + + def __init__(self, config: IsaacConfig): + Qwen3PreTrainedModel.__init__(self, config) + self.model = IsaacModel(config) # Use our custom model + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Tracks rotary position offsets computed during a full forward pass so decode steps can reuse them. + self.rope_deltas = None + + self.config = config + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + tensor_stream: TensorStream | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs, + ) -> tuple | CausalLMOutputWithPast: + """ + Forward pass for conditional generation supporting both standard inputs and TensorStream. + Uses our embed_stream approach for multimodal inputs. + """ + + # Don't compute embeddings here - let the model handle it + if tensor_stream is not None: + input_ids = None + if input_ids is None and inputs_embeds is None and tensor_stream is None: + raise ValueError("Either input_ids, inputs_embeds, or tensor_stream must be provided.") + + # Build position ids (MRoPE) if needed and tensor_stream is available + # During decode we reuse `self.rope_deltas` computed on the initial forward pass; `rope_delta` captures how far + # cached rotary phases have progressed so we can advance `position_ids` without rebuilding the TensorStream. + if position_ids is None and tensor_stream is not None: + position_ids, self.rope_deltas = self.get_rope_index(input_ids, tensor_stream, attention_mask) + elif position_ids is None and input_ids is not None: + # For text inputs build position ids and modality tensor + position_ids = compute_position_ids_input_ids(input_ids) + if cache_position is not None and self.rope_deltas is not None: + # Combine the incremental decode step (`cache_position`) with cached offsets so hidden states continue + # rotating in lockstep across generation steps. + rope_delta = (cache_position[0] + self.rope_deltas).to(input_ids.device) + else: + rope_delta = 0 + if cache_position is not None and not isinstance(rope_delta, int): # otherwise `deltas` is an int `0` + batch_size = input_ids.shape[0] + rope_delta = rope_delta.repeat_interleave(batch_size // rope_delta.shape[0], dim=0) + position_ids = position_ids.add(rope_delta) + + if tensor_stream is not None: + modality_tensor = modality_mask(tensor_stream) + else: + batch_size, seq_len = input_ids.shape + modality_tensor = torch.empty(batch_size, seq_len, device=position_ids.device).fill_(TextType.text.value) + + outputs = self.model( + input_ids=input_ids, + tensor_stream=tensor_stream, + attention_mask=attention_mask, + position_ids=position_ids, + modality_tensor=modality_tensor, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=None, + ) + + def get_rope_index( + self, + input_ids: torch.Tensor | None, + tensor_stream: TensorStream | None, + attention_mask: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute MRoPE position ids from a TensorStream (or 1D fallback). + + Returns (position_ids, rope_deltas). position_ids is (B,L,3) for MRoPE. + rope_deltas is (B,1) used to advance positions in decode. + """ + # tensor_stream present: compute 3D coords + if tensor_stream is None and input_ids is None: + raise ValueError("`tensor_stream` or `input_ids` must be provided to compute rope indices") + + if tensor_stream is not None: + pos_3d = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) + else: + pos_3d = compute_position_ids_input_ids(input_ids) + B, L, _ = pos_3d.shape + + # Max position per batch across the 3 planes and sequence dimension: (B,) + m_per_batch = pos_3d.amax(dim=(1, 2)) + + # Sequence lengths per batch: (B,) + if attention_mask is None: + seq_lens = torch.full_like(m_per_batch, L) + else: + seq_lens = attention_mask.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=m_per_batch.device) + + rope_deltas = (m_per_batch + 1 - seq_lens).to(dtype=pos_3d.dtype).unsqueeze(1) + return pos_3d, rope_deltas + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: list[torch.FloatTensor] | None = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + tensor_stream: TensorStream | None = None, + cache_position: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + use_cache: bool = True, + **kwargs, + ) -> dict[str, Any]: + """ + Prepare inputs for generation, handling TensorStream inputs properly. + """ + # Call parent preparation + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + **kwargs, + ) + + # Handle TensorStream for first forward pass only + if tensor_stream is not None and (cache_position is None or cache_position[0] == 0): + model_inputs["tensor_stream"] = tensor_stream + # Let forward rebuild position_ids using cached deltas during decode + model_inputs["position_ids"] = None + # Drop tensor_stream after step 0 + if cache_position is not None and cache_position[0] != 0: + model_inputs["tensor_stream"] = None + return model_inputs + + def can_generate(self) -> bool: + return True + + +__all__ = ["IsaacModel", "IsaacForConditionalGeneration"] diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py new file mode 100644 index 000000000000..4cc1e157c363 --- /dev/null +++ b/src/transformers/models/isaac/modular_isaac.py @@ -0,0 +1,1467 @@ +from __future__ import annotations + +from collections import defaultdict +from typing import Any, Union, TypedDict + +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import PIL.Image + + +from ...utils import logging +from ...processing_utils import ProcessorMixin, BatchFeature +from ...tokenization_utils_base import PreTrainedTokenizerBase +from ..auto import AutoTokenizer +from ..qwen3.configuration_qwen3 import Qwen3Config +from ..qwen3.modeling_qwen3 import ( + Qwen3ForCausalLM, + Qwen3PreTrainedModel, + Qwen3DecoderLayer, + Qwen3Model +) +from ...generation import GenerationMixin +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...tokenization_utils import TensorType +import re + +from ..siglip2.modeling_siglip2 import Siglip2MLP +from ..siglip2.configuration_siglip2 import Siglip2VisionConfig + +from perceptron.tensorstream import ( + Event, + Stream, + TensorStream, + TextType, + VisionType, + create_stream, + group_streams, +) +from perceptron.tensorstream.ops import ( + compute_mrope_pos_tensor, + modality_mask, + reconstruct_tensor_stream_from_compact_dict, + slice as ts_slice, + tensor_stream_token_view, +) + +logger = logging.get_logger(__name__) + + +class PixelShuffleSiglip2VisionConfig(Siglip2VisionConfig): + """Vision configuration for Isaac with Pixel Shuffle support. + + Extends Siglip2VisionConfig with additional fields for pixel shuffle. + """ + + model_type = "pixel_shuffle_siglip2" + base_config_key = "vision_config" + + def __init__( + self, + pixel_shuffle_scale_factor: int = 1, + num_patches: int = 256, + **kwargs, + ): + super().__init__(**kwargs) + + # Add our custom fields + self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor + self.num_patches = num_patches + + +def create_cumulative_seq_lengths(seq_sizes: torch.Tensor, device: torch.device) -> tuple[torch.Tensor, int]: + """Create cumulative sequence lengths for variable-length attention.""" + cu_seqlens = torch.zeros(len(seq_sizes) + 1, dtype=torch.int32, device=device) + cu_seqlens[1:] = seq_sizes.cumsum(0) + max_seqlen = int(seq_sizes.max().item()) if len(seq_sizes) > 0 else 0 + return cu_seqlens, max_seqlen + + +class Siglip2VariableSequenceEmbeddings(nn.Module): + def __init__(self, config: PixelShuffleSiglip2VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Linear( + in_features=config.num_channels * self.patch_size * self.patch_size, + out_features=self.embed_dim, + ) + + self.num_patches = config.num_patches + self.position_embedding_size = int(self.num_patches**0.5) + self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + + def positional_embeddings( + self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ) -> torch.Tensor: + # Prepare positional embeddings grid: (1, embed_dim, h, w) + positional_embeddings = ( + self.position_embedding.weight.reshape(self.position_embedding_size, self.position_embedding_size, -1) + .permute(2, 0, 1) + .unsqueeze(0) + ) + + _seq_patches, _seq_sizes, spatial_shapes = packed_seq_patches + pos_embeds_list = [] + mode = "bilinear" + align_corners = False + antialias = True + for spatial_shape in spatial_shapes: + height, width = spatial_shape + # Guard to ensure height and width are positive for torch.compile + if height > 0 and width > 0: + resized_pos_embed = F.interpolate( + positional_embeddings, + size=(height, width), + mode=mode, + align_corners=align_corners, + antialias=antialias, + ) + # Reshape from (1, embed_dim, height, width) to (height*width, embed_dim) + resized_pos_embed = resized_pos_embed.reshape(self.embed_dim, height * width).transpose(0, 1) + else: + # Fallback - should never happen in practice + resized_pos_embed = positional_embeddings.reshape( + self.embed_dim, self.position_embedding_size * self.position_embedding_size + ).transpose(0, 1)[: height * width] + pos_embeds_list.append(resized_pos_embed) + + # Concatenate all positional embeddings along the sequence dimension + pos_embeds = torch.cat(pos_embeds_list, dim=0) + return pos_embeds + + def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor]): + seq_patches, _seq_sizes, _spatial_shapes = packed_seq_patches + + # Apply patch embeddings + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(seq_patches.to(dtype=target_dtype)) + pos_embeds = self.positional_embeddings(packed_seq_patches) + + # Add positional embeddings to patch embeddings + embeddings = patch_embeds + pos_embeds + return embeddings + + +class Siglip2VariableLengthAttention(nn.Module): + """Custom attention that supports variable-length sequences with flash attention.""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None): + batch_size, seq_len, _ = hidden_states.size() + + # For variable-length attention, we need to reshape to (total_tokens, embed_dim) + if batch_size != 1: + raise ValueError("Variable-length attention expects batch_size=1 for packed sequences") + hidden_states = hidden_states.squeeze(0) # Remove batch dimension: (seq_len, embed_dim) + + # Store original dtype + orig_dtype = hidden_states.dtype + + # 1. Linear projections + Q = self.q_proj(hidden_states) # (seq_len, embed_dim) + K = self.k_proj(hidden_states) # (seq_len, embed_dim) + V = self.v_proj(hidden_states) # (seq_len, embed_dim) + + # 2. Reshape for multi-head attention: (seq_len, n_heads, head_dim) + Q = Q.view(-1, self.num_heads, self.embed_dim // self.num_heads) + K = K.view(-1, self.num_heads, self.embed_dim // self.num_heads) + V = V.view(-1, self.num_heads, self.embed_dim // self.num_heads) + + # 3. Apply variable-length attention using flash attention + attn_output, _, _, _, _ = torch.ops.aten._flash_attention_forward( + query=Q, + key=K, + value=V, + cum_seq_q=cu_seqlens, + cum_seq_k=cu_seqlens, + max_q=max_seqlen, + max_k=max_seqlen, + dropout_p=self.dropout if self.training else 0.0, + is_causal=False, + return_debug_mask=False, + scale=self.scale, + window_size_left=-1, + window_size_right=-1, + alibi_slopes=None, + ) + + # 4. Reshape attention output from (seq_len, n_heads, head_dim) to (seq_len, embed_dim) + attn_output = attn_output.reshape(seq_len, self.embed_dim) + + # 5. Convert back to original dtype if needed + if attn_output.dtype != orig_dtype: + attn_output = attn_output.to(orig_dtype) + + # 6. Project output + attn_output = self.out_proj(attn_output) # (seq_len, embed_dim) + + # 7. Add back batch dimension for compatibility + attn_output = attn_output.unsqueeze(0) # (1, seq_len, embed_dim) + + return attn_output, None + + +class IsaacSiglip2EncoderLayer(nn.Module): + """Siglip2 encoder layer with variable-length attention.""" + + def __init__(self, config: PixelShuffleSiglip2VisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Siglip2VariableLengthAttention(config) + + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Siglip2MLP(config) # Use HF's Siglip2MLP + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor = None, + max_seqlen: int = None, + ) -> tuple[torch.FloatTensor]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return (hidden_states,) + + +class IsaacEncoder(nn.Module): + """Encoder using Isaac encoder layers with variable-length attention support.""" + + def __init__(self, config: PixelShuffleSiglip2VisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([IsaacSiglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + inputs_embeds, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, + output_hidden_states: bool = False, + ): + all_hidden_states = () if output_hidden_states else None + + hidden_states = inputs_embeds + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens, + max_seqlen, + ) + + hidden_states = layer_outputs[0] + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return hidden_states, all_hidden_states, None + + +def create_pixel_shuffle_index_map( + seq_sizes: torch.Tensor, + token_grids: torch.Tensor, + scale_factor: int = 1, + device: torch.device | None = None, +) -> torch.Tensor: + """ + Build a gather-index map that tells us, for every *output* token after + pixel-shuffle, which `scale_factor**2` *input* tokens are being merged. + + Args + ---- + seq_sizes : (num_images,) - #patches in each image (row-major order) + token_grids : (num_images,2) - (height, width) for every image + scale_factor : spatial down-scale factor (โ‰ฅ2) + device : (optional) overrides `seq_sizes.device` + + Returns + ------- + gather_idx : (new_total_seq_len, scale_factor**2) int64 tensor. + gather_idx[i, j] is the *flat* index into the *original* + packed sequence for the j-th sub-patch that forms the + i-th output token. + """ + if device is None: + device = seq_sizes.device + + r = int(scale_factor) + if r < 2: + raise ValueError("`scale_factor` must be โ‰ฅ 2") + + # Safety: all spatial dims must be divisible by r + # Cannot run under torch compile fullgraph mode hence + if not torch.compiler.is_compiling(): + if not ((token_grids[:, 0] % r == 0).all() and (token_grids[:, 1] % r == 0).all()): + raise AssertionError( + f"Every (H,W) in `token_grids` must be divisible by scale_factor={r}, got {token_grids.tolist()}" + ) + + gather_chunks: list[torch.Tensor] = [] + tok_offset = 0 + + for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist(), strict=False): + # Build the (H, W) grid of flat indices for this image + grid = torch.arange(seq_len, device=device, dtype=torch.int64) + tok_offset + grid = grid.view(h, w) # (H, W) + + # -------- identical ordering to your fixed-res routine -------- + # Step 1: split width into blocks of r + grid = grid.view(h, w // r, r) # (H, W/r, r) + # Step 2: now split height into blocks of r + grid = grid.view(h // r, r, w // r, r) # (H/r, r, W/r, r) + # Step 3: final permutation to (H/r, W/r, r, r) + grid = grid.permute(0, 2, 1, 3).contiguous() # (H/r, W/r, r, r) + # Step 4: each (r, r) block forms one output token + gather_chunks.append(grid.reshape(-1, r * r)) # (H*W / rยฒ, rยฒ) + + tok_offset += seq_len + + # Concatenate over all images in the packed batch + gather_idx = torch.cat(gather_chunks, dim=0) # (ฮฃ_i HแตขWแตข/rยฒ, rยฒ) + return gather_idx + + +def pixel_shuffle_varlen( + x: torch.Tensor, + token_grids: torch.Tensor, + scale_factor: int = 1, +) -> torch.Tensor: + r"""Apply pixel shuffle to a packed vision sequence without unpacking per image. + + Args: + x (`torch.Tensor`): + Concatenated vision embeddings. Accepts `(seq_len, hidden_size)` or `(1, seq_len, hidden_size)` shapes + produced by stacking image patches. + token_grids (`torch.Tensor`): + Integer tensor of shape `(num_images, 2)` whose rows give the `(height, width)` patch grid sizes + corresponding to each image segment inside `x`. + scale_factor (`int`, *optional*, defaults to 1): + Spatial down-sampling factor specific to pixel shuffle. Values greater than one merge `scale_factor**2` neighboring patches into a + single embedding channel-group. + + Returns: + `torch.Tensor`: Pixel-shuffled embeddings with shape matching the input convention: + `(seq_len, hidden_size * scale_factor**2)` when the input was 2D, or `(1, seq_len, hidden_size * scale_factor**2)` + if the singleton batch dimension was present. + + Raises: + ValueError: If more than one batch item is provided. + """ + keep_batch_dim = x.dim() == 3 + if keep_batch_dim: + if x.size(0) != 1: + raise AssertionError("Packed sequence is expected to have batch_size == 1") + x_ = x.squeeze(0) # (seq, embed) + else: + x_ = x # (seq, embed) + + embed_dim = x_.size(-1) + r = int(scale_factor) + + # Calculate seq_sizes from token_grids + seq_sizes = torch.prod(token_grids, dim=-1) + + # Build index map and gather in one go + gather_idx = create_pixel_shuffle_index_map( + seq_sizes=seq_sizes, + token_grids=token_grids, + scale_factor=r, + device=x_.device, + ) # (new_seq, rยฒ) + + # Gather โ†’ (new_seq, rยฒ, embed_dim) + gathered = x_[gather_idx] # fancy indexing keeps gradient + + # Merge the rยฒ group dimension into channels to finish the shuffle + out = gathered.reshape(gathered.size(0), embed_dim * r * r) + + # Restore batch dimension if needed + if keep_batch_dim: + out = out.unsqueeze(0) + return out + + +class Siglip2SequenceVisionTransformer(nn.Module): + def __init__(self, config: PixelShuffleSiglip2VisionConfig): + super().__init__() + self.config = config + self.embeddings = Siglip2VariableSequenceEmbeddings(config) + self.encoder = IsaacEncoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor + + def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): + seq_patches, token_grids = packed_seq_patches + seq_sizes = torch.prod(token_grids, dim=-1) + + # Get embeddings from packed sequence + hidden_states = self.embeddings((seq_patches, seq_sizes, token_grids)) + + # Add a pseudo batch dimension for the encoder + hidden_states = hidden_states.unsqueeze(0) + + # Generate cumulative sequence lengths for variable-length attention + cu_seqlens, max_seqlen = create_cumulative_seq_lengths(seq_sizes, hidden_states.device) + + # Pass through encoder with variable-length attention parameters + hidden_states, _, _ = self.encoder( + inputs_embeds=hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + # Apply final layer normalization + hidden_states = self.post_layernorm(hidden_states) + + if self.pixel_shuffle_scale_factor > 1: + hidden_states = pixel_shuffle_varlen( + x=hidden_states, + token_grids=token_grids, + scale_factor=self.pixel_shuffle_scale_factor, + ) + # Remove the pseudo batch dimension we added earlier + hidden_states = hidden_states.squeeze(0) + + # Return the full sequence of embeddings + return hidden_states + + +# ============================================================================ +# Configuration +# ============================================================================ + +MAX_PIXELS = 60_000_000 # 60-megapixel ceiling โ‰ˆ 8200 ร— 7300 px + +# Vision preprocessing constants +VISION_MEAN = (0.5, 0.5, 0.5) +VISION_STD = (0.5, 0.5, 0.5) +VISION_SCALE = 1 / 255 + + +def _make_writeable(arr: np.ndarray) -> np.ndarray: + """Return *arr* itself if it is already writeable, otherwise try to flip the + write flag in-place and finally fall back to `arr.copy()`. + This guarantees the buffer handed to `torch.from_numpy()` is always + writeable, silencing the PyTorch warning about undefined behaviour. + """ + if arr.flags.writeable: + return arr + + # First, try the cheap path โ€” in-place flag toggle (works for mmap'd arrays + # and some shared memory buffers): + try: + arr.setflags(write=True) + return arr # success: no data copy + except ValueError: + # Buffer is inherently read-only (e.g. backed by PyAV / PIL): make copy + return arr.copy() + + +def extract_image_pil(image: PIL.Image.Image) -> torch.Tensor | None: + if image.width * image.height > MAX_PIXELS: + raise ValueError(f"Image (w={image.width}, h={image.height}) > MAX=`{MAX_PIXELS}`") + img = image if image.mode == "RGB" else image.convert("RGB") + arr = np.asarray(img) + arr = _make_writeable(arr) + return torch.from_numpy(arr) + + +def get_image_size_for_max_num_patches( + image_height: int, + image_width: int, + patch_size: int, + max_num_patches: int, + min_num_patches: int | None = None, + eps: float = 1e-5, + pixel_shuffle_scale: int = 1, +) -> tuple[int, int]: + r"""Compute a target resolution whose patch grid satisfies patching parametrization. + + Args: + image_height (`int`): + Height in pixels of the source image prior to any resizing. + image_width (`int`): + Width in pixels of the source image prior to any resizing. + patch_size (`int`): + Size of the square patch used by the vision encoder. + max_num_patches (`int`): + Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. + min_num_patches (`int`, *optional*): + Lower bound on the number of patches. When provided the image will be scaled up if necessary. + eps (`float`, *optional*, defaults to 1e-5): + Convergence tolerance for the internal binary search to determing the target dimensions. + pixel_shuffle_scale (`int`, *optional*, defaults to 1): + Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. + + Returns: + `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` + and respect both the maximum and optional minimum patch-count constraints. + """ + + def get_scaled_image_size(scale, original_size, patch_size, pixel_shuffle_scale): + scaled_size = scale * original_size + divisor = patch_size * pixel_shuffle_scale + scaled_size = math.ceil(scaled_size / divisor) * divisor + scaled_size = max(divisor, scaled_size) + return int(scaled_size) + + # Ensure divisibility + divisor = patch_size * pixel_shuffle_scale + adjusted_height = math.ceil(image_height / divisor) * divisor + adjusted_height = max(divisor, adjusted_height) + adjusted_width = math.ceil(image_width / divisor) * divisor + adjusted_width = max(divisor, adjusted_width) + + num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) + + if min_num_patches is not None and num_patches < min_num_patches: + # Scale up + scale_min, scale_max = 1.0, 100.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + num_patches = (target_height / patch_size) * (target_width / patch_size) + if num_patches >= min_num_patches: + scale_max = scale + else: + scale_min = scale + scale = scale_max + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + return target_height, target_width + elif num_patches <= max_num_patches: + return adjusted_height, adjusted_width + else: + # Scale down + scale_min, scale_max = eps / 10, 1.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + num_patches = (target_height / patch_size) * (target_width / patch_size) + if num_patches <= max_num_patches: + scale_min = scale + else: + scale_max = scale + scale = scale_min + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + return target_height, target_width + + +_MEAN_TENSOR = torch.tensor(VISION_MEAN, dtype=torch.float32).view(1, 1, 1, -1) +_STD_TENSOR = torch.tensor(VISION_STD, dtype=torch.float32).view(1, 1, 1, -1) + + +def prepare_image_tensor( + image: torch.Tensor, + scale: float = VISION_SCALE, +) -> torch.Tensor: + r"""Standardize RGB images prior to patch extraction via rescaling and whitening. + + Args: + image (`torch.Tensor`): + Tensor with shape `(..., height, width, 3)` containing RGB values. The tensor is converted to floating + point if needed. + scale (`float`, *optional*, defaults to `VISION_SCALE`): + Scalar multiplier applied before normalization. + Returns: + `torch.Tensor`: Normalized tensor with the same shape as the input and dtype `torch.float32`. + """ + if not torch.is_floating_point(image): + image = image.float() + rescaled = image * scale + + # Use precomputed tensors and move to the correct device if needed + mean_tensor = _MEAN_TENSOR.to(image.device) + std_tensor = _STD_TENSOR.to(image.device) + + normalized = (rescaled - mean_tensor) / std_tensor + return normalized + + +def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: + r"""Convert normalized images into flattened ViT-style patches. + + Args: + image (`torch.Tensor`): + Tensor of shape `(num_images, height, width, channels)`. + patch_size (`int`): + Edge length of the square patches + + Returns: + `torch.Tensor`: + Patch tensor where each position stores the flattened pixels belonging to that patch. + + Raises: + ValueError: If `height` or `width` is not divisible by `patch_size`. + """ + num_images, height, width, channels = image.shape + if height % patch_size or width % patch_size: + raise ValueError(f"Dimensions of images {image.shape} are not divisible by patch_size={patch_size}.") + patches = image.reshape(num_images, height // patch_size, patch_size, width // patch_size, patch_size, channels) + patches = patches.permute(0, 1, 3, 2, 4, 5) + patches = patches.reshape(num_images, height // patch_size, width // patch_size, channels * patch_size * patch_size) + return patches + + +def process_vision_for_patches( + images: torch.Tensor, + patch_size: int, + max_num_patches: int, + min_num_patches: int | None = None, + pixel_shuffle_scale: int = 1, +) -> tuple[torch.Tensor, list[int]]: + r"""Resize, normalize, and patchify RGB images for the vision encoder. + + Args: + images (`torch.Tensor`): + Either `(height, width, channels)` for a single image or `(num_images, height, width, channels)` for a + batch. Channels are expected to be RGB. + patch_size (`int`): + Edge length of square patches; implictly controls resize grid granularity. + max_num_patches (`int`): + Maximum number of patches allowed after resizing. + min_num_patches (`int`, *optional*): + Minimum number of patches. If provided, the routine upsamples images as needed to satisfy the lower bound. + pixel_shuffle_scale (`int`, *optional*, defaults to 1): + pixel shuffle scale factor; influences the target grid that the function produces. + + Returns: + `tuple[torch.Tensor, list[int]]`: A pair `(patches, dims_virtual)` where `patches` has shape + `(num_images, target_h / patch_size, target_w / patch_size, channels * patch_size**2)` and `dims_virtual` + encodes effective `(images, height, width)` dimensions after optional pixel shuffling. + """ + # Add batch dim if single image + if images.dim() == 3: + images = images.unsqueeze(0) + + # Permute to channel first for resize + images = images.permute(0, 3, 1, 2) + + # Get target dimensions + _, _, orig_height, orig_width = images.shape + target_height, target_width = get_image_size_for_max_num_patches( + orig_height, + orig_width, + patch_size, + max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + ) + + # Resize + images = F.interpolate( + images, + size=(target_height, target_width), + mode="bilinear", + align_corners=False, + ) + + # Back to channel last + images = images.permute(0, 2, 3, 1) + + # Normalize + images = prepare_image_tensor(images) + + # Patchify + patches = patchify_vision(images, patch_size=patch_size) + + # Calculate dimensions for the patches + n_images, h_patches, w_patches, _ = patches.shape + dims_virtual = ( + [1, h_patches, w_patches] + if pixel_shuffle_scale == 1 + else [1, h_patches // pixel_shuffle_scale, w_patches // pixel_shuffle_scale] + ) + + return patches, dims_virtual + + +def precompute_inv_freq(theta: float, dim: int) -> torch.Tensor: + """ + Returns shape (dim//2,). + """ + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + return inv_freq # type: ignore[return-value] + + +def precompute_cos_sin_3d( + position_ids: torch.Tensor, # shape (3, B, T) + inv_freq: torch.Tensor, # shape (dim//2,) + mrope_half_section: list[int], # sum to dim//2 +) -> tuple[torch.Tensor, torch.Tensor]: + r"""Generate 3D rotary embeddings for multi-axis positions. + + Args: + position_ids (`torch.Tensor`): + Tensor of shape `(3, batch_size, seq_len)` containing positional indices for the x/y/t axes. + inv_freq (`torch.Tensor`): + Precomputed inverse frequency vector used to derive rotary phases. + mrope_half_section (`list[int]`): + Sizes the axis-specific frequency blocks. + + Returns: + `tuple[torch.Tensor, torch.Tensor]`: Cosine and sine tensors, each of shape `(batch_size, seq_len, dim)`, ready + to be passed into rotary attention layers. + """ + B = position_ids.shape[1] + T = position_ids.shape[2] + dim_half = inv_freq.shape[0] + device = position_ids.device + + # Initialize with full dimension (not half) to match LLaMA + cos_3d = torch.zeros((B, T, dim_half * 2), dtype=torch.float32, device=device) + sin_3d = torch.zeros((B, T, dim_half * 2), dtype=torch.float32, device=device) + + offset = 0 + for d in range(3): + block_size = mrope_half_section[d] + freq_slice = inv_freq[offset : offset + block_size] # shape => (block_size,) + # shape => (B, T, block_size) + phase = position_ids[d].unsqueeze(-1).float() * freq_slice + + cos_part = phase.cos() + sin_part = phase.sin() + + # Duplicate values for both halves of the dimension + cos_3d[:, :, offset : offset + block_size] = cos_part + cos_3d[:, :, dim_half + offset : dim_half + offset + block_size] = cos_part + sin_3d[:, :, offset : offset + block_size] = sin_part + sin_3d[:, :, dim_half + offset : dim_half + offset + block_size] = sin_part + + offset += block_size + + return cos_3d, sin_3d + + +class RopeScaling(TypedDict, total=False): + rope_type: str + factor: float + mrope_section: list[int] + mrope_interleaved: bool + low_freq_factor: float + high_freq_factor: float + original_max_position_embeddings: int + + +class IsaacConfig(Qwen3Config): + """Configuration class for Isaac multimodal model.""" + + model_type = "isaac" + sub_configs = {"vision_config": PixelShuffleSiglip2VisionConfig} + + def __init__( + self, + vision_config=None, + vision_patch_size: int = 16, + vision_max_num_patches: int = 256, + vision_min_num_patches: int | None = None, + pixel_shuffle_scale: int = 1, + max_sequence_length: int = 16384, + vision_token: str = "", + **kwargs, + ): + super().__init__(**kwargs) + + # Handle vision config - either dict or PixelShuffleSiglip2VisionConfig instance + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + else: + self.vision_config = vision_config + + # EventStreamProcessor parameters (for backward compatibility) + self.video_patch_size = vision_patch_size + self.vision_max_num_patches = vision_max_num_patches + self.vision_min_num_patches = vision_min_num_patches + self.pixel_shuffle_scale = pixel_shuffle_scale + + # Processing parameters + self.max_sequence_length = max_sequence_length + self.vision_token = vision_token + + +# ============================================================================ +# Processor Components +# ============================================================================ + + +def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> Event: + r"""Wrap a text into an `Event` compatible with the multimodal TensorStream. + + Args: + tokenizer (`AutoTokenizer`): + Tokenizer used to convert text into model vocabulary ids. + text (`str`): + Plain-text fragment to encode. + time (`float`, *optional*, defaults to 0.0): + Timeline coordinate associated with the event. Both start and end times use the same value because text + segments are instantaneous in the scheduler. + + Returns: + `Event`: Event carrying a `(num_tokens, 1)` tensor of token ids with matching + metadata so that downstream processors can compute modality-specific embeddings. + """ + tokens = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt").squeeze(0) + + # Calculate dimensions for the event + num_tokens = len(tokens) + dims_virtual = [num_tokens, 1] # [sequence_length, 1] + dims_real = dims_virtual.copy() + + # Ensure tokens has the right shape for tensor_stream_token_view + # It expects a 2D tensor where sum(dim=-1) gives the token IDs + if tokens.dim() == 1: + tokens = tokens.unsqueeze(-1) + + return Event( + data=tokens, + type=TextType.text, + time=(time, time), + dims_virtual=dims_virtual, + dims_real=dims_real, + idx_range=(0, num_tokens), + ) + + +# ============================================================================ +# Processor +# ============================================================================ + + +class IsaacProcessor(ProcessorMixin): + attributes = [] + tokenizer_class = ("AutoTokenizer",) + + def __init__( + self, + tokenizer: AutoTokenizer, + config: IsaacConfig, + ): + super().__init__() + self.tokenizer = tokenizer + self.config = config + + # Use vision token from config + self.vision_token = config.vision_token + + # Processing parameters + self.max_sequence_length = config.max_sequence_length + + # Vision processing parameters + self.patch_size = config.video_patch_size + self.max_num_patches = config.vision_max_num_patches + self.min_num_patches = config.vision_min_num_patches + self.pixel_shuffle_scale = config.pixel_shuffle_scale + + def apply_chat_template( + self, + messages: list[dict[str, Any]], + tokenize: bool = False, + add_generation_prompt: bool = False, + **kwargs, + ) -> Any: + return self.tokenizer.apply_chat_template( + messages, tokenize=tokenize, add_generation_prompt=add_generation_prompt, **kwargs + ) + + def build_event_stream_simple( + self, + text: str, + images: list[PIL.Image.Image] | None = None, + ) -> Stream: + events = [] + # Process text and images + # Find all occurrences of vision token + + pattern = re.escape(self.vision_token) + parts = re.split(f"({pattern})", text) # Keep the delimiter in the result + + image_idx = 0 + for current_time, part in enumerate(parts): + if part == self.vision_token: + # Replace vision token with image event + if image_idx < len(images): + # Create vision event from PIL image + image_tensor = extract_image_pil(images[image_idx]) + if image_tensor is not None: + # Create a vision event with the image tensor + vision_event = Event( + data=image_tensor.unsqueeze(0), # HWC format from extract_image_pil + type=VisionType.image, # I-frame + time=(current_time, current_time), + ) + events.append(vision_event) + image_idx += 1 + elif part: # Non-empty text part + # tokens = self.text_processor.tokenize(part, add_special_tokens=False) + text_event = create_text_event(self.tokenizer, part, time=current_time) + events.append(text_event) + + # Process vision events if any + if any(event.type == VisionType.image for event in events): + # Separate text and vision events for processing + text_events = [event for event in events if event.type == TextType.text] + vision_events = [event for event in events if event.type == VisionType.image] + + # Process vision events using functional approach + processed_vision_events = [] + for vision_event in vision_events: + # Process the vision data + patches, dims_virtual = process_vision_for_patches( + vision_event.data.squeeze(0), # Remove the extra dimension + patch_size=self.patch_size, + max_num_patches=self.max_num_patches, + min_num_patches=self.min_num_patches, + pixel_shuffle_scale=self.pixel_shuffle_scale, + ) + + # Update event with processed data + vision_event.data = patches.unsqueeze(1) # Add back frame dimension + vision_event.dims_virtual = dims_virtual + vision_event.dims_real = ( + dims_virtual + if self.pixel_shuffle_scale == 1 + else [ + dims_virtual[0], + dims_virtual[1] * self.pixel_shuffle_scale, + dims_virtual[2] * self.pixel_shuffle_scale, + ] + ) + vision_event.idx_range = (0, math.prod(dims_virtual)) + + # Flatten the patches + vision_event.data = vision_event.data.reshape(-1, vision_event.data.shape[-1]) + processed_vision_events.append(vision_event) + + events = text_events + processed_vision_events + + # Create stream without scheduling (events already in order) + return create_stream(events, priority=[TextType.text, VisionType.image], schedule=True) + + def __call__( + self, + text: Union[str, list[str]], + images: Union[PIL.Image.Image, list[PIL.Image.Image], None] = None, + return_tensors: str | TensorType | None = TensorType.PYTORCH, + **kwargs, + ) -> BatchFeature: + """ + Process text and images into TensorStream format. + Args: + text: Input text or list of texts with vision tokens + images: PIL image or list of images (optional) + return_tensors: Format for output tensors + + Returns: + BatchFeature with input_ids and tensor_stream + """ + # Normalize inputs to lists + if isinstance(text, str): + texts = [text] + else: + texts = text + + if images is not None: + if isinstance(images, PIL.Image.Image): + images_list = [images] + else: + images_list = images + else: + images_list = None + + if len(texts) != 1: + raise ValueError("IsaacProcessor currently supports batch_size=1") + if images_list is not None: + # Count vision tokens in text to validate image count + vision_token_count = texts[0].count(self.vision_token) + if vision_token_count != len(images_list): + raise ValueError( + f"Number of {self.vision_token} tokens in text ({vision_token_count}) " + f"must match number of images ({len(images_list)})" + ) + + # Build event stream + stream = self.build_event_stream_simple( + text=texts[0], + images=images_list, + ) + + # Create TensorStream + tensor_stream = TensorStream([stream]) + + # Slice to max length if needed + _, T = tensor_stream.shape + if T > self.max_sequence_length: + tensor_stream = ts_slice(tensor_stream, start=T - self.max_sequence_length, end=T) + + # Get token view + tokens = tensor_stream_token_view(tensor_stream) + if return_tensors in (TensorType.PYTORCH, "pt"): + input_ids = torch.as_tensor(tokens, dtype=torch.long) + else: + input_ids = tokens + + data = { + "input_ids": input_ids, + "tensor_stream": tensor_stream, + } + + return BatchFeature(data=data) + + +# ============================================================================ +# Model +# ============================================================================ + + +def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: + r"""Create 3D positional indices for token input. + + Args: + input_ids (`torch.Tensor`): + Tensor of shape `(batch_size, seq_len)` containing token ids. + + Returns: + `torch.Tensor`: Positional indices with shape `(batch_size, seq_len, 3)` where each channel duplicates the + 1D position so it can be consumed by the 3-axis MRoPE rotary embedding. + """ + batch_size, seq_length = input_ids.shape + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.unsqueeze(2).expand(-1, -1, 3) # Add 3D for MRoPE + return position_ids + + +class IsaacRotaryEmbedding(nn.Module): + def __init__(self, config: IsaacConfig, device=None): + super().__init__() + + # Extract dimensions from config + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim + + # Get rope_scaling config - use direct access when available + rope_scaling = getattr(config, "rope_scaling", None) or {} + + # Read RopeScaling parameters + self.rope_type = rope_scaling.get("rope_type", "default") + + self.mrope_section = [ + self.head_dim // 4, # 2x more for temporal dim + self.head_dim // 8, + self.head_dim // 8, + ] + + rope_base = getattr(config, "rope_theta", 10000.0) + inv_freq = precompute_inv_freq(rope_base, self.head_dim) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, position_ids: torch.Tensor, modality_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + with torch.no_grad(): + # Ensure non-spatial tokens have 1D rotation equivalence + not_spatial = ~(modality_tensor == VisionType.image.value) + # shape is [N, 1] + data_1d = position_ids[not_spatial][..., 0].unsqueeze(-1) + # now broadcast it from [N, 1] -> [N, D] so it matches pos[not_spatial] exactly + data_1d = data_1d.expand(-1, position_ids.shape[-1]) # expand along the last dim + position_ids = position_ids.clone() # Clone to avoid warning about in-place operations on expanded tensors + position_ids[not_spatial] = data_1d + position_ids = position_ids.permute(2, 0, 1) # pos dim first -> (3, B, L) + cos, sin = precompute_cos_sin_3d(position_ids, self.inv_freq, self.mrope_section) + + return cos, sin + + +class IsaacModel(Qwen3Model): + def __init__(self, config: IsaacConfig): + super().__init__(config) + self.layers = torch.nn.ModuleList( + [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) + + vision_cfg = config.vision_config + if vision_cfg is None: + raise ValueError("IsaacConfig should always have vision_config") + + hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2) + self.vision_embedding = nn.Sequential( + Siglip2SequenceVisionTransformer(vision_cfg), + nn.Linear( + hidden_dim, + 4 * hidden_dim, + bias=False, + ), + nn.SiLU(), + nn.Linear(4 * hidden_dim, config.hidden_size, bias=False), + ) + + # Dispatch table for TensorStream balanced embedding (text + vision) + self.embed_fns = { + TextType: self.embed_text_tokens, + VisionType: self.embed_vision, + } + + def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: + """Embed text tokens, squeezing singleton dimensions.""" + # Text events are shaped as (..., 1); squeeze the singleton index dim + h = self.embed_tokens(token_ids) + if h.dim() >= 2 and h.size(-2) == 1: + h = h[..., 0, :] + return h + + def embed_vision(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + """Embed vision tokens using the vision encoder.""" + # vision tokens is (seq_patches, token_grids) + return self.vision_embedding(vision_tokens) + + def embed_stream(self, tensor_stream: TensorStream) -> torch.Tensor: + """ + Embed each modality stream independently, preserving the original TensorStream + structure. + """ + flat_stream = tensor_stream.flat_stream() + per_modality_stream = group_streams(flat_stream, group_fn=lambda ev: ev.type, schedule=False) + per_modality_compact_stream = {k: v.compact() for k, v in per_modality_stream.items()} + + # Collect per-event grids for vision tokens (H, W like dims sans time) + token_grids = defaultdict(list) + for stream in tensor_stream.streams: + for event in stream: + token_grids[event.type].append(event.dims(virtual=False)) + + embedded_compact = {} + for stream_type, modality_payload_tensor in per_modality_compact_stream.items(): + if stream_type.modality == VisionType: + # Build a (N_events, 2) grid tensor with spatial dims only + grids = token_grids.get(stream_type, []) + if len(grids) == 0: + input_tensor = modality_payload_tensor + else: + token_grids_tensor = torch.tensor(grids, dtype=torch.long, device=tensor_stream.device)[:, 1:] + input_tensor = (modality_payload_tensor, token_grids_tensor) + embedded_compact[stream_type] = self.embed_fns[stream_type.modality](input_tensor) + else: + embedded_compact[stream_type] = self.embed_fns[stream_type.modality](modality_payload_tensor) + + # Reconstruct a TensorStream with embedded payloads and compact + embedded_ts = reconstruct_tensor_stream_from_compact_dict(tensor_stream, embedded_compact) + h = embedded_ts.compact() # (B, T, D) + return h + + def forward( + self, + input_ids: torch.LongTensor | None = None, + tensor_stream: TensorStream | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + modality_tensor: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs, + ) -> tuple | BaseModelOutputWithPast: + """ + Forward pass with MRoPE position embeddings. + + Computes position embeddings once and passes them through all layers. + """ + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Get inputs + if tensor_stream is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both tensor_stream and inputs_embeds") + elif tensor_stream is not None: + # Embed TensorStream directly + inputs_embeds = self.embed_stream(tensor_stream) + # Create modality tensor if not provided + if modality_tensor is None: + modality_tensor = modality_mask(tensor_stream) + elif input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + inputs_embeds = self.embed_tokens(input_ids) + # Create text modality tensor if not provided + if modality_tensor is None: + batch_size, seq_length = input_ids.shape + modality_tensor = torch.full( + (batch_size, seq_length), TextType.text.value, device=input_ids.device, dtype=torch.long + ) + elif inputs_embeds is None: + raise ValueError("You have to specify either tensor_stream, input_ids or inputs_embeds") + + # Create default position_ids if not provided + if position_ids is None: + if tensor_stream is not None: + position_ids = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) + else: + position_ids = compute_position_ids_input_ids(input_ids) + + # Compute MRoPE position embeddings if we have custom rotary_emb + cos, sin = self.rotary_emb(position_ids, modality_tensor) + cos = cos.to(inputs_embeds.dtype) + sin = sin.to(inputs_embeds.dtype) + + # Prepare attention mask + if attention_mask is not None: + attention_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, False + ) + + # Initialize hidden states + hidden_states = inputs_embeds + + for decoder_layer in self.layers: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=(cos, sin), + **kwargs, + ) + + hidden_states = layer_outputs[0] + + # Final layer norm + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): + """Isaac multimodal model for conditional generation.""" + + config_class = IsaacConfig + + def __init__(self, config: IsaacConfig): + Qwen3PreTrainedModel.__init__(self, config) + self.model = IsaacModel(config) # Use our custom model + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Tracks rotary position offsets computed during a full forward pass so decode steps can reuse them. + self.rope_deltas = None + + self.config = config + + def get_rope_index( + self, + input_ids: torch.Tensor | None, + tensor_stream: TensorStream | None, + attention_mask: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute MRoPE position ids from a TensorStream (or 1D fallback). + + Returns (position_ids, rope_deltas). position_ids is (B,L,3) for MRoPE. + rope_deltas is (B,1) used to advance positions in decode. + """ + # tensor_stream present: compute 3D coords + if tensor_stream is None and input_ids is None: + raise ValueError("`tensor_stream` or `input_ids` must be provided to compute rope indices") + + if tensor_stream is not None: + pos_3d = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) + else: + pos_3d = compute_position_ids_input_ids(input_ids) + B, L, _ = pos_3d.shape + + # Max position per batch across the 3 planes and sequence dimension: (B,) + m_per_batch = pos_3d.amax(dim=(1, 2)) + + # Sequence lengths per batch: (B,) + if attention_mask is None: + seq_lens = torch.full_like(m_per_batch, L) + else: + seq_lens = attention_mask.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=m_per_batch.device) + + rope_deltas = (m_per_batch + 1 - seq_lens).to(dtype=pos_3d.dtype).unsqueeze(1) + return pos_3d, rope_deltas + + def forward( + self, + input_ids: torch.LongTensor | None = None, + tensor_stream: TensorStream | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs, + ) -> tuple | CausalLMOutputWithPast: + """ + Forward pass for conditional generation supporting both standard inputs and TensorStream. + Uses our embed_stream approach for multimodal inputs. + """ + + # Don't compute embeddings here - let the model handle it + if tensor_stream is not None: + input_ids = None + if input_ids is None and inputs_embeds is None and tensor_stream is None: + raise ValueError("Either input_ids, inputs_embeds, or tensor_stream must be provided.") + + # Build position ids (MRoPE) if needed and tensor_stream is available + # During decode we reuse `self.rope_deltas` computed on the initial forward pass; `rope_delta` captures how far + # cached rotary phases have progressed so we can advance `position_ids` without rebuilding the TensorStream. + if position_ids is None and tensor_stream is not None: + position_ids, self.rope_deltas = self.get_rope_index(input_ids, tensor_stream, attention_mask) + elif position_ids is None and input_ids is not None: + # For text inputs build position ids and modality tensor + position_ids = compute_position_ids_input_ids(input_ids) + if cache_position is not None and self.rope_deltas is not None: + # Combine the incremental decode step (`cache_position`) with cached offsets so hidden states continue + # rotating in lockstep across generation steps. + rope_delta = (cache_position[0] + self.rope_deltas).to(input_ids.device) + else: + rope_delta = 0 + if cache_position is not None and not isinstance(rope_delta, int): # otherwise `deltas` is an int `0` + batch_size = input_ids.shape[0] + rope_delta = rope_delta.repeat_interleave(batch_size // rope_delta.shape[0], dim=0) + position_ids = position_ids.add(rope_delta) + + if tensor_stream is not None: + modality_tensor = modality_mask(tensor_stream) + else: + batch_size, seq_len = input_ids.shape + modality_tensor = torch.empty(batch_size, seq_len, device=position_ids.device).fill_(TextType.text.value) + + outputs = self.model( + input_ids=input_ids, + tensor_stream=tensor_stream, + attention_mask=attention_mask, + position_ids=position_ids, + modality_tensor=modality_tensor, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=None, + ) + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: list[torch.FloatTensor] | None = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + tensor_stream: TensorStream | None = None, + cache_position: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + use_cache: bool = True, + **kwargs, + ) -> dict[str, Any]: + """ + Prepare inputs for generation, handling TensorStream inputs properly. + """ + # Call parent preparation + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + **kwargs, + ) + + # Handle TensorStream for first forward pass only + if tensor_stream is not None and (cache_position is None or cache_position[0] == 0): + model_inputs["tensor_stream"] = tensor_stream + # Let forward rebuild position_ids using cached deltas during decode + model_inputs["position_ids"] = None + # Drop tensor_stream after step 0 + if cache_position is not None and cache_position[0] != 0: + model_inputs["tensor_stream"] = None + return model_inputs + + def can_generate(self) -> bool: + return True + + +__all__ = [ + "IsaacConfig", + "IsaacModel", + "IsaacForConditionalGeneration", + "IsaacProcessor", +] \ No newline at end of file diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py new file mode 100644 index 000000000000..45766db223d9 --- /dev/null +++ b/src/transformers/models/isaac/processing_isaac.py @@ -0,0 +1,509 @@ +# ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ +# This file was automatically generated from src/transformers/models/isaac/modular_isaac.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_isaac.py file directly. One of our CI enforces this. +# ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ +import math +import re +from typing import Any, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from perceptron.tensorstream import Event, Stream, TensorStream, TextType, VisionType, create_stream +from perceptron.tensorstream.ops import slice as ts_slice +from perceptron.tensorstream.ops import tensor_stream_token_view + +from ...processing_utils import BatchFeature, ProcessorMixin +from ...tokenization_utils import TensorType +from ..auto import AutoTokenizer +from .configuration_isaac import IsaacConfig + + +# ============================================================================ +# Configuration +# ============================================================================ + +MAX_PIXELS = 60_000_000 # 60-megapixel ceiling โ‰ˆ 8200 ร— 7300 px + +# Vision preprocessing constants +VISION_MEAN = (0.5, 0.5, 0.5) +VISION_STD = (0.5, 0.5, 0.5) +VISION_SCALE = 1 / 255 + + +def _make_writeable(arr: np.ndarray) -> np.ndarray: + """Return *arr* itself if it is already writeable, otherwise try to flip the + write flag in-place and finally fall back to `arr.copy()`. + This guarantees the buffer handed to `torch.from_numpy()` is always + writeable, silencing the PyTorch warning about undefined behaviour. + """ + if arr.flags.writeable: + return arr + + # First, try the cheap path โ€” in-place flag toggle (works for mmap'd arrays + # and some shared memory buffers): + try: + arr.setflags(write=True) + return arr # success: no data copy + except ValueError: + # Buffer is inherently read-only (e.g. backed by PyAV / PIL): make copy + return arr.copy() + + +def extract_image_pil(image: PIL.Image.Image) -> torch.Tensor | None: + if image.width * image.height > MAX_PIXELS: + raise ValueError(f"Image (w={image.width}, h={image.height}) > MAX=`{MAX_PIXELS}`") + img = image if image.mode == "RGB" else image.convert("RGB") + arr = np.asarray(img) + arr = _make_writeable(arr) + return torch.from_numpy(arr) + + +def get_image_size_for_max_num_patches( + image_height: int, + image_width: int, + patch_size: int, + max_num_patches: int, + min_num_patches: int | None = None, + eps: float = 1e-5, + pixel_shuffle_scale: int = 1, +) -> tuple[int, int]: + r"""Compute a target resolution whose patch grid satisfies patching parametrization. + + Args: + image_height (`int`): + Height in pixels of the source image prior to any resizing. + image_width (`int`): + Width in pixels of the source image prior to any resizing. + patch_size (`int`): + Size of the square patch used by the vision encoder. + max_num_patches (`int`): + Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. + min_num_patches (`int`, *optional*): + Lower bound on the number of patches. When provided the image will be scaled up if necessary. + eps (`float`, *optional*, defaults to 1e-5): + Convergence tolerance for the internal binary search to determing the target dimensions. + pixel_shuffle_scale (`int`, *optional*, defaults to 1): + Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. + + Returns: + `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` + and respect both the maximum and optional minimum patch-count constraints. + """ + + def get_scaled_image_size(scale, original_size, patch_size, pixel_shuffle_scale): + scaled_size = scale * original_size + divisor = patch_size * pixel_shuffle_scale + scaled_size = math.ceil(scaled_size / divisor) * divisor + scaled_size = max(divisor, scaled_size) + return int(scaled_size) + + # Ensure divisibility + divisor = patch_size * pixel_shuffle_scale + adjusted_height = math.ceil(image_height / divisor) * divisor + adjusted_height = max(divisor, adjusted_height) + adjusted_width = math.ceil(image_width / divisor) * divisor + adjusted_width = max(divisor, adjusted_width) + + num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) + + if min_num_patches is not None and num_patches < min_num_patches: + # Scale up + scale_min, scale_max = 1.0, 100.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + num_patches = (target_height / patch_size) * (target_width / patch_size) + if num_patches >= min_num_patches: + scale_max = scale + else: + scale_min = scale + scale = scale_max + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + return target_height, target_width + elif num_patches <= max_num_patches: + return adjusted_height, adjusted_width + else: + # Scale down + scale_min, scale_max = eps / 10, 1.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + num_patches = (target_height / patch_size) * (target_width / patch_size) + if num_patches <= max_num_patches: + scale_min = scale + else: + scale_max = scale + scale = scale_min + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + return target_height, target_width + + +_MEAN_TENSOR = torch.tensor(VISION_MEAN, dtype=torch.float32).view(1, 1, 1, -1) +_STD_TENSOR = torch.tensor(VISION_STD, dtype=torch.float32).view(1, 1, 1, -1) + + +def prepare_image_tensor( + image: torch.Tensor, + scale: float = VISION_SCALE, +) -> torch.Tensor: + r"""Standardize RGB images prior to patch extraction via rescaling and whitening. + + Args: + image (`torch.Tensor`): + Tensor with shape `(..., height, width, 3)` containing RGB values. The tensor is converted to floating + point if needed. + scale (`float`, *optional*, defaults to `VISION_SCALE`): + Scalar multiplier applied before normalization. + Returns: + `torch.Tensor`: Normalized tensor with the same shape as the input and dtype `torch.float32`. + """ + if not torch.is_floating_point(image): + image = image.float() + rescaled = image * scale + + # Use precomputed tensors and move to the correct device if needed + mean_tensor = _MEAN_TENSOR.to(image.device) + std_tensor = _STD_TENSOR.to(image.device) + + normalized = (rescaled - mean_tensor) / std_tensor + return normalized + + +def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: + r"""Convert normalized images into flattened ViT-style patches. + + Args: + image (`torch.Tensor`): + Tensor of shape `(num_images, height, width, channels)`. + patch_size (`int`): + Edge length of the square patches + + Returns: + `torch.Tensor`: + Patch tensor where each position stores the flattened pixels belonging to that patch. + + Raises: + ValueError: If `height` or `width` is not divisible by `patch_size`. + """ + num_images, height, width, channels = image.shape + if height % patch_size or width % patch_size: + raise ValueError(f"Dimensions of images {image.shape} are not divisible by patch_size={patch_size}.") + patches = image.reshape(num_images, height // patch_size, patch_size, width // patch_size, patch_size, channels) + patches = patches.permute(0, 1, 3, 2, 4, 5) + patches = patches.reshape( + num_images, height // patch_size, width // patch_size, channels * patch_size * patch_size + ) + return patches + + +def process_vision_for_patches( + images: torch.Tensor, + patch_size: int, + max_num_patches: int, + min_num_patches: int | None = None, + pixel_shuffle_scale: int = 1, +) -> tuple[torch.Tensor, list[int]]: + r"""Resize, normalize, and patchify RGB images for the vision encoder. + + Args: + images (`torch.Tensor`): + Either `(height, width, channels)` for a single image or `(num_images, height, width, channels)` for a + batch. Channels are expected to be RGB. + patch_size (`int`): + Edge length of square patches; implictly controls resize grid granularity. + max_num_patches (`int`): + Maximum number of patches allowed after resizing. + min_num_patches (`int`, *optional*): + Minimum number of patches. If provided, the routine upsamples images as needed to satisfy the lower bound. + pixel_shuffle_scale (`int`, *optional*, defaults to 1): + pixel shuffle scale factor; influences the target grid that the function produces. + + Returns: + `tuple[torch.Tensor, list[int]]`: A pair `(patches, dims_virtual)` where `patches` has shape + `(num_images, target_h / patch_size, target_w / patch_size, channels * patch_size**2)` and `dims_virtual` + encodes effective `(images, height, width)` dimensions after optional pixel shuffling. + """ + # Add batch dim if single image + if images.dim() == 3: + images = images.unsqueeze(0) + + # Permute to channel first for resize + images = images.permute(0, 3, 1, 2) + + # Get target dimensions + _, _, orig_height, orig_width = images.shape + target_height, target_width = get_image_size_for_max_num_patches( + orig_height, + orig_width, + patch_size, + max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + ) + + # Resize + images = F.interpolate( + images, + size=(target_height, target_width), + mode="bilinear", + align_corners=False, + ) + + # Back to channel last + images = images.permute(0, 2, 3, 1) + + # Normalize + images = prepare_image_tensor(images) + + # Patchify + patches = patchify_vision(images, patch_size=patch_size) + + # Calculate dimensions for the patches + n_images, h_patches, w_patches, _ = patches.shape + dims_virtual = ( + [1, h_patches, w_patches] + if pixel_shuffle_scale == 1 + else [1, h_patches // pixel_shuffle_scale, w_patches // pixel_shuffle_scale] + ) + + return patches, dims_virtual + + +# ============================================================================ +# Processor Components +# ============================================================================ + + +def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> Event: + r"""Wrap a text into an `Event` compatible with the multimodal TensorStream. + + Args: + tokenizer (`AutoTokenizer`): + Tokenizer used to convert text into model vocabulary ids. + text (`str`): + Plain-text fragment to encode. + time (`float`, *optional*, defaults to 0.0): + Timeline coordinate associated with the event. Both start and end times use the same value because text + segments are instantaneous in the scheduler. + + Returns: + `Event`: Event carrying a `(num_tokens, 1)` tensor of token ids with matching + metadata so that downstream processors can compute modality-specific embeddings. + """ + tokens = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt").squeeze(0) + + # Calculate dimensions for the event + num_tokens = len(tokens) + dims_virtual = [num_tokens, 1] # [sequence_length, 1] + dims_real = dims_virtual.copy() + + # Ensure tokens has the right shape for tensor_stream_token_view + # It expects a 2D tensor where sum(dim=-1) gives the token IDs + if tokens.dim() == 1: + tokens = tokens.unsqueeze(-1) + + return Event( + data=tokens, + type=TextType.text, + time=(time, time), + dims_virtual=dims_virtual, + dims_real=dims_real, + idx_range=(0, num_tokens), + ) + + +# ============================================================================ +# Processor +# ============================================================================ + + +class IsaacProcessor(ProcessorMixin): + attributes = [] + tokenizer_class = ("AutoTokenizer",) + + def __init__( + self, + tokenizer: AutoTokenizer, + config: IsaacConfig, + ): + super().__init__() + self.tokenizer = tokenizer + self.config = config + + # Use vision token from config + self.vision_token = config.vision_token + + # Processing parameters + self.max_sequence_length = config.max_sequence_length + + # Vision processing parameters + self.patch_size = config.video_patch_size + self.max_num_patches = config.vision_max_num_patches + self.min_num_patches = config.vision_min_num_patches + self.pixel_shuffle_scale = config.pixel_shuffle_scale + + def apply_chat_template( + self, + messages: list[dict[str, Any]], + tokenize: bool = False, + add_generation_prompt: bool = False, + **kwargs, + ) -> Any: + return self.tokenizer.apply_chat_template( + messages, tokenize=tokenize, add_generation_prompt=add_generation_prompt, **kwargs + ) + + def build_event_stream_simple( + self, + text: str, + images: list[PIL.Image.Image] | None = None, + ) -> Stream: + events = [] + # Process text and images + # Find all occurrences of vision token + + pattern = re.escape(self.vision_token) + parts = re.split(f"({pattern})", text) # Keep the delimiter in the result + + image_idx = 0 + for current_time, part in enumerate(parts): + if part == self.vision_token: + # Replace vision token with image event + if image_idx < len(images): + # Create vision event from PIL image + image_tensor = extract_image_pil(images[image_idx]) + if image_tensor is not None: + # Create a vision event with the image tensor + vision_event = Event( + data=image_tensor.unsqueeze(0), # HWC format from extract_image_pil + type=VisionType.image, # I-frame + time=(current_time, current_time), + ) + events.append(vision_event) + image_idx += 1 + elif part: # Non-empty text part + # tokens = self.text_processor.tokenize(part, add_special_tokens=False) + text_event = create_text_event(self.tokenizer, part, time=current_time) + events.append(text_event) + + # Process vision events if any + if any(event.type == VisionType.image for event in events): + # Separate text and vision events for processing + text_events = [event for event in events if event.type == TextType.text] + vision_events = [event for event in events if event.type == VisionType.image] + + # Process vision events using functional approach + processed_vision_events = [] + for vision_event in vision_events: + # Process the vision data + patches, dims_virtual = process_vision_for_patches( + vision_event.data.squeeze(0), # Remove the extra dimension + patch_size=self.patch_size, + max_num_patches=self.max_num_patches, + min_num_patches=self.min_num_patches, + pixel_shuffle_scale=self.pixel_shuffle_scale, + ) + + # Update event with processed data + vision_event.data = patches.unsqueeze(1) # Add back frame dimension + vision_event.dims_virtual = dims_virtual + vision_event.dims_real = ( + dims_virtual + if self.pixel_shuffle_scale == 1 + else [ + dims_virtual[0], + dims_virtual[1] * self.pixel_shuffle_scale, + dims_virtual[2] * self.pixel_shuffle_scale, + ] + ) + vision_event.idx_range = (0, math.prod(dims_virtual)) + + # Flatten the patches + vision_event.data = vision_event.data.reshape(-1, vision_event.data.shape[-1]) + processed_vision_events.append(vision_event) + + events = text_events + processed_vision_events + + # Create stream without scheduling (events already in order) + return create_stream(events, priority=[TextType.text, VisionType.image], schedule=True) + + def __call__( + self, + text: Union[str, list[str]], + images: Union[PIL.Image.Image, list[PIL.Image.Image], None] = None, + return_tensors: str | TensorType | None = TensorType.PYTORCH, + **kwargs, + ) -> BatchFeature: + """ + Process text and images into TensorStream format. + Args: + text: Input text or list of texts with vision tokens + images: PIL image or list of images (optional) + return_tensors: Format for output tensors + + Returns: + BatchFeature with input_ids and tensor_stream + """ + # Normalize inputs to lists + if isinstance(text, str): + texts = [text] + else: + texts = text + + if images is not None: + if isinstance(images, PIL.Image.Image): + images_list = [images] + else: + images_list = images + else: + images_list = None + + if len(texts) != 1: + raise ValueError("IsaacProcessor currently supports batch_size=1") + if images_list is not None: + # Count vision tokens in text to validate image count + vision_token_count = texts[0].count(self.vision_token) + if vision_token_count != len(images_list): + raise ValueError( + f"Number of {self.vision_token} tokens in text ({vision_token_count}) " + f"must match number of images ({len(images_list)})" + ) + + # Build event stream + stream = self.build_event_stream_simple( + text=texts[0], + images=images_list, + ) + + # Create TensorStream + tensor_stream = TensorStream([stream]) + + # Slice to max length if needed + _, T = tensor_stream.shape + if T > self.max_sequence_length: + tensor_stream = ts_slice(tensor_stream, start=T - self.max_sequence_length, end=T) + + # Get token view + tokens = tensor_stream_token_view(tensor_stream) + if return_tensors in (TensorType.PYTORCH, "pt"): + input_ids = torch.as_tensor(tokens, dtype=torch.long) + else: + input_ids = tokens + + data = { + "input_ids": input_ids, + "tensor_stream": tensor_stream, + } + + return BatchFeature(data=data) + + +__all__ = ["IsaacProcessor"] diff --git a/tests/models/isaac/__init__.py b/tests/models/isaac/__init__.py new file mode 100644 index 000000000000..199f5353a864 --- /dev/null +++ b/tests/models/isaac/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file From 63d1b1b63e1de8cbd3b732710ad0b6b0b779adec Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Fri, 10 Oct 2025 11:11:36 +0400 Subject: [PATCH 02/77] style: fixing assorted PR notes --- .../models/isaac/modular_isaac.py | 1526 ++++++++++++----- 1 file changed, 1119 insertions(+), 407 deletions(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 4cc1e157c363..313c3ce76d5e 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -1,36 +1,27 @@ from __future__ import annotations +import copy +import math from collections import defaultdict -from typing import Any, Union, TypedDict +from collections.abc import Sequence +from typing import Any, TypedDict -import math import numpy as np +import PIL.Image import torch import torch.nn as nn import torch.nn.functional as F -import PIL.Image -from ...utils import logging -from ...processing_utils import ProcessorMixin, BatchFeature -from ...tokenization_utils_base import PreTrainedTokenizerBase -from ..auto import AutoTokenizer -from ..qwen3.configuration_qwen3 import Qwen3Config -from ..qwen3.modeling_qwen3 import ( - Qwen3ForCausalLM, - Qwen3PreTrainedModel, - Qwen3DecoderLayer, - Qwen3Model -) -from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...tokenization_utils import TensorType -import re +try: + from torchvision.transforms.v2 import functional as TVF +except ImportError: + TVF = None -from ..siglip2.modeling_siglip2 import Siglip2MLP -from ..siglip2.configuration_siglip2 import Siglip2VisionConfig -from perceptron.tensorstream import ( +import re + +from genesis.public.tensorstream.tensor_stream import ( Event, Stream, TensorStream, @@ -39,24 +30,97 @@ create_stream, group_streams, ) -from perceptron.tensorstream.ops import ( +from genesis.public.tensorstream.tensor_stream_utils import ( compute_mrope_pos_tensor, modality_mask, reconstruct_tensor_stream_from_compact_dict, - slice as ts_slice, tensor_stream_token_view, ) +from genesis.public.tensorstream.tensor_stream_utils import ( + slice as ts_slice, +) + +from ...cache_utils import Cache, SlidingWindowCache, StaticCache +from ...feature_extraction_utils import BatchFeature +from ...generation.utils import GenerationMixin +from ...image_processing_utils import BaseImageProcessor +from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict +from ...image_transforms import convert_to_rgb +from ...image_utils import ( + ImageInput, + PILImageResampling, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...processing_utils import ImagesKwargs, ProcessorMixin, Unpack +from ...tokenization_utils import TensorType +from ...utils import auto_docstring, filter_out_non_signature_kwargs +from ...utils.import_utils import is_torchdynamo_compiling +from ..auto.image_processing_auto import AutoImageProcessor +from ..auto.modeling_auto import AutoModel +from ..auto.tokenization_auto import AutoTokenizer +from ..qwen2.tokenization_qwen2 import Qwen2Tokenizer +from ..qwen3.configuration_qwen3 import Qwen3Config +from ..qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3PreTrainedModel +from ..siglip2.configuration_siglip2 import Siglip2VisionConfig +from ..siglip2.modeling_siglip2 import Siglip2EncoderLayer as HFSiglip2EncoderLayer + + +# Vision preprocessing constants +VISION_MEAN = (0.5, 0.5, 0.5) +VISION_STD = (0.5, 0.5, 0.5) +VISION_SCALE = 1 / 255 + + + + +def _normalize_rgb_values( + values: float | Sequence[float] | tuple[float, ...], + *, + name: str, +) -> tuple[float, float, float]: + """Coerce RGB normalization parameters into a 3-tuple of floats.""" + if isinstance(values, (list, tuple)): + if len(values) == 3: + return tuple(float(v) for v in values) + if len(values) == 1: + value = float(values[0]) + return (value, value, value) + raise ValueError(f"`{name}` must have length 1 or 3 when provided as a sequence. Got length {len(values)}.") + + value = float(values) + return (value, value, value) + + +def _make_writeable(arr: np.ndarray) -> np.ndarray: + if arr.flags.writeable: + return arr + try: + arr.setflags(write=True) + return arr + except ValueError: + return arr.copy() -logger = logging.get_logger(__name__) -class PixelShuffleSiglip2VisionConfig(Siglip2VisionConfig): +class IsaacVisionConfig(Siglip2VisionConfig): """Vision configuration for Isaac with Pixel Shuffle support. Extends Siglip2VisionConfig with additional fields for pixel shuffle. + + Args: + pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1): + Spatial factor applied before pixel shuffle reduces the resolution. + num_patches (`int`, *optional*, defaults to 256): + Maximum number of learnable positional embeddings to initialize. """ - model_type = "pixel_shuffle_siglip2" + model_type = "isaac_vision" base_config_key = "vision_config" def __init__( @@ -72,16 +136,493 @@ def __init__( self.num_patches = num_patches -def create_cumulative_seq_lengths(seq_sizes: torch.Tensor, device: torch.device) -> tuple[torch.Tensor, int]: - """Create cumulative sequence lengths for variable-length attention.""" - cu_seqlens = torch.zeros(len(seq_sizes) + 1, dtype=torch.int32, device=device) - cu_seqlens[1:] = seq_sizes.cumsum(0) - max_seqlen = int(seq_sizes.max().item()) if len(seq_sizes) > 0 else 0 - return cu_seqlens, max_seqlen +class IsaacImageProcessorKwargs(ImagesKwargs): + patch_size: int | None + max_num_patches: int | None + min_num_patches: int | None + pixel_shuffle_scale: int | None + do_rescale: bool | None + rescale_factor: float | None + do_normalize: bool | None + image_mean: float | Sequence[float] | None + image_std: float | Sequence[float] | None + do_convert_rgb: bool | None + + +@auto_docstring +class IsaacImageProcessorFast(BaseImageProcessorFast): + slow_image_processor_class = None + r"""Fast torch-based image processor for Isaac vision inputs.""" + + resample = PILImageResampling.BILINEAR + model_input_names = ["patches", "token_grids"] + valid_kwargs = IsaacImageProcessorKwargs + unused_kwargs = ["size", "do_center_crop", "crop_size"] + + def __init__( + self, + *, + patch_size: int = 16, + max_num_patches: int = 256, + min_num_patches: int | None = None, + pixel_shuffle_scale: int = 1, + do_rescale: bool = True, + rescale_factor: float | None = None, + do_normalize: bool = True, + image_mean: float | Sequence[float] | None = None, + image_std: float | Sequence[float] | None = None, + do_convert_rgb: bool = True, + **kwargs: Unpack[IsaacImageProcessorKwargs], + ) -> None: + super().__init__(**kwargs) + + if pixel_shuffle_scale < 1: + raise ValueError("`pixel_shuffle_scale` must be >= 1") + mean_values = _normalize_rgb_values( + image_mean if image_mean is not None else VISION_MEAN, name="image_mean" + ) + std_values = _normalize_rgb_values( + image_std if image_std is not None else VISION_STD, name="image_std" + ) -class Siglip2VariableSequenceEmbeddings(nn.Module): - def __init__(self, config: PixelShuffleSiglip2VisionConfig): + self.patch_size = patch_size + self.max_num_patches = max_num_patches + self.min_num_patches = min_num_patches + self.pixel_shuffle_scale = pixel_shuffle_scale + self.do_rescale = do_rescale + self.rescale_factor = VISION_SCALE if rescale_factor is None else float(rescale_factor) + self.do_normalize = do_normalize + self.image_mean = list(mean_values) + self.image_std = list(std_values) + self.do_convert_rgb = do_convert_rgb + + def _validate_preprocess_kwargs(self, **kwargs): + # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for. + kwargs.pop("do_resize", None) + return super()._validate_preprocess_kwargs(**kwargs) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + patch_size: int, + max_num_patches: int, + interpolation: Any | None, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: float | Sequence[float] | None, + image_std: float | Sequence[float] | None, + return_tensors: str | TensorType | None, + *, + min_num_patches: int | None = None, + pixel_shuffle_scale: int | None = None, + do_convert_rgb: bool | None = None, + **kwargs, + ) -> BatchFeature: + if TVF is None: + raise ImportError("torchvision is required for IsaacImageProcessorFast but is not installed.") + + min_num_patches = min_num_patches if min_num_patches is not None else self.min_num_patches + pixel_shuffle_scale = pixel_shuffle_scale if pixel_shuffle_scale is not None else self.pixel_shuffle_scale + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_rescale = self.do_rescale if do_rescale is None else do_rescale + do_normalize = self.do_normalize if do_normalize is None else do_normalize + rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor + + mean_values = _normalize_rgb_values( + image_mean if image_mean is not None else self.image_mean, name="image_mean" + ) + std_values = _normalize_rgb_values( + image_std if image_std is not None else self.image_std, name="image_std" + ) + + patches_list: list[torch.Tensor] = [] + token_grids: list[torch.Tensor] = [] + virtual_dims: list[list[int]] = [] + real_dims: list[list[int]] = [] + + for image in images: + if image.ndim != 3: + raise ValueError("Expected channel-first image tensor with shape (C, H, W).") + + channels, original_height, original_width = image.shape + if do_convert_rgb and channels == 1: + image = image.repeat(3, 1, 1) + channels = 3 + + if original_height * original_width > MAX_PIXELS: + raise ValueError( + f"Image (w={original_width}, h={original_height}) > MAX=`{MAX_PIXELS}`" + ) + + target_height, target_width = get_image_size_for_max_num_patches( + original_height, + original_width, + patch_size, + max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + ) + + if do_resize: + size_dict = SizeDict(height=target_height, width=target_width) + image = self.resize(image=image, size=size_dict, interpolation=interpolation) + else: + if ((original_height % patch_size) != 0) or ((original_width % patch_size) != 0): + raise ValueError( + "Image dimensions must be divisible by patch_size when resize is disabled." + ) + + # Apply rescaling and normalization as needed + image = self.rescale_and_normalize( + image, + do_rescale, + rescale_factor, + do_normalize, + list(mean_values), + list(std_values), + ) + + # Convert to NHWC for residual P-frame adjustment and patch extraction + nhwc_image = image.permute(1, 2, 0).unsqueeze(0) + nhwc_image = _compute_residual_p_frames(nhwc_image, is_p_frame=[False]) + + patches = patchify_vision(nhwc_image, patch_size=patch_size).squeeze(0) + height_tokens, width_tokens, _ = patches.shape + + patches_list.append(patches.unsqueeze(0)) + token_grids.append( + torch.tensor([height_tokens, width_tokens], dtype=torch.long, device=patches.device) + ) + + real_dims.append([1, height_tokens, width_tokens]) + if pixel_shuffle_scale > 1: + if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): + raise ValueError( + "Spatial dimensions must be divisible by pixel_shuffle_scale when pixel shuffle is enabled." + ) + virtual_dims.append( + [1, height_tokens // pixel_shuffle_scale, width_tokens // pixel_shuffle_scale] + ) + else: + virtual_dims.append([1, height_tokens, width_tokens]) + + patches_tensor = torch.cat(patches_list, dim=0) + token_grids_tensor = torch.stack(token_grids, dim=0) + virtual_dims_tensor = torch.tensor(virtual_dims, dtype=torch.long, device=patches_tensor.device) + real_dims_tensor = torch.tensor(real_dims, dtype=torch.long, device=patches_tensor.device) + + batch_feature = BatchFeature( + data={ + "patches": patches_tensor, + "token_grids": token_grids_tensor, + "virtual_pixel_size": virtual_dims_tensor, + "real_pixel_size": real_dims_tensor, + }, + tensor_type=return_tensors, + ) + return batch_feature + + + + +class IsaacImageProcessor(BaseImageProcessor): + """Image processor that prepares RGB frames for the Isaac vision encoder.""" + + model_input_names = ["patches", "token_grids"] + + def __init__( + self, + patch_size: int = 16, + max_num_patches: int = 256, + min_num_patches: int | None = None, + pixel_shuffle_scale: int = 1, + do_rescale: bool = True, + rescale_factor: float | None = None, + do_normalize: bool = True, + image_mean: float | Sequence[float] | None = None, + image_std: float | Sequence[float] | None = None, + do_convert_rgb: bool = True, + resize_mode: str = "bilinear", + align_corners: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + if pixel_shuffle_scale < 1: + raise ValueError("`pixel_shuffle_scale` must be >= 1") + + rescale_value = VISION_SCALE if rescale_factor is None else float(rescale_factor) + mean_value = VISION_MEAN if image_mean is None else image_mean + std_value = VISION_STD if image_std is None else image_std + + self.patch_size = patch_size + self.max_num_patches = max_num_patches + self.min_num_patches = min_num_patches + self.pixel_shuffle_scale = pixel_shuffle_scale + self.do_rescale = do_rescale + self.rescale_factor = rescale_value + self.do_normalize = do_normalize + self.image_mean = _normalize_rgb_values(mean_value, name="image_mean") + self.image_std = _normalize_rgb_values(std_value, name="image_std") + self.do_convert_rgb = do_convert_rgb + self.resize_mode = resize_mode + self.align_corners = align_corners + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + patch_size: int | None = None, + max_num_patches: int | None = None, + min_num_patches: int | None = None, + pixel_shuffle_scale: int | None = None, + do_rescale: bool | None = None, + rescale_factor: float | None = None, + do_normalize: bool | None = None, + image_mean: float | Sequence[float] | None = None, + image_std: float | Sequence[float] | None = None, + do_convert_rgb: bool | None = None, + return_tensors: str | TensorType | None = None, + ) -> BatchFeature: + patch_size = patch_size if patch_size is not None else self.patch_size + max_num_patches = max_num_patches if max_num_patches is not None else self.max_num_patches + min_num_patches = min_num_patches if min_num_patches is not None else self.min_num_patches + pixel_shuffle_scale = ( + pixel_shuffle_scale if pixel_shuffle_scale is not None else self.pixel_shuffle_scale + ) + do_rescale = self.do_rescale if do_rescale is None else do_rescale + rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor + do_normalize = self.do_normalize if do_normalize is None else do_normalize + image_mean = self.image_mean if image_mean is None else _normalize_rgb_values(image_mean, name="image_mean") + image_std = self.image_std if image_std is None else _normalize_rgb_values(image_std, name="image_std") + do_convert_rgb = self.do_convert_rgb if do_convert_rgb is None else do_convert_rgb + + images = self.fetch_images(images) + images = make_flat_list_of_images(images) + + if not images: + raise ValueError("Received an empty list of images for preprocessing.") + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + if not valid_images(images): + raise ValueError( + "Invalid image type. Expected PIL images, numpy arrays, or tensors convertible to numpy arrays." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + ) + + patches_list = [] + token_grids = [] + virtual_dims = [] + real_dims = [] + + for image in images: + np_image = to_numpy_array(image) + + if np_image.ndim == 2: + np_image = np.repeat(np_image[..., None], 3, axis=-1) + + height, width = np_image.shape[:2] + if height * width > MAX_PIXELS: + raise ValueError(f"Image (w={width}, h={height}) > MAX=`{MAX_PIXELS}`") + + torch_image = torch.from_numpy(_make_writeable(np_image)) + patches, vidims, rdims = self._process_single_image( + torch_image, + patch_size=patch_size, + max_num_patches=max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + ) + + patches_list.append(patches) + token_grids.append(torch.tensor([patches.size(1), patches.size(2)], dtype=torch.long)) + virtual_dims.append(vidims) + real_dims.append(rdims) + + patches_tensor = torch.cat(patches_list, dim=0) + token_grid_tensor = torch.stack(token_grids, dim=0) + virtual_dims_tensor = torch.tensor(virtual_dims, dtype=torch.long) + real_dims_tensor = torch.tensor(real_dims, dtype=torch.long) + + data = { + "patches": patches_tensor, + "token_grids": token_grid_tensor, + "virtual_pixel_size": virtual_dims_tensor, + "real_pixel_size": real_dims_tensor, + } + + return BatchFeature(data=data, tensor_type=return_tensors) + + def _process_single_image( + self, + image: torch.Tensor, + *, + patch_size: int, + max_num_patches: int, + min_num_patches: int | None, + pixel_shuffle_scale: int, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: tuple[float, ...], + image_std: tuple[float, ...], + ) -> tuple[torch.Tensor, list[int], list[int]]: + image_uint8 = image.unsqueeze(0) # (1, H, W, C) + image_chw = image_uint8.permute(0, 3, 1, 2) # (1, C, H, W) + + _, _, orig_height, orig_width = image_chw.shape + target_height, target_width = get_image_size_for_max_num_patches( + orig_height, + orig_width, + patch_size, + max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + ) + + if self.resize_mode in {"linear", "bilinear", "bicubic", "trilinear"}: + resized = F.interpolate( + image_chw, + size=(target_height, target_width), + mode=self.resize_mode, + align_corners=self.align_corners, + ) + else: + resized = F.interpolate( + image_chw, + size=(target_height, target_width), + mode=self.resize_mode, + ) + + resized = resized.permute(0, 2, 3, 1) # (1, H, W, C) + + scale = rescale_factor if do_rescale else 1.0 + mean = image_mean if do_normalize else (0.0, 0.0, 0.0) + std = image_std if do_normalize else (1.0, 1.0, 1.0) + resized = _prepare_image_tensor(resized, scale=scale, mean=mean, std=std) + + resized = _compute_residual_p_frames(resized, is_p_frame=[False]) + + patches = patchify_vision(resized, patch_size=patch_size) + _, h_patches, w_patches, _ = patches.shape + + real_dims = [1, h_patches, w_patches] + if pixel_shuffle_scale > 1: + if (h_patches % pixel_shuffle_scale) or (w_patches % pixel_shuffle_scale): + raise ValueError( + "Spatial dimensions must be divisible by pixel_shuffle_scale when pixel shuffle is enabled." + ) + virtual_dims = [1, h_patches // pixel_shuffle_scale, w_patches // pixel_shuffle_scale] + else: + virtual_dims = real_dims.copy() + + return patches, virtual_dims, real_dims + + +def _max_from_cu(cu: torch.Tensor | None, fallback: int) -> int: + """Helper to compute max sequence length from cumulative sequence lengths.""" + if cu is None or len(cu) < 2: + return fallback + return int((cu[1:] - cu[:-1]).max().item()) + + +def flash_attention_document_mask_forward( + module: torch.nn.Module, + q_lhd: torch.Tensor, # (L, H, D) + k_lhd: torch.Tensor, # (L, H, D) + v_lhd: torch.Tensor, # (L, H, D) + attention_mask: torch.Tensor | None = None, # unused for FA path + dropout: float = 0.0, + scaling: float | None = None, + cum_seq_q: torch.Tensor | None = None, + cum_seq_k: torch.Tensor | None = None, + max_seqlen: int | None = None, + is_causal: bool = False, + **kwargs, +) -> tuple[torch.Tensor, None]: + """FlashAttention that consumes (L, H, D) directly to avoid layout churn.""" + L, H, D = q_lhd.shape + + # Compute max block length once (honor caller when provided) + if max_seqlen is not None: + max_q = max_k = int(max_seqlen) + else: + max_q = _max_from_cu(cum_seq_q, L) + max_k = _max_from_cu(cum_seq_k, L) + + # Ensure contiguity only if needed + if not q_lhd.is_contiguous(): + q_lhd = q_lhd.contiguous() + if not k_lhd.is_contiguous(): + k_lhd = k_lhd.contiguous() + if not v_lhd.is_contiguous(): + v_lhd = v_lhd.contiguous() + + out_lhd, *_ = torch.ops.aten._flash_attention_forward( + query=q_lhd, # (L, H, D) + key=k_lhd, # (L, H, D) + value=v_lhd, # (L, H, D) + cum_seq_q=cum_seq_q, + cum_seq_k=cum_seq_k, + max_q=max_q, + max_k=max_k, + dropout_p=dropout, + is_causal=is_causal, + return_debug_mask=False, + scale=scaling, + window_size_left=-1, + window_size_right=-1, + alibi_slopes=None, + ) + return out_lhd, None # (L, H, D) + + +def sdpa_document_mask_forward( + q_lhd: torch.Tensor, # (L, H, D) + k_lhd: torch.Tensor, # (L, H, D) + v_lhd: torch.Tensor, # (L, H, D) + dropout: float, + scaling: float | None, + cu_seqlens: torch.Tensor | None, +) -> torch.Tensor: + """SDPA with block-diagonal masking for variable-length sequences.""" + L, H, D = q_lhd.shape + + # Transpose to (1, H, L, D) format for SDPA + Q = q_lhd.permute(1, 0, 2).unsqueeze(0) + K = k_lhd.permute(1, 0, 2).unsqueeze(0) + V = v_lhd.permute(1, 0, 2).unsqueeze(0) + + # Build block-diagonal mask for variable-length sequences + attn_mask = None + if cu_seqlens is not None: + seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() + seg_ids = torch.repeat_interleave(torch.arange(len(seq_sizes), device=q_lhd.device), seq_sizes) + block_mask = seg_ids[:, None] != seg_ids[None, :] # Cross-document attention blocked + attn_mask = torch.where(block_mask, -torch.inf, 0.0).to(q_lhd.dtype).view(1, 1, L, L) + + Y = F.scaled_dot_product_attention(Q, K, V, attn_mask=attn_mask, dropout_p=dropout, scale=scaling) + return Y.squeeze(0).permute(1, 0, 2) # Back to (L, H, D) + + +class IsaacVisionEmbeddings(nn.Module): + def __init__(self, config: IsaacVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -96,9 +637,7 @@ def __init__(self, config: PixelShuffleSiglip2VisionConfig): self.position_embedding_size = int(self.num_patches**0.5) self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) - def positional_embeddings( - self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor] - ) -> torch.Tensor: + def positional_embeddings(self, spatial_shapes: torch.Tensor) -> torch.Tensor: # Prepare positional embeddings grid: (1, embed_dim, h, w) positional_embeddings = ( self.position_embedding.weight.reshape(self.position_embedding_size, self.position_embedding_size, -1) @@ -106,11 +645,9 @@ def positional_embeddings( .unsqueeze(0) ) - _seq_patches, _seq_sizes, spatial_shapes = packed_seq_patches pos_embeds_list = [] mode = "bilinear" align_corners = False - antialias = True for spatial_shape in spatial_shapes: height, width = spatial_shape # Guard to ensure height and width are positive for torch.compile @@ -120,35 +657,33 @@ def positional_embeddings( size=(height, width), mode=mode, align_corners=align_corners, - antialias=antialias, + antialias=True, ) # Reshape from (1, embed_dim, height, width) to (height*width, embed_dim) resized_pos_embed = resized_pos_embed.reshape(self.embed_dim, height * width).transpose(0, 1) else: # Fallback - should never happen in practice - resized_pos_embed = positional_embeddings.reshape( - self.embed_dim, self.position_embedding_size * self.position_embedding_size - ).transpose(0, 1)[: height * width] + raise RuntimeError( + "Encountered non-positive spatial dimensions while computing positional embeddings." + ) pos_embeds_list.append(resized_pos_embed) # Concatenate all positional embeddings along the sequence dimension pos_embeds = torch.cat(pos_embeds_list, dim=0) return pos_embeds - def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor]): - seq_patches, _seq_sizes, _spatial_shapes = packed_seq_patches - + def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor): # Apply patch embeddings target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(seq_patches.to(dtype=target_dtype)) - pos_embeds = self.positional_embeddings(packed_seq_patches) + pos_embeds = self.positional_embeddings(spatial_shapes) # Add positional embeddings to patch embeddings embeddings = patch_embeds + pos_embeds return embeddings -class Siglip2VariableLengthAttention(nn.Module): +class IsaacVisionAttention(nn.Module): """Custom attention that supports variable-length sequences with flash attention.""" def __init__(self, config): @@ -171,71 +706,51 @@ def __init__(self, config): self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None): - batch_size, seq_len, _ = hidden_states.size() - - # For variable-length attention, we need to reshape to (total_tokens, embed_dim) + # Expect packed sequences with batch_size == 1 + batch_size, L, _ = hidden_states.shape if batch_size != 1: - raise ValueError("Variable-length attention expects batch_size=1 for packed sequences") - hidden_states = hidden_states.squeeze(0) # Remove batch dimension: (seq_len, embed_dim) - - # Store original dtype - orig_dtype = hidden_states.dtype - - # 1. Linear projections - Q = self.q_proj(hidden_states) # (seq_len, embed_dim) - K = self.k_proj(hidden_states) # (seq_len, embed_dim) - V = self.v_proj(hidden_states) # (seq_len, embed_dim) - - # 2. Reshape for multi-head attention: (seq_len, n_heads, head_dim) - Q = Q.view(-1, self.num_heads, self.embed_dim // self.num_heads) - K = K.view(-1, self.num_heads, self.embed_dim // self.num_heads) - V = V.view(-1, self.num_heads, self.embed_dim // self.num_heads) - - # 3. Apply variable-length attention using flash attention - attn_output, _, _, _, _ = torch.ops.aten._flash_attention_forward( - query=Q, - key=K, - value=V, - cum_seq_q=cu_seqlens, - cum_seq_k=cu_seqlens, - max_q=max_seqlen, - max_k=max_seqlen, - dropout_p=self.dropout if self.training else 0.0, - is_causal=False, - return_debug_mask=False, - scale=self.scale, - window_size_left=-1, - window_size_right=-1, - alibi_slopes=None, - ) - - # 4. Reshape attention output from (seq_len, n_heads, head_dim) to (seq_len, embed_dim) - attn_output = attn_output.reshape(seq_len, self.embed_dim) - - # 5. Convert back to original dtype if needed - if attn_output.dtype != orig_dtype: - attn_output = attn_output.to(orig_dtype) - - # 6. Project output - attn_output = self.out_proj(attn_output) # (seq_len, embed_dim) - - # 7. Add back batch dimension for compatibility - attn_output = attn_output.unsqueeze(0) # (1, seq_len, embed_dim) + raise ValueError("packed variable-length attention expects batch_size=1") + x = hidden_states[0] # (L, E) + + H = self.num_heads + D = self.head_dim + p_drop = self.dropout if self.training else 0.0 + + # Project and reshape to (L, H, D) + q = self.q_proj(x).view(L, H, D) + k = self.k_proj(x).view(L, H, D) + v = self.v_proj(x).view(L, H, D) + + attn_impl = getattr(self.config, "_attn_implementation", "flash_attention_3") + + if attn_impl in ("flash_attention_2", "flash_attention_3"): + y_lhd, _ = flash_attention_document_mask_forward( + self, + q, + k, + v, + attention_mask=None, + dropout=p_drop, + scaling=self.scale, + cum_seq_q=cu_seqlens, + cum_seq_k=cu_seqlens, + max_seqlen=max_seqlen, + is_causal=False, + ) + else: + y_lhd = sdpa_document_mask_forward(q, k, v, dropout=p_drop, scaling=self.scale, cu_seqlens=cu_seqlens) - return attn_output, None + # Merge heads and project + y = self.out_proj(y_lhd.reshape(L, self.embed_dim)) + return y.unsqueeze(0), None # (1, L, E) -class IsaacSiglip2EncoderLayer(nn.Module): - """Siglip2 encoder layer with variable-length attention.""" +class IsaacVisionEncoderLayer(HFSiglip2EncoderLayer): + """Isaac vision encoder layer with variable-length attention.""" - def __init__(self, config: PixelShuffleSiglip2VisionConfig): - super().__init__() - self.embed_dim = config.hidden_size - self.self_attn = Siglip2VariableLengthAttention(config) - - self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = Siglip2MLP(config) # Use HF's Siglip2MLP - self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + def __init__(self, config: IsaacVisionConfig): + super().__init__(config) + self.self_attn = IsaacVisionAttention(config) def forward( self, @@ -263,13 +778,13 @@ def forward( return (hidden_states,) -class IsaacEncoder(nn.Module): +class IsaacVisionEncoder(nn.Module): """Encoder using Isaac encoder layers with variable-length attention support.""" - def __init__(self, config: PixelShuffleSiglip2VisionConfig): + def __init__(self, config: IsaacVisionConfig): super().__init__() self.config = config - self.layers = nn.ModuleList([IsaacSiglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList([IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) def forward( self, @@ -327,16 +842,19 @@ def create_pixel_shuffle_index_map( if device is None: device = seq_sizes.device - r = int(scale_factor) - if r < 2: + scale_factor = int(scale_factor) + if scale_factor < 2: raise ValueError("`scale_factor` must be โ‰ฅ 2") - # Safety: all spatial dims must be divisible by r + # Safety: all spatial dims must be divisible by the scale factor # Cannot run under torch compile fullgraph mode hence - if not torch.compiler.is_compiling(): - if not ((token_grids[:, 0] % r == 0).all() and (token_grids[:, 1] % r == 0).all()): + if not is_torchdynamo_compiling(): + if not ( + (token_grids[:, 0] % scale_factor == 0).all() and (token_grids[:, 1] % scale_factor == 0).all() + ): raise AssertionError( - f"Every (H,W) in `token_grids` must be divisible by scale_factor={r}, got {token_grids.tolist()}" + "Every (H,W) in `token_grids` must be divisible by " + f"scale_factor={scale_factor}, got {token_grids.tolist()}" ) gather_chunks: list[torch.Tensor] = [] @@ -348,19 +866,21 @@ def create_pixel_shuffle_index_map( grid = grid.view(h, w) # (H, W) # -------- identical ordering to your fixed-res routine -------- - # Step 1: split width into blocks of r - grid = grid.view(h, w // r, r) # (H, W/r, r) - # Step 2: now split height into blocks of r - grid = grid.view(h // r, r, w // r, r) # (H/r, r, W/r, r) - # Step 3: final permutation to (H/r, W/r, r, r) - grid = grid.permute(0, 2, 1, 3).contiguous() # (H/r, W/r, r, r) - # Step 4: each (r, r) block forms one output token - gather_chunks.append(grid.reshape(-1, r * r)) # (H*W / rยฒ, rยฒ) + # Step 1: split width into blocks of scale_factor + grid = grid.view(h, w // scale_factor, scale_factor) # (H, W/scale_factor, scale_factor) + # Step 2: now split height into blocks of scale_factor + grid = grid.view(h // scale_factor, scale_factor, w // scale_factor, scale_factor) + # (H/scale_factor, scale_factor, W/scale_factor, scale_factor) + # Step 3: final permutation to (H/scale_factor, W/scale_factor, scale_factor, scale_factor) + grid = grid.permute(0, 2, 1, 3).contiguous() # (H/scale_factor, W/scale_factor, scale_factor, scale_factor) + # Step 4: each (scale_factor, scale_factor) block forms one output token + gather_chunks.append(grid.reshape(-1, scale_factor * scale_factor)) + # (H*W / scale_factor**2, scale_factor**2) tok_offset += seq_len # Concatenate over all images in the packed batch - gather_idx = torch.cat(gather_chunks, dim=0) # (ฮฃ_i HแตขWแตข/rยฒ, rยฒ) + gather_idx = torch.cat(gather_chunks, dim=0) # (ฮฃ_i HแตขWแตข/scale_factor**2, scale_factor**2) return gather_idx @@ -399,7 +919,7 @@ def pixel_shuffle_varlen( x_ = x # (seq, embed) embed_dim = x_.size(-1) - r = int(scale_factor) + scale_factor = int(scale_factor) # Calculate seq_sizes from token_grids seq_sizes = torch.prod(token_grids, dim=-1) @@ -408,15 +928,15 @@ def pixel_shuffle_varlen( gather_idx = create_pixel_shuffle_index_map( seq_sizes=seq_sizes, token_grids=token_grids, - scale_factor=r, + scale_factor=scale_factor, device=x_.device, - ) # (new_seq, rยฒ) + ) # (new_seq, scale_factor**2) - # Gather โ†’ (new_seq, rยฒ, embed_dim) + # Gather โ†’ (new_seq, scale_factor**2, embed_dim) gathered = x_[gather_idx] # fancy indexing keeps gradient - # Merge the rยฒ group dimension into channels to finish the shuffle - out = gathered.reshape(gathered.size(0), embed_dim * r * r) + # Merge the scale_factor**2 group dimension into channels to finish the shuffle + out = gathered.reshape(gathered.size(0), embed_dim * scale_factor * scale_factor) # Restore batch dimension if needed if keep_batch_dim: @@ -424,12 +944,12 @@ def pixel_shuffle_varlen( return out -class Siglip2SequenceVisionTransformer(nn.Module): - def __init__(self, config: PixelShuffleSiglip2VisionConfig): +class IsaacVisionTransformer(nn.Module): + def __init__(self, config: IsaacVisionConfig): super().__init__() self.config = config - self.embeddings = Siglip2VariableSequenceEmbeddings(config) - self.encoder = IsaacEncoder(config) + self.embeddings = IsaacVisionEmbeddings(config) + self.encoder = IsaacVisionEncoder(config) self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor @@ -438,13 +958,15 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): seq_sizes = torch.prod(token_grids, dim=-1) # Get embeddings from packed sequence - hidden_states = self.embeddings((seq_patches, seq_sizes, token_grids)) + hidden_states = self.embeddings(seq_patches, token_grids) # Add a pseudo batch dimension for the encoder hidden_states = hidden_states.unsqueeze(0) # Generate cumulative sequence lengths for variable-length attention - cu_seqlens, max_seqlen = create_cumulative_seq_lengths(seq_sizes, hidden_states.device) + cu_seqlens = torch.zeros(seq_sizes.size(0) + 1, dtype=torch.int32, device=hidden_states.device) + cu_seqlens[1:] = seq_sizes.cumsum(0) + max_seqlen = int(seq_sizes.max().item()) if seq_sizes.numel() > 0 else 0 # Pass through encoder with variable-length attention parameters hidden_states, _, _ = self.encoder( @@ -473,40 +995,19 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): # Configuration # ============================================================================ -MAX_PIXELS = 60_000_000 # 60-megapixel ceiling โ‰ˆ 8200 ร— 7300 px - -# Vision preprocessing constants -VISION_MEAN = (0.5, 0.5, 0.5) -VISION_STD = (0.5, 0.5, 0.5) -VISION_SCALE = 1 / 255 - +MAX_PIXELS = 60_000_000 # 60โ€‘megapixel ceiling โ‰ˆ 8200 ร— 7300 px -def _make_writeable(arr: np.ndarray) -> np.ndarray: - """Return *arr* itself if it is already writeable, otherwise try to flip the - write flag in-place and finally fall back to `arr.copy()`. - This guarantees the buffer handed to `torch.from_numpy()` is always - writeable, silencing the PyTorch warning about undefined behaviour. - """ - if arr.flags.writeable: - return arr - - # First, try the cheap path โ€” in-place flag toggle (works for mmap'd arrays - # and some shared memory buffers): - try: - arr.setflags(write=True) - return arr # success: no data copy - except ValueError: - # Buffer is inherently read-only (e.g. backed by PyAV / PIL): make copy - return arr.copy() - - -def extract_image_pil(image: PIL.Image.Image) -> torch.Tensor | None: - if image.width * image.height > MAX_PIXELS: - raise ValueError(f"Image (w={image.width}, h={image.height}) > MAX=`{MAX_PIXELS}`") - img = image if image.mode == "RGB" else image.convert("RGB") - arr = np.asarray(img) - arr = _make_writeable(arr) - return torch.from_numpy(arr) +def get_scaled_image_size( + scale: float, + original_size: int, + patch_size: int, + pixel_shuffle_scale: int, +) -> int: + scaled_size = scale * original_size + divisor = patch_size * pixel_shuffle_scale + scaled_size = math.ceil(scaled_size / divisor) * divisor + scaled_size = max(divisor, scaled_size) + return int(scaled_size) def get_image_size_for_max_num_patches( @@ -541,13 +1042,6 @@ def get_image_size_for_max_num_patches( and respect both the maximum and optional minimum patch-count constraints. """ - def get_scaled_image_size(scale, original_size, patch_size, pixel_shuffle_scale): - scaled_size = scale * original_size - divisor = patch_size * pixel_shuffle_scale - scaled_size = math.ceil(scaled_size / divisor) * divisor - scaled_size = max(divisor, scaled_size) - return int(scaled_size) - # Ensure divisibility divisor = patch_size * pixel_shuffle_scale adjusted_height = math.ceil(image_height / divisor) * divisor @@ -593,37 +1087,6 @@ def get_scaled_image_size(scale, original_size, patch_size, pixel_shuffle_scale) return target_height, target_width -_MEAN_TENSOR = torch.tensor(VISION_MEAN, dtype=torch.float32).view(1, 1, 1, -1) -_STD_TENSOR = torch.tensor(VISION_STD, dtype=torch.float32).view(1, 1, 1, -1) - - -def prepare_image_tensor( - image: torch.Tensor, - scale: float = VISION_SCALE, -) -> torch.Tensor: - r"""Standardize RGB images prior to patch extraction via rescaling and whitening. - - Args: - image (`torch.Tensor`): - Tensor with shape `(..., height, width, 3)` containing RGB values. The tensor is converted to floating - point if needed. - scale (`float`, *optional*, defaults to `VISION_SCALE`): - Scalar multiplier applied before normalization. - Returns: - `torch.Tensor`: Normalized tensor with the same shape as the input and dtype `torch.float32`. - """ - if not torch.is_floating_point(image): - image = image.float() - rescaled = image * scale - - # Use precomputed tensors and move to the correct device if needed - mean_tensor = _MEAN_TENSOR.to(image.device) - std_tensor = _STD_TENSOR.to(image.device) - - normalized = (rescaled - mean_tensor) / std_tensor - return normalized - - def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: r"""Convert normalized images into flattened ViT-style patches. @@ -649,87 +1112,6 @@ def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: return patches -def process_vision_for_patches( - images: torch.Tensor, - patch_size: int, - max_num_patches: int, - min_num_patches: int | None = None, - pixel_shuffle_scale: int = 1, -) -> tuple[torch.Tensor, list[int]]: - r"""Resize, normalize, and patchify RGB images for the vision encoder. - - Args: - images (`torch.Tensor`): - Either `(height, width, channels)` for a single image or `(num_images, height, width, channels)` for a - batch. Channels are expected to be RGB. - patch_size (`int`): - Edge length of square patches; implictly controls resize grid granularity. - max_num_patches (`int`): - Maximum number of patches allowed after resizing. - min_num_patches (`int`, *optional*): - Minimum number of patches. If provided, the routine upsamples images as needed to satisfy the lower bound. - pixel_shuffle_scale (`int`, *optional*, defaults to 1): - pixel shuffle scale factor; influences the target grid that the function produces. - - Returns: - `tuple[torch.Tensor, list[int]]`: A pair `(patches, dims_virtual)` where `patches` has shape - `(num_images, target_h / patch_size, target_w / patch_size, channels * patch_size**2)` and `dims_virtual` - encodes effective `(images, height, width)` dimensions after optional pixel shuffling. - """ - # Add batch dim if single image - if images.dim() == 3: - images = images.unsqueeze(0) - - # Permute to channel first for resize - images = images.permute(0, 3, 1, 2) - - # Get target dimensions - _, _, orig_height, orig_width = images.shape - target_height, target_width = get_image_size_for_max_num_patches( - orig_height, - orig_width, - patch_size, - max_num_patches, - min_num_patches=min_num_patches, - pixel_shuffle_scale=pixel_shuffle_scale, - ) - - # Resize - images = F.interpolate( - images, - size=(target_height, target_width), - mode="bilinear", - align_corners=False, - ) - - # Back to channel last - images = images.permute(0, 2, 3, 1) - - # Normalize - images = prepare_image_tensor(images) - - # Patchify - patches = patchify_vision(images, patch_size=patch_size) - - # Calculate dimensions for the patches - n_images, h_patches, w_patches, _ = patches.shape - dims_virtual = ( - [1, h_patches, w_patches] - if pixel_shuffle_scale == 1 - else [1, h_patches // pixel_shuffle_scale, w_patches // pixel_shuffle_scale] - ) - - return patches, dims_virtual - - -def precompute_inv_freq(theta: float, dim: int) -> torch.Tensor: - """ - Returns shape (dim//2,). - """ - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - return inv_freq # type: ignore[return-value] - - def precompute_cos_sin_3d( position_ids: torch.Tensor, # shape (3, B, T) inv_freq: torch.Tensor, # shape (dim//2,) @@ -793,22 +1175,41 @@ class IsaacConfig(Qwen3Config): """Configuration class for Isaac multimodal model.""" model_type = "isaac" - sub_configs = {"vision_config": PixelShuffleSiglip2VisionConfig} + sub_configs = {"vision_config": IsaacVisionConfig, "text_config": Qwen3Config} + image_processor_type = "IsaacImageProcessor" def __init__( self, vision_config=None, + text_config: Qwen3Config | dict | None = None, vision_patch_size: int = 16, vision_max_num_patches: int = 256, vision_min_num_patches: int | None = None, pixel_shuffle_scale: int = 1, + vision_rescale_factor: float = VISION_SCALE, + vision_mean: float | Sequence[float] = VISION_MEAN, + vision_std: float | Sequence[float] = VISION_STD, max_sequence_length: int = 16384, vision_token: str = "", + vision_attn_implementation: str | None = None, **kwargs, ): - super().__init__(**kwargs) + resolved_text_config = kwargs.pop("text_config", text_config) + if isinstance(resolved_text_config, Qwen3Config): + text_config_kwargs = copy.deepcopy(resolved_text_config.to_dict()) + elif isinstance(resolved_text_config, dict): + text_config_kwargs = copy.deepcopy(resolved_text_config) + elif resolved_text_config is None: + text_config_kwargs = {} + else: + raise TypeError("`text_config` must be a mapping or `Qwen3Config` instance when provided.") + + text_config_kwargs.update(kwargs) - # Handle vision config - either dict or PixelShuffleSiglip2VisionConfig instance + super().__init__(**text_config_kwargs) + self.text_config = Qwen3Config(**text_config_kwargs) + + # Handle vision config - either dict or IsaacVisionConfig instance if isinstance(vision_config, dict): self.vision_config = self.sub_configs["vision_config"](**vision_config) elif vision_config is None: @@ -822,9 +1223,21 @@ def __init__( self.vision_min_num_patches = vision_min_num_patches self.pixel_shuffle_scale = pixel_shuffle_scale + # Vision normalization parameters + self.vision_rescale_factor = float(vision_rescale_factor) + self.vision_mean = _normalize_rgb_values(vision_mean, name="vision_mean") + self.vision_std = _normalize_rgb_values(vision_std, name="vision_std") + # Processing parameters self.max_sequence_length = max_sequence_length self.vision_token = vision_token + self.vision_attn_implementation = vision_attn_implementation + + def get_text_config(self, *_, **kwargs) -> Qwen3Config: + # Accept optional decoder/encoder flags to align with HF composite configs + kwargs.pop("decoder", None) + kwargs.pop("encoder", None) + return self.text_config # ============================================================================ @@ -875,41 +1288,111 @@ def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> # ============================================================================ +IsaacImageProcessorFast.slow_image_processor_class = IsaacImageProcessor + + + class IsaacProcessor(ProcessorMixin): - attributes = [] - tokenizer_class = ("AutoTokenizer",) + attributes = ["image_processor", "tokenizer"] + image_processor_class = ("IsaacImageProcessor", "IsaacImageProcessorFast") + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") def __init__( self, - tokenizer: AutoTokenizer, - config: IsaacConfig, - ): - super().__init__() - self.tokenizer = tokenizer - self.config = config + image_processor: IsaacImageProcessor | IsaacImageProcessorFast | None = None, + tokenizer: Qwen2Tokenizer | None = None, + *, + vision_token: str = "", + max_sequence_length: int = 16384, + vision_patch_size: int = 16, + vision_max_num_patches: int = 256, + vision_min_num_patches: int | None = None, + pixel_shuffle_scale: int = 1, + rescale_factor: float | None = None, + image_mean: float | Sequence[float] | None = None, + image_std: float | Sequence[float] | None = None, + vision_attn_implementation: str | None = None, + config: IsaacConfig | dict | None = None, + **kwargs, + ) -> None: + if tokenizer is None: + raise ValueError("`tokenizer` must be provided to initialize IsaacProcessor.") + + if isinstance(config, dict): + config = IsaacConfig(**config) + + if config is not None: + vision_patch_size = config.video_patch_size + vision_max_num_patches = config.vision_max_num_patches + vision_min_num_patches = config.vision_min_num_patches + pixel_shuffle_scale = config.pixel_shuffle_scale + max_sequence_length = config.max_sequence_length + vision_token = config.vision_token + vision_attn_implementation = config.vision_attn_implementation + rescale_factor = config.vision_rescale_factor + image_mean = tuple(config.vision_mean) + image_std = tuple(config.vision_std) + + resolved_rescale_factor = ( + float(rescale_factor) if rescale_factor is not None else float(VISION_SCALE) + ) + resolved_image_mean = _normalize_rgb_values( + image_mean if image_mean is not None else VISION_MEAN, + name="image_mean", + ) + resolved_image_std = _normalize_rgb_values( + image_std if image_std is not None else VISION_STD, + name="image_std", + ) - # Use vision token from config - self.vision_token = config.vision_token + if image_processor is None: + image_processor = IsaacImageProcessor( + patch_size=vision_patch_size, + max_num_patches=vision_max_num_patches, + min_num_patches=vision_min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + rescale_factor=resolved_rescale_factor, + image_mean=resolved_image_mean, + image_std=resolved_image_std, + ) + else: + vision_patch_size = getattr(image_processor, "patch_size", vision_patch_size) + vision_max_num_patches = getattr(image_processor, "max_num_patches", vision_max_num_patches) + vision_min_num_patches = getattr(image_processor, "min_num_patches", vision_min_num_patches) + pixel_shuffle_scale = getattr(image_processor, "pixel_shuffle_scale", pixel_shuffle_scale) + resolved_rescale_factor = getattr(image_processor, "rescale_factor", resolved_rescale_factor) + resolved_image_mean = _normalize_rgb_values( + getattr(image_processor, "image_mean", resolved_image_mean), + name="image_mean", + ) + resolved_image_std = _normalize_rgb_values( + getattr(image_processor, "image_std", resolved_image_std), + name="image_std", + ) - # Processing parameters - self.max_sequence_length = config.max_sequence_length + if config is not None: + config.vision_rescale_factor = resolved_rescale_factor + config.vision_mean = resolved_image_mean + config.vision_std = resolved_image_std - # Vision processing parameters - self.patch_size = config.video_patch_size - self.max_num_patches = config.vision_max_num_patches - self.min_num_patches = config.vision_min_num_patches - self.pixel_shuffle_scale = config.pixel_shuffle_scale + super().__init__(image_processor, tokenizer) + self.current_processor = self.image_processor + self.config = config - def apply_chat_template( - self, - messages: list[dict[str, Any]], - tokenize: bool = False, - add_generation_prompt: bool = False, - **kwargs, - ) -> Any: - return self.tokenizer.apply_chat_template( - messages, tokenize=tokenize, add_generation_prompt=add_generation_prompt, **kwargs - ) + # Mirror tokenizer chat template so ProcessorMixin.apply_chat_template works. + self.chat_template = getattr(self.tokenizer, "chat_template", None) + + self.vision_token = vision_token + self.max_sequence_length = max_sequence_length + self.vision_attn_implementation = vision_attn_implementation + + self.patch_size = getattr(self.image_processor, "patch_size", vision_patch_size) + self.max_num_patches = getattr(self.image_processor, "max_num_patches", vision_max_num_patches) + self.min_num_patches = getattr(self.image_processor, "min_num_patches", vision_min_num_patches) + self.pixel_shuffle_scale = getattr(self.image_processor, "pixel_shuffle_scale", pixel_shuffle_scale) + self.rescale_factor = getattr(self.image_processor, "rescale_factor", resolved_rescale_factor) + self.image_mean = tuple(getattr(self.image_processor, "image_mean", resolved_image_mean)) + self.image_std = tuple(getattr(self.image_processor, "image_std", resolved_image_std)) def build_event_stream_simple( self, @@ -927,68 +1410,40 @@ def build_event_stream_simple( for current_time, part in enumerate(parts): if part == self.vision_token: # Replace vision token with image event - if image_idx < len(images): - # Create vision event from PIL image - image_tensor = extract_image_pil(images[image_idx]) - if image_tensor is not None: - # Create a vision event with the image tensor - vision_event = Event( - data=image_tensor.unsqueeze(0), # HWC format from extract_image_pil - type=VisionType.image, # I-frame - time=(current_time, current_time), - ) - events.append(vision_event) - image_idx += 1 - elif part: # Non-empty text part - # tokens = self.text_processor.tokenize(part, add_special_tokens=False) - text_event = create_text_event(self.tokenizer, part, time=current_time) - events.append(text_event) + if images is None or image_idx >= len(images): + raise ValueError("Encountered vision token without a corresponding image.") - # Process vision events if any - if any(event.type == VisionType.image for event in events): - # Separate text and vision events for processing - text_events = [event for event in events if event.type == TextType.text] - vision_events = [event for event in events if event.type == VisionType.image] - - # Process vision events using functional approach - processed_vision_events = [] - for vision_event in vision_events: - # Process the vision data - patches, dims_virtual = process_vision_for_patches( - vision_event.data.squeeze(0), # Remove the extra dimension - patch_size=self.patch_size, - max_num_patches=self.max_num_patches, - min_num_patches=self.min_num_patches, - pixel_shuffle_scale=self.pixel_shuffle_scale, + features = self.image_processor( + images=images[image_idx], + return_tensors=TensorType.PYTORCH, ) - # Update event with processed data - vision_event.data = patches.unsqueeze(1) # Add back frame dimension - vision_event.dims_virtual = dims_virtual - vision_event.dims_real = ( - dims_virtual - if self.pixel_shuffle_scale == 1 - else [ - dims_virtual[0], - dims_virtual[1] * self.pixel_shuffle_scale, - dims_virtual[2] * self.pixel_shuffle_scale, - ] + patches = features["patches"][0] # (H_tokens, W_tokens, embed) + virtual_dims = features["virtual_pixel_size"][0].tolist() + real_dims = features["real_pixel_size"][0].tolist() + + vision_event = Event( + data=patches.reshape(-1, patches.shape[-1]), + type=VisionType.image, + time=(current_time, current_time), + dims_virtual=virtual_dims, + dims_real=real_dims, + idx_range=(0, math.prod(virtual_dims)), ) - vision_event.idx_range = (0, math.prod(dims_virtual)) - - # Flatten the patches - vision_event.data = vision_event.data.reshape(-1, vision_event.data.shape[-1]) - processed_vision_events.append(vision_event) - - events = text_events + processed_vision_events + events.append(vision_event) + image_idx += 1 + elif part: # Non-empty text part + # tokens = self.text_processor.tokenize(part, add_special_tokens=False) + text_event = create_text_event(self.tokenizer, part, time=current_time) + events.append(text_event) # Create stream without scheduling (events already in order) return create_stream(events, priority=[TextType.text, VisionType.image], schedule=True) def __call__( self, - text: Union[str, list[str]], - images: Union[PIL.Image.Image, list[PIL.Image.Image], None] = None, + text: str | list[str], + images: PIL.Image.Image | list[PIL.Image.Image] | None = None, return_tensors: str | TensorType | None = TensorType.PYTORCH, **kwargs, ) -> BatchFeature: @@ -1080,61 +1535,108 @@ def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: class IsaacRotaryEmbedding(nn.Module): + EXTRA_ROPE_KEYS = {"mrope_section", "mrope_interleaved"} + def __init__(self, config: IsaacConfig, device=None): super().__init__() - # Extract dimensions from config - self.hidden_size = config.hidden_size - self.num_attention_heads = config.num_attention_heads - self.head_dim = config.head_dim - - # Get rope_scaling config - use direct access when available + self.config = config rope_scaling = getattr(config, "rope_scaling", None) or {} + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) + if rope_type not in ROPE_INIT_FUNCTIONS: + raise ValueError(f"Unsupported rope_type '{rope_type}' for IsaacRotaryEmbedding") - # Read RopeScaling parameters - self.rope_type = rope_scaling.get("rope_type", "default") + self.rope_type = rope_type + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] - self.mrope_section = [ - self.head_dim // 4, # 2x more for temporal dim - self.head_dim // 8, - self.head_dim // 8, - ] + sanitized_scaling = {k: v for k, v in rope_scaling.items() if k not in self.EXTRA_ROPE_KEYS} + if sanitized_scaling != rope_scaling: + config_for_rope = copy.copy(config) + config_for_rope.rope_scaling = sanitized_scaling + else: + config_for_rope = config - rope_base = getattr(config, "rope_theta", 10000.0) - inv_freq = precompute_inv_freq(rope_base, self.head_dim) + init_device = device if device is not None and getattr(device, "type", None) != "meta" else None + inv_freq, attention_scaling = rope_init_fn(config_for_rope, device=init_device) self.register_buffer("inv_freq", inv_freq, persistent=False) + self.attention_scaling = self._normalize_scale(attention_scaling) + + rotary_half_dim = self.inv_freq.shape[0] + self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim) + + @staticmethod + def _normalize_scale(scale: torch.Tensor | float) -> torch.Tensor | float: + if isinstance(scale, torch.Tensor): + return scale.detach().clone() + return float(scale) + + @staticmethod + def _resolve_mrope_section(section: list[int] | None, rotary_half_dim: int) -> list[int]: + if section is None: + weights = (2, 1, 1) + base = [rotary_half_dim * w // sum(weights) for w in weights] + base[0] += rotary_half_dim - sum(base) + return base + + section = [int(v) for v in section] + if len(section) != 3: + raise ValueError("`mrope_section` must contain exactly three elements (temporal, height, width)") + if sum(section) != rotary_half_dim: + raise ValueError( + f"`mrope_section` must sum to the rotary half-dimension ({rotary_half_dim}). Received {section}." + ) + return section def forward(self, position_ids: torch.Tensor, modality_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: with torch.no_grad(): - # Ensure non-spatial tokens have 1D rotation equivalence - not_spatial = ~(modality_tensor == VisionType.image.value) - # shape is [N, 1] - data_1d = position_ids[not_spatial][..., 0].unsqueeze(-1) - # now broadcast it from [N, 1] -> [N, D] so it matches pos[not_spatial] exactly - data_1d = data_1d.expand(-1, position_ids.shape[-1]) # expand along the last dim - position_ids = position_ids.clone() # Clone to avoid warning about in-place operations on expanded tensors - position_ids[not_spatial] = data_1d - position_ids = position_ids.permute(2, 0, 1) # pos dim first -> (3, B, L) + position_ids = position_ids.clone() + not_spatial = modality_tensor != VisionType.image.value + if not_spatial.any(): + data_1d = position_ids[not_spatial][..., 0].unsqueeze(-1) + position_ids[not_spatial] = data_1d.expand(-1, position_ids.shape[-1]) + + position_ids = position_ids.permute(2, 0, 1) cos, sin = precompute_cos_sin_3d(position_ids, self.inv_freq, self.mrope_section) + scale = self.attention_scaling + if isinstance(scale, torch.Tensor): + scale = scale.to(device=cos.device, dtype=cos.dtype) + elif scale != 1.0: + scale = cos.new_tensor(scale) + if isinstance(scale, torch.Tensor) or scale != 1.0: + cos = cos * scale + sin = sin * scale return cos, sin -class IsaacModel(Qwen3Model): +class IsaacModel(Qwen3PreTrainedModel): + supports_gradient_checkpointing = True + def __init__(self, config: IsaacConfig): super().__init__(config) - self.layers = torch.nn.ModuleList( - [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) + + text_cfg_source = getattr(config, "get_text_config", lambda: config)() + text_cfg = copy.deepcopy(text_cfg_source) + text_cfg._attn_implementation = config._attn_implementation + self.text_model = AutoModel.from_config(text_cfg) + # Ensure downstream callers observe the composed config + self.text_model.config = config + self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) vision_cfg = config.vision_config + # Use vision_attn_implementation if specified, otherwise fall back to general attn_implementation + vision_cfg._attn_implementation = ( + config.vision_attn_implementation + if config.vision_attn_implementation is not None + else config._attn_implementation + ) if vision_cfg is None: raise ValueError("IsaacConfig should always have vision_config") hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2) self.vision_embedding = nn.Sequential( - Siglip2SequenceVisionTransformer(vision_cfg), + IsaacVisionTransformer(vision_cfg), nn.Linear( hidden_dim, 4 * hidden_dim, @@ -1150,10 +1652,37 @@ def __init__(self, config: IsaacConfig): VisionType: self.embed_vision, } + def get_input_embeddings(self) -> nn.Module: + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value: nn.Module) -> None: + self.text_model.set_input_embeddings(value) + + @property + def embed_tokens(self) -> nn.Module: + return self.text_model.embed_tokens + + @embed_tokens.setter + def embed_tokens(self, value: nn.Module) -> None: + self.text_model.embed_tokens = value + + @property + def layers(self) -> nn.ModuleList: + return self.text_model.layers + + @property + def norm(self) -> nn.Module: + return self.text_model.norm + + def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None): + self.text_model._set_gradient_checkpointing( + enable=enable, gradient_checkpointing_func=gradient_checkpointing_func + ) + def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: """Embed text tokens, squeezing singleton dimensions.""" # Text events are shaped as (..., 1); squeeze the singleton index dim - h = self.embed_tokens(token_ids) + h = self.text_model.embed_tokens(token_ids) if h.dim() >= 2 and h.size(-2) == 1: h = h[..., 0, :] return h @@ -1235,7 +1764,7 @@ def forward( elif input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.text_model.embed_tokens(input_ids) # Create text modality tensor if not provided if modality_tensor is None: batch_size, seq_length = input_ids.shape @@ -1266,7 +1795,7 @@ def forward( # Initialize hidden states hidden_states = inputs_embeds - for decoder_layer in self.layers: + for decoder_layer in self.text_model.layers: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, @@ -1278,16 +1807,166 @@ def forward( **kwargs, ) - hidden_states = layer_outputs[0] + hidden_states = layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs # Final layer norm - hidden_states = self.norm(hidden_states) + hidden_states = self.text_model.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, ) + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen3Config, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen3Config`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): """Isaac multimodal model for conditional generation.""" @@ -1459,9 +2138,42 @@ def can_generate(self) -> bool: return True +AutoImageProcessor.register( + IsaacConfig, + slow_image_processor_class=IsaacImageProcessor, + fast_image_processor_class=IsaacImageProcessorFast, + exist_ok=True, +) + + __all__ = [ "IsaacConfig", "IsaacModel", "IsaacForConditionalGeneration", + "IsaacImageProcessor", + "IsaacImageProcessorFast", "IsaacProcessor", -] \ No newline at end of file +] +def _prepare_image_tensor(image: torch.Tensor, scale: float, mean: tuple[float, ...], std: tuple[float, ...]) -> torch.Tensor: + """Mirror the prepare_image_tensor utility used in the training pipelines.""" + if not torch.is_floating_point(image): + image = image.float() + + rescaled = image * scale + mean_tensor = torch.tensor(mean, dtype=torch.float32, device=rescaled.device).view(1, 1, 1, -1) + std_tensor = torch.tensor(std, dtype=torch.float32, device=rescaled.device).view(1, 1, 1, -1) + normalized = (rescaled - mean_tensor) / std_tensor + return normalized + + +def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> torch.Tensor: + """Compute residuals for P-frames to stay in sync with the training pipeline.""" + if not any(is_p_frame): + return frames + + frame_indices = torch.arange(len(is_p_frame), device=frames.device) + i_frame_mask = torch.tensor([not flag for flag in is_p_frame], device=frames.device) + last_i_indices = torch.cummax((i_frame_mask * (1 + frame_indices)), dim=0).values.long() - 1 + p_indices = frame_indices[torch.tensor(is_p_frame, device=frames.device)] + frames[p_indices] = frames[p_indices] - frames[last_i_indices[p_indices]] + return frames From d72311d8e103ec1ab8ccb0c918981eec868ebe1d Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Fri, 10 Oct 2025 11:38:43 +0400 Subject: [PATCH 03/77] fix: get modular convert utility working --- .../models/isaac/modular_isaac.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 313c3ce76d5e..a51950f4a7dc 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -151,8 +151,8 @@ class IsaacImageProcessorKwargs(ImagesKwargs): @auto_docstring class IsaacImageProcessorFast(BaseImageProcessorFast): - slow_image_processor_class = None r"""Fast torch-based image processor for Isaac vision inputs.""" + slow_image_processor_class = "IsaacImageProcessor" resample = PILImageResampling.BILINEAR model_input_names = ["patches", "token_grids"] @@ -1287,11 +1287,6 @@ def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> # Processor # ============================================================================ - -IsaacImageProcessorFast.slow_image_processor_class = IsaacImageProcessor - - - class IsaacProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = ("IsaacImageProcessor", "IsaacImageProcessorFast") @@ -1299,7 +1294,7 @@ class IsaacProcessor(ProcessorMixin): def __init__( self, - image_processor: IsaacImageProcessor | IsaacImageProcessorFast | None = None, + image_processor: IsaacImageProcessor | BaseImageProcessorFast | None = None, tokenizer: Qwen2Tokenizer | None = None, *, vision_token: str = "", @@ -1613,7 +1608,7 @@ class IsaacModel(Qwen3PreTrainedModel): supports_gradient_checkpointing = True def __init__(self, config: IsaacConfig): - super().__init__(config) + Qwen3PreTrainedModel.__init__(self, config) text_cfg_source = getattr(config, "get_text_config", lambda: config)() text_cfg = copy.deepcopy(text_cfg_source) @@ -1974,15 +1969,14 @@ class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): config_class = IsaacConfig def __init__(self, config: IsaacConfig): - Qwen3PreTrainedModel.__init__(self, config) + super().__init__(config) + self.model = IsaacModel(config) # Use our custom model self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Tracks rotary position offsets computed during a full forward pass so decode steps can reuse them. self.rope_deltas = None - self.config = config - def get_rope_index( self, input_ids: torch.Tensor | None, @@ -2138,10 +2132,18 @@ def can_generate(self) -> bool: return True + +def _load_isaac_fast_image_processor(): + try: + from .image_processing_isaac_fast import IsaacImageProcessorFast as fast_cls + except ImportError: + fast_cls = None + return fast_cls + AutoImageProcessor.register( IsaacConfig, slow_image_processor_class=IsaacImageProcessor, - fast_image_processor_class=IsaacImageProcessorFast, + fast_image_processor_class=_load_isaac_fast_image_processor(), exist_ok=True, ) From d6ed8440839800f93100f80a54a07d93a2fa2c8a Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Fri, 10 Oct 2025 12:01:27 +0400 Subject: [PATCH 04/77] feat: modular convert utility outputs --- .../models/isaac/configuration_isaac.py | 83 +- .../models/isaac/image_processing_isaac.py | 432 ++++++++++ .../isaac/image_processing_isaac_fast.py | 347 ++++++++ .../models/isaac/modeling_isaac.py | 750 ++++++++++++------ .../models/isaac/processing_isaac.py | 471 ++++------- 5 files changed, 1508 insertions(+), 575 deletions(-) create mode 100644 src/transformers/models/isaac/image_processing_isaac.py create mode 100644 src/transformers/models/isaac/image_processing_isaac_fast.py diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index d29e44b68e4d..3ffac505e2ee 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -5,17 +5,27 @@ # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ +import copy +from collections.abc import Sequence + +# Build the list of all image processors from ...configuration_utils import PretrainedConfig, layer_type_validation from ...modeling_rope_utils import rope_config_validation -class PixelShuffleSiglip2VisionConfig(PretrainedConfig): +class IsaacVisionConfig(PretrainedConfig): """Vision configuration for Isaac with Pixel Shuffle support. Extends Siglip2VisionConfig with additional fields for pixel shuffle. + + Args: + pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1): + Spatial factor applied before pixel shuffle reduces the resolution. + num_patches (`int`, *optional*, defaults to 256): + Maximum number of learnable positional embeddings to initialize. """ - model_type = "pixel_shuffle_siglip2" + model_type = "isaac_vision" base_config_key = "vision_config" def __init__( @@ -41,6 +51,30 @@ def __init__( self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor +# Vision preprocessing constants +VISION_MEAN = (0.5, 0.5, 0.5) +VISION_STD = (0.5, 0.5, 0.5) +VISION_SCALE = 1 / 255 + + +def _normalize_rgb_values( + values: float | Sequence[float] | tuple[float, ...], + *, + name: str, +) -> tuple[float, float, float]: + """Coerce RGB normalization parameters into a 3-tuple of floats.""" + if isinstance(values, (list, tuple)): + if len(values) == 3: + return tuple(float(v) for v in values) + if len(values) == 1: + value = float(values[0]) + return (value, value, value) + raise ValueError(f"`{name}` must have length 1 or 3 when provided as a sequence. Got length {len(values)}.") + + value = float(values) + return (value, value, value) + + class IsaacConfig(PretrainedConfig): """Configuration class for Isaac multimodal model.""" @@ -62,23 +96,36 @@ class IsaacConfig(PretrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } - sub_configs = {"vision_config": PixelShuffleSiglip2VisionConfig} + sub_configs = {"vision_config": IsaacVisionConfig, "text_config": Qwen3Config} + image_processor_type = "IsaacImageProcessor" def __init__( self, vision_config=None, + text_config: Qwen3Config | dict | None = None, vision_patch_size: int = 16, vision_max_num_patches: int = 256, vision_min_num_patches: int | None = None, pixel_shuffle_scale: int = 1, + vision_rescale_factor: float = VISION_SCALE, + vision_mean: float | Sequence[float] = VISION_MEAN, + vision_std: float | Sequence[float] = VISION_STD, max_sequence_length: int = 16384, vision_token: str = "", + vision_attn_implementation: str | None = None, **kwargs, ): - super().__init__( - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) + resolved_text_config = kwargs.pop("text_config", text_config) + if isinstance(resolved_text_config, Qwen3Config): + text_config_kwargs = copy.deepcopy(resolved_text_config.to_dict()) + elif isinstance(resolved_text_config, dict): + text_config_kwargs = copy.deepcopy(resolved_text_config) + elif resolved_text_config is None: + text_config_kwargs = {} + else: + raise TypeError("`text_config` must be a mapping or `Qwen3Config` instance when provided.") + + text_config_kwargs.update(kwargs) self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size @@ -117,9 +164,15 @@ def __init__( else "full_attention" for i in range(self.num_hidden_layers) ] - layer_type_validation(self.layer_types, self.num_hidden_layers) + layer_type_validation(self.layer_types) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.text_config = Qwen3Config(**text_config_kwargs) - # Handle vision config - either dict or PixelShuffleSiglip2VisionConfig instance + # Handle vision config - either dict or IsaacVisionConfig instance if isinstance(vision_config, dict): self.vision_config = self.sub_configs["vision_config"](**vision_config) elif vision_config is None: @@ -133,9 +186,21 @@ def __init__( self.vision_min_num_patches = vision_min_num_patches self.pixel_shuffle_scale = pixel_shuffle_scale + # Vision normalization parameters + self.vision_rescale_factor = float(vision_rescale_factor) + self.vision_mean = _normalize_rgb_values(vision_mean, name="vision_mean") + self.vision_std = _normalize_rgb_values(vision_std, name="vision_std") + # Processing parameters self.max_sequence_length = max_sequence_length self.vision_token = vision_token + self.vision_attn_implementation = vision_attn_implementation + + def get_text_config(self, *_, **kwargs) -> Qwen3Config: + # Accept optional decoder/encoder flags to align with HF composite configs + kwargs.pop("decoder", None) + kwargs.pop("encoder", None) + return self.text_config __all__ = ["IsaacConfig"] diff --git a/src/transformers/models/isaac/image_processing_isaac.py b/src/transformers/models/isaac/image_processing_isaac.py new file mode 100644 index 000000000000..d38740a4cfaf --- /dev/null +++ b/src/transformers/models/isaac/image_processing_isaac.py @@ -0,0 +1,432 @@ +# ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ +# This file was automatically generated from src/transformers/models/isaac/modular_isaac.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_isaac.py file directly. One of our CI enforces this. +# ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ +import math +from collections.abc import Sequence + +import numpy as np +import torch +import torch.nn.functional as F + +from ...feature_extraction_utils import BatchFeature +from ...image_processing_utils import BaseImageProcessor +from ...image_transforms import convert_to_rgb +from ...image_utils import ( + ImageInput, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...processing_utils import ImagesKwargs +from ...tokenization_utils import TensorType +from ...utils import filter_out_non_signature_kwargs + + +class IsaacImageProcessorKwargs(ImagesKwargs): + patch_size: int | None + max_num_patches: int | None + min_num_patches: int | None + pixel_shuffle_scale: int | None + do_rescale: bool | None + rescale_factor: float | None + do_normalize: bool | None + image_mean: float | Sequence[float] | None + image_std: float | Sequence[float] | None + do_convert_rgb: bool | None + + +# Vision preprocessing constants +VISION_MEAN = (0.5, 0.5, 0.5) +VISION_STD = (0.5, 0.5, 0.5) +VISION_SCALE = 1 / 255 + + +def _normalize_rgb_values( + values: float | Sequence[float] | tuple[float, ...], + *, + name: str, +) -> tuple[float, float, float]: + """Coerce RGB normalization parameters into a 3-tuple of floats.""" + if isinstance(values, (list, tuple)): + if len(values) == 3: + return tuple(float(v) for v in values) + if len(values) == 1: + value = float(values[0]) + return (value, value, value) + raise ValueError(f"`{name}` must have length 1 or 3 when provided as a sequence. Got length {len(values)}.") + + value = float(values) + return (value, value, value) + + +def _make_writeable(arr: np.ndarray) -> np.ndarray: + if arr.flags.writeable: + return arr + try: + arr.setflags(write=True) + return arr + except ValueError: + return arr.copy() + + +# ============================================================================ +# Configuration +# ============================================================================ + +MAX_PIXELS = 60_000_000 # 60โ€‘megapixel ceiling โ‰ˆ 8200 ร— 7300 px + + +def get_scaled_image_size( + scale: float, + original_size: int, + patch_size: int, + pixel_shuffle_scale: int, +) -> int: + scaled_size = scale * original_size + divisor = patch_size * pixel_shuffle_scale + scaled_size = math.ceil(scaled_size / divisor) * divisor + scaled_size = max(divisor, scaled_size) + return int(scaled_size) + + +def get_image_size_for_max_num_patches( + image_height: int, + image_width: int, + patch_size: int, + max_num_patches: int, + min_num_patches: int | None = None, + eps: float = 1e-5, + pixel_shuffle_scale: int = 1, +) -> tuple[int, int]: + r"""Compute a target resolution whose patch grid satisfies patching parametrization. + + Args: + image_height (`int`): + Height in pixels of the source image prior to any resizing. + image_width (`int`): + Width in pixels of the source image prior to any resizing. + patch_size (`int`): + Size of the square patch used by the vision encoder. + max_num_patches (`int`): + Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. + min_num_patches (`int`, *optional*): + Lower bound on the number of patches. When provided the image will be scaled up if necessary. + eps (`float`, *optional*, defaults to 1e-5): + Convergence tolerance for the internal binary search to determing the target dimensions. + pixel_shuffle_scale (`int`, *optional*, defaults to 1): + Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. + + Returns: + `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` + and respect both the maximum and optional minimum patch-count constraints. + """ + + # Ensure divisibility + divisor = patch_size * pixel_shuffle_scale + adjusted_height = math.ceil(image_height / divisor) * divisor + adjusted_height = max(divisor, adjusted_height) + adjusted_width = math.ceil(image_width / divisor) * divisor + adjusted_width = max(divisor, adjusted_width) + + num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) + + if min_num_patches is not None and num_patches < min_num_patches: + # Scale up + scale_min, scale_max = 1.0, 100.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + num_patches = (target_height / patch_size) * (target_width / patch_size) + if num_patches >= min_num_patches: + scale_max = scale + else: + scale_min = scale + scale = scale_max + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + return target_height, target_width + elif num_patches <= max_num_patches: + return adjusted_height, adjusted_width + else: + # Scale down + scale_min, scale_max = eps / 10, 1.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + num_patches = (target_height / patch_size) * (target_width / patch_size) + if num_patches <= max_num_patches: + scale_min = scale + else: + scale_max = scale + scale = scale_min + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + return target_height, target_width + + +def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: + r"""Convert normalized images into flattened ViT-style patches. + + Args: + image (`torch.Tensor`): + Tensor of shape `(num_images, height, width, channels)`. + patch_size (`int`): + Edge length of the square patches + + Returns: + `torch.Tensor`: + Patch tensor where each position stores the flattened pixels belonging to that patch. + + Raises: + ValueError: If `height` or `width` is not divisible by `patch_size`. + """ + num_images, height, width, channels = image.shape + if height % patch_size or width % patch_size: + raise ValueError(f"Dimensions of images {image.shape} are not divisible by patch_size={patch_size}.") + patches = image.reshape(num_images, height // patch_size, patch_size, width // patch_size, patch_size, channels) + patches = patches.permute(0, 1, 3, 2, 4, 5) + patches = patches.reshape( + num_images, height // patch_size, width // patch_size, channels * patch_size * patch_size + ) + return patches + + +def _prepare_image_tensor( + image: torch.Tensor, scale: float, mean: tuple[float, ...], std: tuple[float, ...] +) -> torch.Tensor: + """Mirror the prepare_image_tensor utility used in the training pipelines.""" + if not torch.is_floating_point(image): + image = image.float() + + rescaled = image * scale + mean_tensor = torch.tensor(mean, dtype=torch.float32, device=rescaled.device).view(1, 1, 1, -1) + std_tensor = torch.tensor(std, dtype=torch.float32, device=rescaled.device).view(1, 1, 1, -1) + normalized = (rescaled - mean_tensor) / std_tensor + return normalized + + +def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> torch.Tensor: + """Compute residuals for P-frames to stay in sync with the training pipeline.""" + if not any(is_p_frame): + return frames + + frame_indices = torch.arange(len(is_p_frame), device=frames.device) + i_frame_mask = torch.tensor([not flag for flag in is_p_frame], device=frames.device) + last_i_indices = torch.cummax((i_frame_mask * (1 + frame_indices)), dim=0).values.long() - 1 + p_indices = frame_indices[torch.tensor(is_p_frame, device=frames.device)] + frames[p_indices] = frames[p_indices] - frames[last_i_indices[p_indices]] + return frames + + +class IsaacImageProcessor(BaseImageProcessor): + """Image processor that prepares RGB frames for the Isaac vision encoder.""" + + model_input_names = ["patches", "token_grids"] + + def __init__( + self, + patch_size: int = 16, + max_num_patches: int = 256, + min_num_patches: int | None = None, + pixel_shuffle_scale: int = 1, + do_rescale: bool = True, + rescale_factor: float | None = None, + do_normalize: bool = True, + image_mean: float | Sequence[float] | None = None, + image_std: float | Sequence[float] | None = None, + do_convert_rgb: bool = True, + resize_mode: str = "bilinear", + align_corners: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + if pixel_shuffle_scale < 1: + raise ValueError("`pixel_shuffle_scale` must be >= 1") + + rescale_value = VISION_SCALE if rescale_factor is None else float(rescale_factor) + mean_value = VISION_MEAN if image_mean is None else image_mean + std_value = VISION_STD if image_std is None else image_std + + self.patch_size = patch_size + self.max_num_patches = max_num_patches + self.min_num_patches = min_num_patches + self.pixel_shuffle_scale = pixel_shuffle_scale + self.do_rescale = do_rescale + self.rescale_factor = rescale_value + self.do_normalize = do_normalize + self.image_mean = _normalize_rgb_values(mean_value, name="image_mean") + self.image_std = _normalize_rgb_values(std_value, name="image_std") + self.do_convert_rgb = do_convert_rgb + self.resize_mode = resize_mode + self.align_corners = align_corners + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + patch_size: int | None = None, + max_num_patches: int | None = None, + min_num_patches: int | None = None, + pixel_shuffle_scale: int | None = None, + do_rescale: bool | None = None, + rescale_factor: float | None = None, + do_normalize: bool | None = None, + image_mean: float | Sequence[float] | None = None, + image_std: float | Sequence[float] | None = None, + do_convert_rgb: bool | None = None, + return_tensors: str | TensorType | None = None, + ) -> BatchFeature: + patch_size = patch_size if patch_size is not None else self.patch_size + max_num_patches = max_num_patches if max_num_patches is not None else self.max_num_patches + min_num_patches = min_num_patches if min_num_patches is not None else self.min_num_patches + pixel_shuffle_scale = pixel_shuffle_scale if pixel_shuffle_scale is not None else self.pixel_shuffle_scale + do_rescale = self.do_rescale if do_rescale is None else do_rescale + rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor + do_normalize = self.do_normalize if do_normalize is None else do_normalize + image_mean = self.image_mean if image_mean is None else _normalize_rgb_values(image_mean, name="image_mean") + image_std = self.image_std if image_std is None else _normalize_rgb_values(image_std, name="image_std") + do_convert_rgb = self.do_convert_rgb if do_convert_rgb is None else do_convert_rgb + + images = self.fetch_images(images) + images = make_flat_list_of_images(images) + + if not images: + raise ValueError("Received an empty list of images for preprocessing.") + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + if not valid_images(images): + raise ValueError( + "Invalid image type. Expected PIL images, numpy arrays, or tensors convertible to numpy arrays." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + ) + + patches_list = [] + token_grids = [] + virtual_dims = [] + real_dims = [] + + for image in images: + np_image = to_numpy_array(image) + + if np_image.ndim == 2: + np_image = np.repeat(np_image[..., None], 3, axis=-1) + + height, width = np_image.shape[:2] + if height * width > MAX_PIXELS: + raise ValueError(f"Image (w={width}, h={height}) > MAX=`{MAX_PIXELS}`") + + torch_image = torch.from_numpy(_make_writeable(np_image)) + patches, vidims, rdims = self._process_single_image( + torch_image, + patch_size=patch_size, + max_num_patches=max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + ) + + patches_list.append(patches) + token_grids.append(torch.tensor([patches.size(1), patches.size(2)], dtype=torch.long)) + virtual_dims.append(vidims) + real_dims.append(rdims) + + patches_tensor = torch.cat(patches_list, dim=0) + token_grid_tensor = torch.stack(token_grids, dim=0) + virtual_dims_tensor = torch.tensor(virtual_dims, dtype=torch.long) + real_dims_tensor = torch.tensor(real_dims, dtype=torch.long) + + data = { + "patches": patches_tensor, + "token_grids": token_grid_tensor, + "virtual_pixel_size": virtual_dims_tensor, + "real_pixel_size": real_dims_tensor, + } + + return BatchFeature(data=data, tensor_type=return_tensors) + + def _process_single_image( + self, + image: torch.Tensor, + *, + patch_size: int, + max_num_patches: int, + min_num_patches: int | None, + pixel_shuffle_scale: int, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: tuple[float, ...], + image_std: tuple[float, ...], + ) -> tuple[torch.Tensor, list[int], list[int]]: + image_uint8 = image.unsqueeze(0) # (1, H, W, C) + image_chw = image_uint8.permute(0, 3, 1, 2) # (1, C, H, W) + + _, _, orig_height, orig_width = image_chw.shape + target_height, target_width = get_image_size_for_max_num_patches( + orig_height, + orig_width, + patch_size, + max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + ) + + if self.resize_mode in {"linear", "bilinear", "bicubic", "trilinear"}: + resized = F.interpolate( + image_chw, + size=(target_height, target_width), + mode=self.resize_mode, + align_corners=self.align_corners, + ) + else: + resized = F.interpolate( + image_chw, + size=(target_height, target_width), + mode=self.resize_mode, + ) + + resized = resized.permute(0, 2, 3, 1) # (1, H, W, C) + + scale = rescale_factor if do_rescale else 1.0 + mean = image_mean if do_normalize else (0.0, 0.0, 0.0) + std = image_std if do_normalize else (1.0, 1.0, 1.0) + resized = _prepare_image_tensor(resized, scale=scale, mean=mean, std=std) + + resized = _compute_residual_p_frames(resized, is_p_frame=[False]) + + patches = patchify_vision(resized, patch_size=patch_size) + _, h_patches, w_patches, _ = patches.shape + + real_dims = [1, h_patches, w_patches] + if pixel_shuffle_scale > 1: + if (h_patches % pixel_shuffle_scale) or (w_patches % pixel_shuffle_scale): + raise ValueError( + "Spatial dimensions must be divisible by pixel_shuffle_scale when pixel shuffle is enabled." + ) + virtual_dims = [1, h_patches // pixel_shuffle_scale, w_patches // pixel_shuffle_scale] + else: + virtual_dims = real_dims.copy() + + return patches, virtual_dims, real_dims + + +__all__ = ["IsaacImageProcessor"] diff --git a/src/transformers/models/isaac/image_processing_isaac_fast.py b/src/transformers/models/isaac/image_processing_isaac_fast.py new file mode 100644 index 000000000000..72c93f1c3a81 --- /dev/null +++ b/src/transformers/models/isaac/image_processing_isaac_fast.py @@ -0,0 +1,347 @@ +# ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ +# This file was automatically generated from src/transformers/models/isaac/modular_isaac.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_isaac.py file directly. One of our CI enforces this. +# ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ +import math +from collections.abc import Sequence +from typing import Any + +import torch + +from ...feature_extraction_utils import BatchFeature +from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict +from ...image_utils import PILImageResampling +from ...processing_utils import Unpack +from ...tokenization_utils import TensorType +from ...utils import auto_docstring +from .image_processing_isaac import IsaacImageProcessorKwargs + + +# Vision preprocessing constants +VISION_MEAN = (0.5, 0.5, 0.5) +VISION_STD = (0.5, 0.5, 0.5) +VISION_SCALE = 1 / 255 + + +def _normalize_rgb_values( + values: float | Sequence[float] | tuple[float, ...], + *, + name: str, +) -> tuple[float, float, float]: + """Coerce RGB normalization parameters into a 3-tuple of floats.""" + if isinstance(values, (list, tuple)): + if len(values) == 3: + return tuple(float(v) for v in values) + if len(values) == 1: + value = float(values[0]) + return (value, value, value) + raise ValueError(f"`{name}` must have length 1 or 3 when provided as a sequence. Got length {len(values)}.") + + value = float(values) + return (value, value, value) + + +# ============================================================================ +# Configuration +# ============================================================================ + +MAX_PIXELS = 60_000_000 # 60โ€‘megapixel ceiling โ‰ˆ 8200 ร— 7300 px + + +def get_scaled_image_size( + scale: float, + original_size: int, + patch_size: int, + pixel_shuffle_scale: int, +) -> int: + scaled_size = scale * original_size + divisor = patch_size * pixel_shuffle_scale + scaled_size = math.ceil(scaled_size / divisor) * divisor + scaled_size = max(divisor, scaled_size) + return int(scaled_size) + + +def get_image_size_for_max_num_patches( + image_height: int, + image_width: int, + patch_size: int, + max_num_patches: int, + min_num_patches: int | None = None, + eps: float = 1e-5, + pixel_shuffle_scale: int = 1, +) -> tuple[int, int]: + r"""Compute a target resolution whose patch grid satisfies patching parametrization. + + Args: + image_height (`int`): + Height in pixels of the source image prior to any resizing. + image_width (`int`): + Width in pixels of the source image prior to any resizing. + patch_size (`int`): + Size of the square patch used by the vision encoder. + max_num_patches (`int`): + Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. + min_num_patches (`int`, *optional*): + Lower bound on the number of patches. When provided the image will be scaled up if necessary. + eps (`float`, *optional*, defaults to 1e-5): + Convergence tolerance for the internal binary search to determing the target dimensions. + pixel_shuffle_scale (`int`, *optional*, defaults to 1): + Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. + + Returns: + `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` + and respect both the maximum and optional minimum patch-count constraints. + """ + + # Ensure divisibility + divisor = patch_size * pixel_shuffle_scale + adjusted_height = math.ceil(image_height / divisor) * divisor + adjusted_height = max(divisor, adjusted_height) + adjusted_width = math.ceil(image_width / divisor) * divisor + adjusted_width = max(divisor, adjusted_width) + + num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) + + if min_num_patches is not None and num_patches < min_num_patches: + # Scale up + scale_min, scale_max = 1.0, 100.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + num_patches = (target_height / patch_size) * (target_width / patch_size) + if num_patches >= min_num_patches: + scale_max = scale + else: + scale_min = scale + scale = scale_max + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + return target_height, target_width + elif num_patches <= max_num_patches: + return adjusted_height, adjusted_width + else: + # Scale down + scale_min, scale_max = eps / 10, 1.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + num_patches = (target_height / patch_size) * (target_width / patch_size) + if num_patches <= max_num_patches: + scale_min = scale + else: + scale_max = scale + scale = scale_min + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + return target_height, target_width + + +def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: + r"""Convert normalized images into flattened ViT-style patches. + + Args: + image (`torch.Tensor`): + Tensor of shape `(num_images, height, width, channels)`. + patch_size (`int`): + Edge length of the square patches + + Returns: + `torch.Tensor`: + Patch tensor where each position stores the flattened pixels belonging to that patch. + + Raises: + ValueError: If `height` or `width` is not divisible by `patch_size`. + """ + num_images, height, width, channels = image.shape + if height % patch_size or width % patch_size: + raise ValueError(f"Dimensions of images {image.shape} are not divisible by patch_size={patch_size}.") + patches = image.reshape(num_images, height // patch_size, patch_size, width // patch_size, patch_size, channels) + patches = patches.permute(0, 1, 3, 2, 4, 5) + patches = patches.reshape( + num_images, height // patch_size, width // patch_size, channels * patch_size * patch_size + ) + return patches + + +def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> torch.Tensor: + """Compute residuals for P-frames to stay in sync with the training pipeline.""" + if not any(is_p_frame): + return frames + + frame_indices = torch.arange(len(is_p_frame), device=frames.device) + i_frame_mask = torch.tensor([not flag for flag in is_p_frame], device=frames.device) + last_i_indices = torch.cummax((i_frame_mask * (1 + frame_indices)), dim=0).values.long() - 1 + p_indices = frame_indices[torch.tensor(is_p_frame, device=frames.device)] + frames[p_indices] = frames[p_indices] - frames[last_i_indices[p_indices]] + return frames + + +@auto_docstring +class IsaacImageProcessorFast(BaseImageProcessorFast): + r"""Fast torch-based image processor for Isaac vision inputs.""" + + slow_image_processor_class = "IsaacImageProcessor" + + resample = PILImageResampling.BILINEAR + model_input_names = ["patches", "token_grids"] + valid_kwargs = IsaacImageProcessorKwargs + unused_kwargs = ["size", "do_center_crop", "crop_size"] + + def __init__( + self, + *, + patch_size: int = 16, + max_num_patches: int = 256, + min_num_patches: int | None = None, + pixel_shuffle_scale: int = 1, + do_rescale: bool = True, + rescale_factor: float | None = None, + do_normalize: bool = True, + image_mean: float | Sequence[float] | None = None, + image_std: float | Sequence[float] | None = None, + do_convert_rgb: bool = True, + **kwargs: Unpack[IsaacImageProcessorKwargs], + ) -> None: + super().__init__(**kwargs) + + if pixel_shuffle_scale < 1: + raise ValueError("`pixel_shuffle_scale` must be >= 1") + + mean_values = _normalize_rgb_values(image_mean if image_mean is not None else VISION_MEAN, name="image_mean") + std_values = _normalize_rgb_values(image_std if image_std is not None else VISION_STD, name="image_std") + + self.patch_size = patch_size + self.max_num_patches = max_num_patches + self.min_num_patches = min_num_patches + self.pixel_shuffle_scale = pixel_shuffle_scale + self.do_rescale = do_rescale + self.rescale_factor = VISION_SCALE if rescale_factor is None else float(rescale_factor) + self.do_normalize = do_normalize + self.image_mean = list(mean_values) + self.image_std = list(std_values) + self.do_convert_rgb = do_convert_rgb + + def _validate_preprocess_kwargs(self, **kwargs): + # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for. + kwargs.pop("do_resize", None) + return super()._validate_preprocess_kwargs(**kwargs) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + patch_size: int, + max_num_patches: int, + interpolation: Any | None, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: float | Sequence[float] | None, + image_std: float | Sequence[float] | None, + return_tensors: str | TensorType | None, + *, + min_num_patches: int | None = None, + pixel_shuffle_scale: int | None = None, + do_convert_rgb: bool | None = None, + **kwargs, + ) -> BatchFeature: + if TVF is None: + raise ImportError("torchvision is required for IsaacImageProcessorFast but is not installed.") + + min_num_patches = min_num_patches if min_num_patches is not None else self.min_num_patches + pixel_shuffle_scale = pixel_shuffle_scale if pixel_shuffle_scale is not None else self.pixel_shuffle_scale + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_rescale = self.do_rescale if do_rescale is None else do_rescale + do_normalize = self.do_normalize if do_normalize is None else do_normalize + rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor + + mean_values = _normalize_rgb_values( + image_mean if image_mean is not None else self.image_mean, name="image_mean" + ) + std_values = _normalize_rgb_values(image_std if image_std is not None else self.image_std, name="image_std") + + patches_list: list[torch.Tensor] = [] + token_grids: list[torch.Tensor] = [] + virtual_dims: list[list[int]] = [] + real_dims: list[list[int]] = [] + + for image in images: + if image.ndim != 3: + raise ValueError("Expected channel-first image tensor with shape (C, H, W).") + + channels, original_height, original_width = image.shape + if do_convert_rgb and channels == 1: + image = image.repeat(3, 1, 1) + channels = 3 + + if original_height * original_width > MAX_PIXELS: + raise ValueError(f"Image (w={original_width}, h={original_height}) > MAX=`{MAX_PIXELS}`") + + target_height, target_width = get_image_size_for_max_num_patches( + original_height, + original_width, + patch_size, + max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + ) + + if do_resize: + size_dict = SizeDict(height=target_height, width=target_width) + image = self.resize(image=image, size=size_dict, interpolation=interpolation) + else: + if ((original_height % patch_size) != 0) or ((original_width % patch_size) != 0): + raise ValueError("Image dimensions must be divisible by patch_size when resize is disabled.") + + # Apply rescaling and normalization as needed + image = self.rescale_and_normalize( + image, + do_rescale, + rescale_factor, + do_normalize, + list(mean_values), + list(std_values), + ) + + # Convert to NHWC for residual P-frame adjustment and patch extraction + nhwc_image = image.permute(1, 2, 0).unsqueeze(0) + nhwc_image = _compute_residual_p_frames(nhwc_image, is_p_frame=[False]) + + patches = patchify_vision(nhwc_image, patch_size=patch_size).squeeze(0) + height_tokens, width_tokens, _ = patches.shape + + patches_list.append(patches.unsqueeze(0)) + token_grids.append(torch.tensor([height_tokens, width_tokens], dtype=torch.long, device=patches.device)) + + real_dims.append([1, height_tokens, width_tokens]) + if pixel_shuffle_scale > 1: + if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): + raise ValueError( + "Spatial dimensions must be divisible by pixel_shuffle_scale when pixel shuffle is enabled." + ) + virtual_dims.append([1, height_tokens // pixel_shuffle_scale, width_tokens // pixel_shuffle_scale]) + else: + virtual_dims.append([1, height_tokens, width_tokens]) + + patches_tensor = torch.cat(patches_list, dim=0) + token_grids_tensor = torch.stack(token_grids, dim=0) + virtual_dims_tensor = torch.tensor(virtual_dims, dtype=torch.long, device=patches_tensor.device) + real_dims_tensor = torch.tensor(real_dims, dtype=torch.long, device=patches_tensor.device) + + batch_feature = BatchFeature( + data={ + "patches": patches_tensor, + "token_grids": token_grids_tensor, + "virtual_pixel_size": virtual_dims_tensor, + "real_pixel_size": real_dims_tensor, + }, + tensor_type=return_tensors, + ) + return batch_feature + + +__all__ = ["IsaacImageProcessorFast"] diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index f182f2bb6477..4ebf79738d7a 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -5,36 +5,41 @@ # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ +import copy from collections import defaultdict -from typing import Any, Callable, Optional, TypedDict +from collections.abc import Callable +from typing import Any, Optional, TypedDict import torch import torch.nn as nn import torch.nn.functional as F -from perceptron.tensorstream import TensorStream, TextType, VisionType, group_streams -from perceptron.tensorstream.ops import ( +from genesis.public.tensorstream.tensor_stream import TensorStream, TextType, VisionType, group_streams +from genesis.public.tensorstream.tensor_stream_utils import ( compute_mrope_pos_tensor, modality_mask, reconstruct_tensor_stream_from_compact_dict, ) from ...activations import ACT2FN -from ...cache_utils import Cache -from ...generation import GenerationMixin +from ...cache_utils import Cache, SlidingWindowCache, StaticCache +from ...generation.utils import GenerationMixin from ...integrations import use_kernel_forward_from_hub +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.deprecation import deprecate_kwarg -from ...utils.generic import check_model_inputs -from .configuration_isaac import IsaacConfig, PixelShuffleSiglip2VisionConfig +from ...utils.import_utils import is_torchdynamo_compiling +from ..auto.modeling_auto import AutoModel +from .configuration_isaac import IsaacConfig, IsaacVisionConfig -class Siglip2VariableSequenceEmbeddings(nn.Module): - def __init__(self, config: PixelShuffleSiglip2VisionConfig): +class IsaacVisionEmbeddings(nn.Module): + def __init__(self, config: IsaacVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -49,9 +54,7 @@ def __init__(self, config: PixelShuffleSiglip2VisionConfig): self.position_embedding_size = int(self.num_patches**0.5) self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) - def positional_embeddings( - self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor] - ) -> torch.Tensor: + def positional_embeddings(self, spatial_shapes: torch.Tensor) -> torch.Tensor: # Prepare positional embeddings grid: (1, embed_dim, h, w) positional_embeddings = ( self.position_embedding.weight.reshape(self.position_embedding_size, self.position_embedding_size, -1) @@ -59,11 +62,9 @@ def positional_embeddings( .unsqueeze(0) ) - _seq_patches, _seq_sizes, spatial_shapes = packed_seq_patches pos_embeds_list = [] mode = "bilinear" align_corners = False - antialias = True for spatial_shape in spatial_shapes: height, width = spatial_shape # Guard to ensure height and width are positive for torch.compile @@ -73,35 +74,119 @@ def positional_embeddings( size=(height, width), mode=mode, align_corners=align_corners, - antialias=antialias, + antialias=True, ) # Reshape from (1, embed_dim, height, width) to (height*width, embed_dim) resized_pos_embed = resized_pos_embed.reshape(self.embed_dim, height * width).transpose(0, 1) else: # Fallback - should never happen in practice - resized_pos_embed = positional_embeddings.reshape( - self.embed_dim, self.position_embedding_size * self.position_embedding_size - ).transpose(0, 1)[: height * width] + raise RuntimeError( + "Encountered non-positive spatial dimensions while computing positional embeddings." + ) pos_embeds_list.append(resized_pos_embed) # Concatenate all positional embeddings along the sequence dimension pos_embeds = torch.cat(pos_embeds_list, dim=0) return pos_embeds - def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor]): - seq_patches, _seq_sizes, _spatial_shapes = packed_seq_patches - + def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor): # Apply patch embeddings target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(seq_patches.to(dtype=target_dtype)) - pos_embeds = self.positional_embeddings(packed_seq_patches) + pos_embeds = self.positional_embeddings(spatial_shapes) # Add positional embeddings to patch embeddings embeddings = patch_embeds + pos_embeds return embeddings -class Siglip2VariableLengthAttention(nn.Module): +def _max_from_cu(cu: torch.Tensor | None, fallback: int) -> int: + """Helper to compute max sequence length from cumulative sequence lengths.""" + if cu is None or len(cu) < 2: + return fallback + return int((cu[1:] - cu[:-1]).max().item()) + + +def flash_attention_document_mask_forward( + module: torch.nn.Module, + q_lhd: torch.Tensor, # (L, H, D) + k_lhd: torch.Tensor, # (L, H, D) + v_lhd: torch.Tensor, # (L, H, D) + attention_mask: torch.Tensor | None = None, # unused for FA path + dropout: float = 0.0, + scaling: float | None = None, + cum_seq_q: torch.Tensor | None = None, + cum_seq_k: torch.Tensor | None = None, + max_seqlen: int | None = None, + is_causal: bool = False, + **kwargs, +) -> tuple[torch.Tensor, None]: + """FlashAttention that consumes (L, H, D) directly to avoid layout churn.""" + L, H, D = q_lhd.shape + + # Compute max block length once (honor caller when provided) + if max_seqlen is not None: + max_q = max_k = int(max_seqlen) + else: + max_q = _max_from_cu(cum_seq_q, L) + max_k = _max_from_cu(cum_seq_k, L) + + # Ensure contiguity only if needed + if not q_lhd.is_contiguous(): + q_lhd = q_lhd.contiguous() + if not k_lhd.is_contiguous(): + k_lhd = k_lhd.contiguous() + if not v_lhd.is_contiguous(): + v_lhd = v_lhd.contiguous() + + out_lhd, *_ = torch.ops.aten._flash_attention_forward( + query=q_lhd, # (L, H, D) + key=k_lhd, # (L, H, D) + value=v_lhd, # (L, H, D) + cum_seq_q=cum_seq_q, + cum_seq_k=cum_seq_k, + max_q=max_q, + max_k=max_k, + dropout_p=dropout, + is_causal=is_causal, + return_debug_mask=False, + scale=scaling, + window_size_left=-1, + window_size_right=-1, + alibi_slopes=None, + ) + return out_lhd, None # (L, H, D) + + +def sdpa_document_mask_forward( + q_lhd: torch.Tensor, # (L, H, D) + k_lhd: torch.Tensor, # (L, H, D) + v_lhd: torch.Tensor, # (L, H, D) + dropout: float, + scaling: float | None, + cu_seqlens: torch.Tensor | None, +) -> torch.Tensor: + """SDPA with block-diagonal masking for variable-length sequences.""" + L, H, D = q_lhd.shape + + # Transpose to (1, H, L, D) format for SDPA + Q = q_lhd.permute(1, 0, 2).unsqueeze(0) + K = k_lhd.permute(1, 0, 2).unsqueeze(0) + V = v_lhd.permute(1, 0, 2).unsqueeze(0) + + # Build block-diagonal mask for variable-length sequences + attn_mask = None + if cu_seqlens is not None: + seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() + seg_ids = torch.repeat_interleave(torch.arange(len(seq_sizes), device=q_lhd.device), seq_sizes) + block_mask = seg_ids[:, None] != seg_ids[None, :] # Cross-document attention blocked + attn_mask = torch.where(block_mask, -torch.inf, 0.0).to(q_lhd.dtype).view(1, 1, L, L) + + Y = F.scaled_dot_product_attention(Q, K, V, attn_mask=attn_mask, dropout_p=dropout, scale=scaling) + return Y.squeeze(0).permute(1, 0, 2) # Back to (L, H, D) + + +class IsaacVisionAttention(nn.Module): """Custom attention that supports variable-length sequences with flash attention.""" def __init__(self, config): @@ -124,71 +209,51 @@ def __init__(self, config): self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None): - batch_size, seq_len, _ = hidden_states.size() - - # For variable-length attention, we need to reshape to (total_tokens, embed_dim) + # Expect packed sequences with batch_size == 1 + batch_size, L, _ = hidden_states.shape if batch_size != 1: - raise ValueError("Variable-length attention expects batch_size=1 for packed sequences") - hidden_states = hidden_states.squeeze(0) # Remove batch dimension: (seq_len, embed_dim) - - # Store original dtype - orig_dtype = hidden_states.dtype - - # 1. Linear projections - Q = self.q_proj(hidden_states) # (seq_len, embed_dim) - K = self.k_proj(hidden_states) # (seq_len, embed_dim) - V = self.v_proj(hidden_states) # (seq_len, embed_dim) - - # 2. Reshape for multi-head attention: (seq_len, n_heads, head_dim) - Q = Q.view(-1, self.num_heads, self.embed_dim // self.num_heads) - K = K.view(-1, self.num_heads, self.embed_dim // self.num_heads) - V = V.view(-1, self.num_heads, self.embed_dim // self.num_heads) - - # 3. Apply variable-length attention using flash attention - attn_output, _, _, _, _ = torch.ops.aten._flash_attention_forward( - query=Q, - key=K, - value=V, - cum_seq_q=cu_seqlens, - cum_seq_k=cu_seqlens, - max_q=max_seqlen, - max_k=max_seqlen, - dropout_p=self.dropout if self.training else 0.0, - is_causal=False, - return_debug_mask=False, - scale=self.scale, - window_size_left=-1, - window_size_right=-1, - alibi_slopes=None, - ) - - # 4. Reshape attention output from (seq_len, n_heads, head_dim) to (seq_len, embed_dim) - attn_output = attn_output.reshape(seq_len, self.embed_dim) - - # 5. Convert back to original dtype if needed - if attn_output.dtype != orig_dtype: - attn_output = attn_output.to(orig_dtype) - - # 6. Project output - attn_output = self.out_proj(attn_output) # (seq_len, embed_dim) - - # 7. Add back batch dimension for compatibility - attn_output = attn_output.unsqueeze(0) # (1, seq_len, embed_dim) - - return attn_output, None + raise ValueError("packed variable-length attention expects batch_size=1") + x = hidden_states[0] # (L, E) + + H = self.num_heads + D = self.head_dim + p_drop = self.dropout if self.training else 0.0 + + # Project and reshape to (L, H, D) + q = self.q_proj(x).view(L, H, D) + k = self.k_proj(x).view(L, H, D) + v = self.v_proj(x).view(L, H, D) + + attn_impl = getattr(self.config, "_attn_implementation", "flash_attention_3") + + if attn_impl in ("flash_attention_2", "flash_attention_3"): + y_lhd, _ = flash_attention_document_mask_forward( + self, + q, + k, + v, + attention_mask=None, + dropout=p_drop, + scaling=self.scale, + cum_seq_q=cu_seqlens, + cum_seq_k=cu_seqlens, + max_seqlen=max_seqlen, + is_causal=False, + ) + else: + y_lhd = sdpa_document_mask_forward(q, k, v, dropout=p_drop, scaling=self.scale, cu_seqlens=cu_seqlens) + # Merge heads and project + y = self.out_proj(y_lhd.reshape(L, self.embed_dim)) + return y.unsqueeze(0), None # (1, L, E) -class IsaacSiglip2EncoderLayer(nn.Module): - """Siglip2 encoder layer with variable-length attention.""" - def __init__(self, config: PixelShuffleSiglip2VisionConfig): - super().__init__() - self.embed_dim = config.hidden_size - self.self_attn = Siglip2VariableLengthAttention(config) +class IsaacVisionEncoderLayer(HFSiglip2EncoderLayer): + """Isaac vision encoder layer with variable-length attention.""" - self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = Siglip2MLP(config) # Use HF's Siglip2MLP - self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + def __init__(self, config: IsaacVisionConfig): + super().__init__(config) + self.self_attn = IsaacVisionAttention(config) def forward( self, @@ -216,13 +281,13 @@ def forward( return (hidden_states,) -class IsaacEncoder(nn.Module): +class IsaacVisionEncoder(nn.Module): """Encoder using Isaac encoder layers with variable-length attention support.""" - def __init__(self, config: PixelShuffleSiglip2VisionConfig): + def __init__(self, config: IsaacVisionConfig): super().__init__() self.config = config - self.layers = nn.ModuleList([IsaacSiglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList([IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) def forward( self, @@ -253,14 +318,6 @@ def forward( return hidden_states, all_hidden_states, None -def create_cumulative_seq_lengths(seq_sizes: torch.Tensor, device: torch.device) -> tuple[torch.Tensor, int]: - """Create cumulative sequence lengths for variable-length attention.""" - cu_seqlens = torch.zeros(len(seq_sizes) + 1, dtype=torch.int32, device=device) - cu_seqlens[1:] = seq_sizes.cumsum(0) - max_seqlen = int(seq_sizes.max().item()) if len(seq_sizes) > 0 else 0 - return cu_seqlens, max_seqlen - - def create_pixel_shuffle_index_map( seq_sizes: torch.Tensor, token_grids: torch.Tensor, @@ -288,16 +345,17 @@ def create_pixel_shuffle_index_map( if device is None: device = seq_sizes.device - r = int(scale_factor) - if r < 2: + scale_factor = int(scale_factor) + if scale_factor < 2: raise ValueError("`scale_factor` must be โ‰ฅ 2") - # Safety: all spatial dims must be divisible by r + # Safety: all spatial dims must be divisible by the scale factor # Cannot run under torch compile fullgraph mode hence - if not torch.compiler.is_compiling(): - if not ((token_grids[:, 0] % r == 0).all() and (token_grids[:, 1] % r == 0).all()): + if not is_torchdynamo_compiling(): + if not ((token_grids[:, 0] % scale_factor == 0).all() and (token_grids[:, 1] % scale_factor == 0).all()): raise AssertionError( - f"Every (H,W) in `token_grids` must be divisible by scale_factor={r}, got {token_grids.tolist()}" + "Every (H,W) in `token_grids` must be divisible by " + f"scale_factor={scale_factor}, got {token_grids.tolist()}" ) gather_chunks: list[torch.Tensor] = [] @@ -309,19 +367,21 @@ def create_pixel_shuffle_index_map( grid = grid.view(h, w) # (H, W) # -------- identical ordering to your fixed-res routine -------- - # Step 1: split width into blocks of r - grid = grid.view(h, w // r, r) # (H, W/r, r) - # Step 2: now split height into blocks of r - grid = grid.view(h // r, r, w // r, r) # (H/r, r, W/r, r) - # Step 3: final permutation to (H/r, W/r, r, r) - grid = grid.permute(0, 2, 1, 3).contiguous() # (H/r, W/r, r, r) - # Step 4: each (r, r) block forms one output token - gather_chunks.append(grid.reshape(-1, r * r)) # (H*W / rยฒ, rยฒ) + # Step 1: split width into blocks of scale_factor + grid = grid.view(h, w // scale_factor, scale_factor) # (H, W/scale_factor, scale_factor) + # Step 2: now split height into blocks of scale_factor + grid = grid.view(h // scale_factor, scale_factor, w // scale_factor, scale_factor) + # (H/scale_factor, scale_factor, W/scale_factor, scale_factor) + # Step 3: final permutation to (H/scale_factor, W/scale_factor, scale_factor, scale_factor) + grid = grid.permute(0, 2, 1, 3).contiguous() # (H/scale_factor, W/scale_factor, scale_factor, scale_factor) + # Step 4: each (scale_factor, scale_factor) block forms one output token + gather_chunks.append(grid.reshape(-1, scale_factor * scale_factor)) + # (H*W / scale_factor**2, scale_factor**2) tok_offset += seq_len # Concatenate over all images in the packed batch - gather_idx = torch.cat(gather_chunks, dim=0) # (ฮฃ_i HแตขWแตข/rยฒ, rยฒ) + gather_idx = torch.cat(gather_chunks, dim=0) # (ฮฃ_i HแตขWแตข/scale_factor**2, scale_factor**2) return gather_idx @@ -360,7 +420,7 @@ def pixel_shuffle_varlen( x_ = x # (seq, embed) embed_dim = x_.size(-1) - r = int(scale_factor) + scale_factor = int(scale_factor) # Calculate seq_sizes from token_grids seq_sizes = torch.prod(token_grids, dim=-1) @@ -369,15 +429,15 @@ def pixel_shuffle_varlen( gather_idx = create_pixel_shuffle_index_map( seq_sizes=seq_sizes, token_grids=token_grids, - scale_factor=r, + scale_factor=scale_factor, device=x_.device, - ) # (new_seq, rยฒ) + ) # (new_seq, scale_factor**2) - # Gather โ†’ (new_seq, rยฒ, embed_dim) + # Gather โ†’ (new_seq, scale_factor**2, embed_dim) gathered = x_[gather_idx] # fancy indexing keeps gradient - # Merge the rยฒ group dimension into channels to finish the shuffle - out = gathered.reshape(gathered.size(0), embed_dim * r * r) + # Merge the scale_factor**2 group dimension into channels to finish the shuffle + out = gathered.reshape(gathered.size(0), embed_dim * scale_factor * scale_factor) # Restore batch dimension if needed if keep_batch_dim: @@ -385,12 +445,12 @@ def pixel_shuffle_varlen( return out -class Siglip2SequenceVisionTransformer(nn.Module): - def __init__(self, config: PixelShuffleSiglip2VisionConfig): +class IsaacVisionTransformer(nn.Module): + def __init__(self, config: IsaacVisionConfig): super().__init__() self.config = config - self.embeddings = Siglip2VariableSequenceEmbeddings(config) - self.encoder = IsaacEncoder(config) + self.embeddings = IsaacVisionEmbeddings(config) + self.encoder = IsaacVisionEncoder(config) self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor @@ -399,13 +459,15 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): seq_sizes = torch.prod(token_grids, dim=-1) # Get embeddings from packed sequence - hidden_states = self.embeddings((seq_patches, seq_sizes, token_grids)) + hidden_states = self.embeddings(seq_patches, token_grids) # Add a pseudo batch dimension for the encoder hidden_states = hidden_states.unsqueeze(0) # Generate cumulative sequence lengths for variable-length attention - cu_seqlens, max_seqlen = create_cumulative_seq_lengths(seq_sizes, hidden_states.device) + cu_seqlens = torch.zeros(seq_sizes.size(0) + 1, dtype=torch.int32, device=hidden_states.device) + cu_seqlens[1:] = seq_sizes.cumsum(0) + max_seqlen = int(seq_sizes.max().item()) if seq_sizes.numel() > 0 else 0 # Pass through encoder with variable-length attention parameters hidden_states, _, _ = self.encoder( @@ -440,14 +502,6 @@ class RopeScaling(TypedDict, total=False): original_max_position_embeddings: int -def precompute_inv_freq(theta: float, dim: int) -> torch.Tensor: - """ - Returns shape (dim//2,). - """ - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - return inv_freq # type: ignore[return-value] - - def precompute_cos_sin_3d( position_ids: torch.Tensor, # shape (3, B, T) inv_freq: torch.Tensor, # shape (dim//2,) @@ -498,42 +552,76 @@ def precompute_cos_sin_3d( class IsaacRotaryEmbedding(nn.Module): + EXTRA_ROPE_KEYS = {"mrope_section", "mrope_interleaved"} + def __init__(self, config: IsaacConfig, device=None): super().__init__() - # Extract dimensions from config - self.hidden_size = config.hidden_size - self.num_attention_heads = config.num_attention_heads - self.head_dim = config.head_dim - - # Get rope_scaling config - use direct access when available + self.config = config rope_scaling = getattr(config, "rope_scaling", None) or {} + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) + if rope_type not in ROPE_INIT_FUNCTIONS: + raise ValueError(f"Unsupported rope_type '{rope_type}' for IsaacRotaryEmbedding") - # Read RopeScaling parameters - self.rope_type = rope_scaling.get("rope_type", "default") + self.rope_type = rope_type + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] - self.mrope_section = [ - self.head_dim // 4, # 2x more for temporal dim - self.head_dim // 8, - self.head_dim // 8, - ] + sanitized_scaling = {k: v for k, v in rope_scaling.items() if k not in self.EXTRA_ROPE_KEYS} + if sanitized_scaling != rope_scaling: + config_for_rope = copy.copy(config) + config_for_rope.rope_scaling = sanitized_scaling + else: + config_for_rope = config - rope_base = getattr(config, "rope_theta", 10000.0) - inv_freq = precompute_inv_freq(rope_base, self.head_dim) + init_device = device if device is not None and getattr(device, "type", None) != "meta" else None + inv_freq, attention_scaling = rope_init_fn(config_for_rope, device=init_device) self.register_buffer("inv_freq", inv_freq, persistent=False) + self.attention_scaling = self._normalize_scale(attention_scaling) + + rotary_half_dim = self.inv_freq.shape[0] + self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim) + + @staticmethod + def _normalize_scale(scale: torch.Tensor | float) -> torch.Tensor | float: + if isinstance(scale, torch.Tensor): + return scale.detach().clone() + return float(scale) + + @staticmethod + def _resolve_mrope_section(section: list[int] | None, rotary_half_dim: int) -> list[int]: + if section is None: + weights = (2, 1, 1) + base = [rotary_half_dim * w // sum(weights) for w in weights] + base[0] += rotary_half_dim - sum(base) + return base + + section = [int(v) for v in section] + if len(section) != 3: + raise ValueError("`mrope_section` must contain exactly three elements (temporal, height, width)") + if sum(section) != rotary_half_dim: + raise ValueError( + f"`mrope_section` must sum to the rotary half-dimension ({rotary_half_dim}). Received {section}." + ) + return section def forward(self, position_ids: torch.Tensor, modality_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: with torch.no_grad(): - # Ensure non-spatial tokens have 1D rotation equivalence - not_spatial = ~(modality_tensor == VisionType.image.value) - # shape is [N, 1] - data_1d = position_ids[not_spatial][..., 0].unsqueeze(-1) - # now broadcast it from [N, 1] -> [N, D] so it matches pos[not_spatial] exactly - data_1d = data_1d.expand(-1, position_ids.shape[-1]) # expand along the last dim - position_ids = position_ids.clone() # Clone to avoid warning about in-place operations on expanded tensors - position_ids[not_spatial] = data_1d - position_ids = position_ids.permute(2, 0, 1) # pos dim first -> (3, B, L) + position_ids = position_ids.clone() + not_spatial = modality_tensor != VisionType.image.value + if not_spatial.any(): + data_1d = position_ids[not_spatial][..., 0].unsqueeze(-1) + position_ids[not_spatial] = data_1d.expand(-1, position_ids.shape[-1]) + + position_ids = position_ids.permute(2, 0, 1) cos, sin = precompute_cos_sin_3d(position_ids, self.inv_freq, self.mrope_section) + scale = self.attention_scaling + if isinstance(scale, torch.Tensor): + scale = scale.to(device=cos.device, dtype=cos.dtype) + elif scale != 1.0: + scale = cos.new_tensor(scale) + if isinstance(scale, torch.Tensor) or scale != 1.0: + cos = cos * scale + sin = sin * scale return cos, sin @@ -769,25 +857,6 @@ def forward( return hidden_states -@auto_docstring -class IsaacPreTrainedModel(PreTrainedModel): - config: IsaacConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["IsaacDecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn = True - _supports_sdpa = True - _supports_flex_attn = True - - _can_compile_fullgraph = True - _supports_attention_backend = True - _can_record_outputs = { - "hidden_states": IsaacDecoderLayer, - "attentions": IsaacAttention, - } - - # ============================================================================ # Model # ============================================================================ @@ -812,28 +881,48 @@ def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: @auto_docstring -class IsaacModel(IsaacPreTrainedModel): +class IsaacModel(PreTrainedModel): + config: IsaacConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["IsaacDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": IsaacDecoderLayer, + "attentions": IsaacAttention, + } + def __init__(self, config: IsaacConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size + Qwen3PreTrainedModel.__init__(self, config) + + text_cfg_source = getattr(config, "get_text_config", lambda: config)() + text_cfg = copy.deepcopy(text_cfg_source) + text_cfg._attn_implementation = config._attn_implementation + self.text_model = AutoModel.from_config(text_cfg) + # Ensure downstream callers observe the composed config + self.text_model.config = config - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = torch.nn.ModuleList( - [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = IsaacRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) - self.gradient_checkpointing = False - self.has_sliding_layers = "sliding_attention" in self.config.layer_types vision_cfg = config.vision_config + # Use vision_attn_implementation if specified, otherwise fall back to general attn_implementation + vision_cfg._attn_implementation = ( + config.vision_attn_implementation + if config.vision_attn_implementation is not None + else config._attn_implementation + ) if vision_cfg is None: raise ValueError("IsaacConfig should always have vision_config") hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2) self.vision_embedding = nn.Sequential( - Siglip2SequenceVisionTransformer(vision_cfg), + IsaacVisionTransformer(vision_cfg), nn.Linear( hidden_dim, 4 * hidden_dim, @@ -849,11 +938,80 @@ def __init__(self, config: IsaacConfig): VisionType: self.embed_vision, } - # Initialize weights and apply final processing - self.post_init() + def get_input_embeddings(self) -> nn.Module: + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value: nn.Module) -> None: + self.text_model.set_input_embeddings(value) + + @property + def embed_tokens(self) -> nn.Module: + return self.text_model.embed_tokens + + @embed_tokens.setter + def embed_tokens(self, value: nn.Module) -> None: + self.text_model.embed_tokens = value + + @property + def layers(self) -> nn.ModuleList: + return self.text_model.layers + + @property + def norm(self) -> nn.Module: + return self.text_model.norm + + def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None): + self.text_model._set_gradient_checkpointing( + enable=enable, gradient_checkpointing_func=gradient_checkpointing_func + ) + + def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: + """Embed text tokens, squeezing singleton dimensions.""" + # Text events are shaped as (..., 1); squeeze the singleton index dim + h = self.text_model.embed_tokens(token_ids) + if h.dim() >= 2 and h.size(-2) == 1: + h = h[..., 0, :] + return h + + def embed_vision(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + """Embed vision tokens using the vision encoder.""" + # vision tokens is (seq_patches, token_grids) + return self.vision_embedding(vision_tokens) + + def embed_stream(self, tensor_stream: TensorStream) -> torch.Tensor: + """ + Embed each modality stream independently, preserving the original TensorStream + structure. + """ + flat_stream = tensor_stream.flat_stream() + per_modality_stream = group_streams(flat_stream, group_fn=lambda ev: ev.type, schedule=False) + per_modality_compact_stream = {k: v.compact() for k, v in per_modality_stream.items()} + + # Collect per-event grids for vision tokens (H, W like dims sans time) + token_grids = defaultdict(list) + for stream in tensor_stream.streams: + for event in stream: + token_grids[event.type].append(event.dims(virtual=False)) + + embedded_compact = {} + for stream_type, modality_payload_tensor in per_modality_compact_stream.items(): + if stream_type.modality == VisionType: + # Build a (N_events, 2) grid tensor with spatial dims only + grids = token_grids.get(stream_type, []) + if len(grids) == 0: + input_tensor = modality_payload_tensor + else: + token_grids_tensor = torch.tensor(grids, dtype=torch.long, device=tensor_stream.device)[:, 1:] + input_tensor = (modality_payload_tensor, token_grids_tensor) + embedded_compact[stream_type] = self.embed_fns[stream_type.modality](input_tensor) + else: + embedded_compact[stream_type] = self.embed_fns[stream_type.modality](modality_payload_tensor) + + # Reconstruct a TensorStream with embedded payloads and compact + embedded_ts = reconstruct_tensor_stream_from_compact_dict(tensor_stream, embedded_compact) + h = embedded_ts.compact() # (B, T, D) + return h - @check_model_inputs - @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -892,7 +1050,7 @@ def forward( elif input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.text_model.embed_tokens(input_ids) # Create text modality tensor if not provided if modality_tensor is None: batch_size, seq_length = input_ids.shape @@ -923,7 +1081,7 @@ def forward( # Initialize hidden states hidden_states = inputs_embeds - for decoder_layer in self.layers: + for decoder_layer in self.text_model.layers: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, @@ -935,62 +1093,186 @@ def forward( **kwargs, ) - hidden_states = layer_outputs[0] + hidden_states = layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs # Final layer norm - hidden_states = self.norm(hidden_states) + hidden_states = self.text_model.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, ) - def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: - """Embed text tokens, squeezing singleton dimensions.""" - # Text events are shaped as (..., 1); squeeze the singleton index dim - h = self.embed_tokens(token_ids) - if h.dim() >= 2 and h.size(-2) == 1: - h = h[..., 0, :] - return h + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) - def embed_vision(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - """Embed vision tokens using the vision encoder.""" - # vision tokens is (seq_patches, token_grids) - return self.vision_embedding(vision_tokens) + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) - def embed_stream(self, tensor_stream: TensorStream) -> torch.Tensor: + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen3Config, + past_key_values: Cache, + ): """ - Embed each modality stream independently, preserving the original TensorStream - structure. + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen3Config`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate """ - flat_stream = tensor_stream.flat_stream() - per_modality_stream = group_streams(flat_stream, group_fn=lambda ev: ev.type, schedule=False) - per_modality_compact_stream = {k: v.compact() for k, v in per_modality_stream.items()} + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask - # Collect per-event grids for vision tokens (H, W like dims sans time) - token_grids = defaultdict(list) - for stream in tensor_stream.streams: - for event in stream: - token_grids[event.type].append(event.dims(virtual=False)) - embedded_compact = {} - for stream_type, modality_payload_tensor in per_modality_compact_stream.items(): - if stream_type.modality == VisionType: - # Build a (N_events, 2) grid tensor with spatial dims only - grids = token_grids.get(stream_type, []) - if len(grids) == 0: - input_tensor = modality_payload_tensor - else: - token_grids_tensor = torch.tensor(grids, dtype=torch.long, device=tensor_stream.device)[:, 1:] - input_tensor = (modality_payload_tensor, token_grids_tensor) - embedded_compact[stream_type] = self.embed_fns[stream_type.modality](input_tensor) - else: - embedded_compact[stream_type] = self.embed_fns[stream_type.modality](modality_payload_tensor) +@auto_docstring +class IsaacPreTrainedModel(PreTrainedModel): + config: IsaacConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["IsaacDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True - # Reconstruct a TensorStream with embedded payloads and compact - embedded_ts = reconstruct_tensor_stream_from_compact_dict(tensor_stream, embedded_compact) - h = embedded_ts.compact() # (B, T, D) - return h + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": IsaacDecoderLayer, + "attentions": IsaacAttention, + } @auto_docstring @@ -1004,14 +1286,16 @@ class IsaacForConditionalGeneration(IsaacPreTrainedModel, GenerationMixin): config_class = IsaacConfig def __init__(self, config: IsaacConfig): - Qwen3PreTrainedModel.__init__(self, config) + super().__init__(config) + self.model = IsaacModel(config) # Use our custom model self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Tracks rotary position offsets computed during a full forward pass so decode steps can reuse them. self.rope_deltas = None - self.config = config + # Initialize weights and apply final processing + self.post_init() @can_return_tuple @auto_docstring diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 45766db223d9..c2781962fe6d 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -6,275 +6,44 @@ # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ import math import re -from typing import Any, Union +from collections.abc import Sequence -import numpy as np import PIL.Image import torch -import torch.nn.functional as F -from perceptron.tensorstream import Event, Stream, TensorStream, TextType, VisionType, create_stream -from perceptron.tensorstream.ops import slice as ts_slice -from perceptron.tensorstream.ops import tensor_stream_token_view +from genesis.public.tensorstream.tensor_stream import Event, Stream, TensorStream, TextType, VisionType, create_stream +from genesis.public.tensorstream.tensor_stream_utils import slice as ts_slice +from genesis.public.tensorstream.tensor_stream_utils import tensor_stream_token_view -from ...processing_utils import BatchFeature, ProcessorMixin +from ...feature_extraction_utils import BatchFeature +from ...image_processing_utils_fast import BaseImageProcessorFast +from ...processing_utils import ProcessorMixin from ...tokenization_utils import TensorType -from ..auto import AutoTokenizer from .configuration_isaac import IsaacConfig +from .image_processing_isaac import IsaacImageProcessor -# ============================================================================ -# Configuration -# ============================================================================ - -MAX_PIXELS = 60_000_000 # 60-megapixel ceiling โ‰ˆ 8200 ร— 7300 px - # Vision preprocessing constants VISION_MEAN = (0.5, 0.5, 0.5) VISION_STD = (0.5, 0.5, 0.5) VISION_SCALE = 1 / 255 -def _make_writeable(arr: np.ndarray) -> np.ndarray: - """Return *arr* itself if it is already writeable, otherwise try to flip the - write flag in-place and finally fall back to `arr.copy()`. - This guarantees the buffer handed to `torch.from_numpy()` is always - writeable, silencing the PyTorch warning about undefined behaviour. - """ - if arr.flags.writeable: - return arr - - # First, try the cheap path โ€” in-place flag toggle (works for mmap'd arrays - # and some shared memory buffers): - try: - arr.setflags(write=True) - return arr # success: no data copy - except ValueError: - # Buffer is inherently read-only (e.g. backed by PyAV / PIL): make copy - return arr.copy() - - -def extract_image_pil(image: PIL.Image.Image) -> torch.Tensor | None: - if image.width * image.height > MAX_PIXELS: - raise ValueError(f"Image (w={image.width}, h={image.height}) > MAX=`{MAX_PIXELS}`") - img = image if image.mode == "RGB" else image.convert("RGB") - arr = np.asarray(img) - arr = _make_writeable(arr) - return torch.from_numpy(arr) - - -def get_image_size_for_max_num_patches( - image_height: int, - image_width: int, - patch_size: int, - max_num_patches: int, - min_num_patches: int | None = None, - eps: float = 1e-5, - pixel_shuffle_scale: int = 1, -) -> tuple[int, int]: - r"""Compute a target resolution whose patch grid satisfies patching parametrization. - - Args: - image_height (`int`): - Height in pixels of the source image prior to any resizing. - image_width (`int`): - Width in pixels of the source image prior to any resizing. - patch_size (`int`): - Size of the square patch used by the vision encoder. - max_num_patches (`int`): - Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. - min_num_patches (`int`, *optional*): - Lower bound on the number of patches. When provided the image will be scaled up if necessary. - eps (`float`, *optional*, defaults to 1e-5): - Convergence tolerance for the internal binary search to determing the target dimensions. - pixel_shuffle_scale (`int`, *optional*, defaults to 1): - Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. - - Returns: - `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` - and respect both the maximum and optional minimum patch-count constraints. - """ - - def get_scaled_image_size(scale, original_size, patch_size, pixel_shuffle_scale): - scaled_size = scale * original_size - divisor = patch_size * pixel_shuffle_scale - scaled_size = math.ceil(scaled_size / divisor) * divisor - scaled_size = max(divisor, scaled_size) - return int(scaled_size) - - # Ensure divisibility - divisor = patch_size * pixel_shuffle_scale - adjusted_height = math.ceil(image_height / divisor) * divisor - adjusted_height = max(divisor, adjusted_height) - adjusted_width = math.ceil(image_width / divisor) * divisor - adjusted_width = max(divisor, adjusted_width) - - num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) - - if min_num_patches is not None and num_patches < min_num_patches: - # Scale up - scale_min, scale_max = 1.0, 100.0 - while (scale_max - scale_min) >= eps: - scale = (scale_min + scale_max) / 2 - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) - num_patches = (target_height / patch_size) * (target_width / patch_size) - if num_patches >= min_num_patches: - scale_max = scale - else: - scale_min = scale - scale = scale_max - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) - return target_height, target_width - elif num_patches <= max_num_patches: - return adjusted_height, adjusted_width - else: - # Scale down - scale_min, scale_max = eps / 10, 1.0 - while (scale_max - scale_min) >= eps: - scale = (scale_min + scale_max) / 2 - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) - num_patches = (target_height / patch_size) * (target_width / patch_size) - if num_patches <= max_num_patches: - scale_min = scale - else: - scale_max = scale - scale = scale_min - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) - return target_height, target_width - - -_MEAN_TENSOR = torch.tensor(VISION_MEAN, dtype=torch.float32).view(1, 1, 1, -1) -_STD_TENSOR = torch.tensor(VISION_STD, dtype=torch.float32).view(1, 1, 1, -1) - - -def prepare_image_tensor( - image: torch.Tensor, - scale: float = VISION_SCALE, -) -> torch.Tensor: - r"""Standardize RGB images prior to patch extraction via rescaling and whitening. - - Args: - image (`torch.Tensor`): - Tensor with shape `(..., height, width, 3)` containing RGB values. The tensor is converted to floating - point if needed. - scale (`float`, *optional*, defaults to `VISION_SCALE`): - Scalar multiplier applied before normalization. - Returns: - `torch.Tensor`: Normalized tensor with the same shape as the input and dtype `torch.float32`. - """ - if not torch.is_floating_point(image): - image = image.float() - rescaled = image * scale - - # Use precomputed tensors and move to the correct device if needed - mean_tensor = _MEAN_TENSOR.to(image.device) - std_tensor = _STD_TENSOR.to(image.device) - - normalized = (rescaled - mean_tensor) / std_tensor - return normalized - - -def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: - r"""Convert normalized images into flattened ViT-style patches. - - Args: - image (`torch.Tensor`): - Tensor of shape `(num_images, height, width, channels)`. - patch_size (`int`): - Edge length of the square patches - - Returns: - `torch.Tensor`: - Patch tensor where each position stores the flattened pixels belonging to that patch. - - Raises: - ValueError: If `height` or `width` is not divisible by `patch_size`. - """ - num_images, height, width, channels = image.shape - if height % patch_size or width % patch_size: - raise ValueError(f"Dimensions of images {image.shape} are not divisible by patch_size={patch_size}.") - patches = image.reshape(num_images, height // patch_size, patch_size, width // patch_size, patch_size, channels) - patches = patches.permute(0, 1, 3, 2, 4, 5) - patches = patches.reshape( - num_images, height // patch_size, width // patch_size, channels * patch_size * patch_size - ) - return patches - - -def process_vision_for_patches( - images: torch.Tensor, - patch_size: int, - max_num_patches: int, - min_num_patches: int | None = None, - pixel_shuffle_scale: int = 1, -) -> tuple[torch.Tensor, list[int]]: - r"""Resize, normalize, and patchify RGB images for the vision encoder. - - Args: - images (`torch.Tensor`): - Either `(height, width, channels)` for a single image or `(num_images, height, width, channels)` for a - batch. Channels are expected to be RGB. - patch_size (`int`): - Edge length of square patches; implictly controls resize grid granularity. - max_num_patches (`int`): - Maximum number of patches allowed after resizing. - min_num_patches (`int`, *optional*): - Minimum number of patches. If provided, the routine upsamples images as needed to satisfy the lower bound. - pixel_shuffle_scale (`int`, *optional*, defaults to 1): - pixel shuffle scale factor; influences the target grid that the function produces. - - Returns: - `tuple[torch.Tensor, list[int]]`: A pair `(patches, dims_virtual)` where `patches` has shape - `(num_images, target_h / patch_size, target_w / patch_size, channels * patch_size**2)` and `dims_virtual` - encodes effective `(images, height, width)` dimensions after optional pixel shuffling. - """ - # Add batch dim if single image - if images.dim() == 3: - images = images.unsqueeze(0) - - # Permute to channel first for resize - images = images.permute(0, 3, 1, 2) - - # Get target dimensions - _, _, orig_height, orig_width = images.shape - target_height, target_width = get_image_size_for_max_num_patches( - orig_height, - orig_width, - patch_size, - max_num_patches, - min_num_patches=min_num_patches, - pixel_shuffle_scale=pixel_shuffle_scale, - ) - - # Resize - images = F.interpolate( - images, - size=(target_height, target_width), - mode="bilinear", - align_corners=False, - ) +def _normalize_rgb_values( + values: float | Sequence[float] | tuple[float, ...], + *, + name: str, +) -> tuple[float, float, float]: + """Coerce RGB normalization parameters into a 3-tuple of floats.""" + if isinstance(values, (list, tuple)): + if len(values) == 3: + return tuple(float(v) for v in values) + if len(values) == 1: + value = float(values[0]) + return (value, value, value) + raise ValueError(f"`{name}` must have length 1 or 3 when provided as a sequence. Got length {len(values)}.") - # Back to channel last - images = images.permute(0, 2, 3, 1) - - # Normalize - images = prepare_image_tensor(images) - - # Patchify - patches = patchify_vision(images, patch_size=patch_size) - - # Calculate dimensions for the patches - n_images, h_patches, w_patches, _ = patches.shape - dims_virtual = ( - [1, h_patches, w_patches] - if pixel_shuffle_scale == 1 - else [1, h_patches // pixel_shuffle_scale, w_patches // pixel_shuffle_scale] - ) - - return patches, dims_virtual + value = float(values) + return (value, value, value) # ============================================================================ @@ -326,40 +95,104 @@ def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> class IsaacProcessor(ProcessorMixin): - attributes = [] - tokenizer_class = ("AutoTokenizer",) + attributes = ["image_processor", "tokenizer"] + image_processor_class = ("IsaacImageProcessor", "IsaacImageProcessorFast") + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") def __init__( self, - tokenizer: AutoTokenizer, - config: IsaacConfig, - ): - super().__init__() - self.tokenizer = tokenizer - self.config = config + image_processor: IsaacImageProcessor | BaseImageProcessorFast | None = None, + tokenizer: Qwen2Tokenizer | None = None, + *, + vision_token: str = "", + max_sequence_length: int = 16384, + vision_patch_size: int = 16, + vision_max_num_patches: int = 256, + vision_min_num_patches: int | None = None, + pixel_shuffle_scale: int = 1, + rescale_factor: float | None = None, + image_mean: float | Sequence[float] | None = None, + image_std: float | Sequence[float] | None = None, + vision_attn_implementation: str | None = None, + config: IsaacConfig | dict | None = None, + **kwargs, + ) -> None: + if tokenizer is None: + raise ValueError("`tokenizer` must be provided to initialize IsaacProcessor.") + + if isinstance(config, dict): + config = IsaacConfig(**config) + + if config is not None: + vision_patch_size = config.video_patch_size + vision_max_num_patches = config.vision_max_num_patches + vision_min_num_patches = config.vision_min_num_patches + pixel_shuffle_scale = config.pixel_shuffle_scale + max_sequence_length = config.max_sequence_length + vision_token = config.vision_token + vision_attn_implementation = config.vision_attn_implementation + rescale_factor = config.vision_rescale_factor + image_mean = tuple(config.vision_mean) + image_std = tuple(config.vision_std) + + resolved_rescale_factor = float(rescale_factor) if rescale_factor is not None else float(VISION_SCALE) + resolved_image_mean = _normalize_rgb_values( + image_mean if image_mean is not None else VISION_MEAN, + name="image_mean", + ) + resolved_image_std = _normalize_rgb_values( + image_std if image_std is not None else VISION_STD, + name="image_std", + ) - # Use vision token from config - self.vision_token = config.vision_token + if image_processor is None: + image_processor = IsaacImageProcessor( + patch_size=vision_patch_size, + max_num_patches=vision_max_num_patches, + min_num_patches=vision_min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + rescale_factor=resolved_rescale_factor, + image_mean=resolved_image_mean, + image_std=resolved_image_std, + ) + else: + vision_patch_size = getattr(image_processor, "patch_size", vision_patch_size) + vision_max_num_patches = getattr(image_processor, "max_num_patches", vision_max_num_patches) + vision_min_num_patches = getattr(image_processor, "min_num_patches", vision_min_num_patches) + pixel_shuffle_scale = getattr(image_processor, "pixel_shuffle_scale", pixel_shuffle_scale) + resolved_rescale_factor = getattr(image_processor, "rescale_factor", resolved_rescale_factor) + resolved_image_mean = _normalize_rgb_values( + getattr(image_processor, "image_mean", resolved_image_mean), + name="image_mean", + ) + resolved_image_std = _normalize_rgb_values( + getattr(image_processor, "image_std", resolved_image_std), + name="image_std", + ) + + if config is not None: + config.vision_rescale_factor = resolved_rescale_factor + config.vision_mean = resolved_image_mean + config.vision_std = resolved_image_std + + super().__init__(image_processor, tokenizer) + self.current_processor = self.image_processor + self.config = config - # Processing parameters - self.max_sequence_length = config.max_sequence_length + # Mirror tokenizer chat template so ProcessorMixin.apply_chat_template works. + self.chat_template = getattr(self.tokenizer, "chat_template", None) - # Vision processing parameters - self.patch_size = config.video_patch_size - self.max_num_patches = config.vision_max_num_patches - self.min_num_patches = config.vision_min_num_patches - self.pixel_shuffle_scale = config.pixel_shuffle_scale + self.vision_token = vision_token + self.max_sequence_length = max_sequence_length + self.vision_attn_implementation = vision_attn_implementation - def apply_chat_template( - self, - messages: list[dict[str, Any]], - tokenize: bool = False, - add_generation_prompt: bool = False, - **kwargs, - ) -> Any: - return self.tokenizer.apply_chat_template( - messages, tokenize=tokenize, add_generation_prompt=add_generation_prompt, **kwargs - ) + self.patch_size = getattr(self.image_processor, "patch_size", vision_patch_size) + self.max_num_patches = getattr(self.image_processor, "max_num_patches", vision_max_num_patches) + self.min_num_patches = getattr(self.image_processor, "min_num_patches", vision_min_num_patches) + self.pixel_shuffle_scale = getattr(self.image_processor, "pixel_shuffle_scale", pixel_shuffle_scale) + self.rescale_factor = getattr(self.image_processor, "rescale_factor", resolved_rescale_factor) + self.image_mean = tuple(getattr(self.image_processor, "image_mean", resolved_image_mean)) + self.image_std = tuple(getattr(self.image_processor, "image_std", resolved_image_std)) def build_event_stream_simple( self, @@ -377,68 +210,40 @@ def build_event_stream_simple( for current_time, part in enumerate(parts): if part == self.vision_token: # Replace vision token with image event - if image_idx < len(images): - # Create vision event from PIL image - image_tensor = extract_image_pil(images[image_idx]) - if image_tensor is not None: - # Create a vision event with the image tensor - vision_event = Event( - data=image_tensor.unsqueeze(0), # HWC format from extract_image_pil - type=VisionType.image, # I-frame - time=(current_time, current_time), - ) - events.append(vision_event) - image_idx += 1 - elif part: # Non-empty text part - # tokens = self.text_processor.tokenize(part, add_special_tokens=False) - text_event = create_text_event(self.tokenizer, part, time=current_time) - events.append(text_event) + if images is None or image_idx >= len(images): + raise ValueError("Encountered vision token without a corresponding image.") - # Process vision events if any - if any(event.type == VisionType.image for event in events): - # Separate text and vision events for processing - text_events = [event for event in events if event.type == TextType.text] - vision_events = [event for event in events if event.type == VisionType.image] - - # Process vision events using functional approach - processed_vision_events = [] - for vision_event in vision_events: - # Process the vision data - patches, dims_virtual = process_vision_for_patches( - vision_event.data.squeeze(0), # Remove the extra dimension - patch_size=self.patch_size, - max_num_patches=self.max_num_patches, - min_num_patches=self.min_num_patches, - pixel_shuffle_scale=self.pixel_shuffle_scale, + features = self.image_processor( + images=images[image_idx], + return_tensors=TensorType.PYTORCH, ) - # Update event with processed data - vision_event.data = patches.unsqueeze(1) # Add back frame dimension - vision_event.dims_virtual = dims_virtual - vision_event.dims_real = ( - dims_virtual - if self.pixel_shuffle_scale == 1 - else [ - dims_virtual[0], - dims_virtual[1] * self.pixel_shuffle_scale, - dims_virtual[2] * self.pixel_shuffle_scale, - ] + patches = features["patches"][0] # (H_tokens, W_tokens, embed) + virtual_dims = features["virtual_pixel_size"][0].tolist() + real_dims = features["real_pixel_size"][0].tolist() + + vision_event = Event( + data=patches.reshape(-1, patches.shape[-1]), + type=VisionType.image, + time=(current_time, current_time), + dims_virtual=virtual_dims, + dims_real=real_dims, + idx_range=(0, math.prod(virtual_dims)), ) - vision_event.idx_range = (0, math.prod(dims_virtual)) - - # Flatten the patches - vision_event.data = vision_event.data.reshape(-1, vision_event.data.shape[-1]) - processed_vision_events.append(vision_event) - - events = text_events + processed_vision_events + events.append(vision_event) + image_idx += 1 + elif part: # Non-empty text part + # tokens = self.text_processor.tokenize(part, add_special_tokens=False) + text_event = create_text_event(self.tokenizer, part, time=current_time) + events.append(text_event) # Create stream without scheduling (events already in order) return create_stream(events, priority=[TextType.text, VisionType.image], schedule=True) def __call__( self, - text: Union[str, list[str]], - images: Union[PIL.Image.Image, list[PIL.Image.Image], None] = None, + text: str | list[str], + images: PIL.Image.Image | list[PIL.Image.Image] | None = None, return_tensors: str | TensorType | None = TensorType.PYTORCH, **kwargs, ) -> BatchFeature: From c3cc42d0315c17d562f92e5716f16dd048718147 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Wed, 15 Oct 2025 12:54:40 +0400 Subject: [PATCH 05/77] chore: port updates --- .../models/isaac/modular_isaac.py | 1195 +++++++++-------- 1 file changed, 651 insertions(+), 544 deletions(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index a51950f4a7dc..726c55aa56f6 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -1,26 +1,96 @@ +# Perceptron, Inc. Non-Production License + +### 1. Scope and acceptance + +# **1.1. Scope of the Agreement.** +# This Agreement applies to any use, modification, or Distribution of any Perceptron Model by You, regardless of the source You obtained a copy of such Perceptron Model. +# +# **1.2. Acceptance.** By accessing, using, modifying, Distributing a Perceptron Model, or by creating, using or distributing a Derivative of the Perceptron Model, You agree to be bound by this Agreement. +# +# **1.3. Acceptance on behalf of a third-party.** If You accept this Agreement on behalf of Your employer or another person or entity, You warrant and represent that You have the authority to act and accept this Agreement on their behalf. In such a case, the word โ€œYouโ€ in this Agreement will refer to Your employer or such other person or entity. +# +# ## 2. License +# **2.1. Grant of rights.** Subject to Section 3 below, Perceptron, Inc. hereby grants You a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable, limited license to use, copy, modify, and Distribute under the conditions provided in Section 2.2 below, the Perceptron Model and any Derivatives made by or for Perceptron, Inc. and to create Derivatives of the Perceptron Model. +# +# **2.2. Distribution of Perceptron Model and Derivatives made by or for Perceptron, Inc..** Subject to Section 3 below, You may Distribute copies of the Perceptron Model and/or Derivatives made by or for Perceptron, Inc., under the following conditions: +# - You must make available a copy of this Agreement to third-party recipients of the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. you Distribute, it being specified that any rights to use the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. shall be directly granted by Perceptron, Inc. to said third-party recipients pursuant to the Perceptron, Inc. Non-Production License agreement executed between these parties; +# - You must retain in all copies of the Perceptron Models the following attribution notice within a โ€œNoticeโ€ text file distributed as part of such copies: โ€œLicensed by Perceptron, Inc. under the Perceptron, Inc. Non-Production Licenseโ€. +# +# **2.3. Distribution of Derivatives made by or for You.** Subject to Section 3 below, You may Distribute any Derivatives made by or for You under additional or different terms and conditions, provided that: +# - In any event, the use and modification of Perceptron Model and/or Derivatives made by or for Perceptron, Inc. shall remain governed by the terms and conditions of this Agreement; +# - You include in any such Derivatives made by or for You prominent notices stating that You modified the concerned Perceptron Model; and +# - Any terms and conditions You impose on any third-party recipients relating to Derivatives made by or for You shall neither limit such third-party recipientsโ€™ use of the Perceptron Model or any Derivatives made by or for Perceptron, Inc. in accordance with the Perceptron, Inc. Non-Production License nor conflict with any of its terms and conditions. +# +# ## 3. Limitations +# **3.1. Misrepresentation.** You must not misrepresent or imply, through any means, that the Derivatives made by or for You and/or any modified version of the Perceptron Model You Distribute under your name and responsibility is an official product of Perceptron, Inc. or has been endorsed, approved or validated by Perceptron, Inc., unless You are authorized by Us to do so in writing. +# +# **3.2. Usage Limitation** +# - You shall only use the Perceptron Models and Derivatives (whether or not created by Perceptron, Inc.) for testing, research, Personal, or evaluation purposes in Non-Production Environments; +# - Subject to the foregoing, You shall not supply the Perceptron Models or Derivatives in the course of a commercial activity, whether in return for payment or free of charge, in any medium or form, including but not limited to through a hosted or managed service (e.g. SaaS, cloud instances, etc.), or behind a software layer. +# +# **3.3. Usage not permitted under this Agreement.** If You want to use a Perceptron Model or a Derivative for any purpose that is not expressly authorized under this Agreement, You must request a license from Perceptron, Inc., which Perceptron, Inc. may grant to You in Perceptron, Inc.โ€™s sole discretion. Please contact Perceptron, Inc. at the following e-mail address if You want to discuss such a license: sales@perceptron.inc +# +# ## 4. Intellectual Property +# **4.1. Trademarks.** No trademark licenses are granted under this Agreement, and in connection with the Perceptron Models, You may not use any name or mark owned by or associated with Perceptron, Inc. or any of its affiliates, except (i) as required for reasonable and customary use in describing and Distributing the Perceptron Models and Derivatives made by or for Perceptron, Inc. and (ii) for attribution purposes as required by this Agreement. +# +# **4.2. Outputs.** We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs You generate and their subsequent uses in accordance with this Agreement. +# +# **4.3. Derivatives.** By entering into this Agreement, You accept that any Derivatives that You may create or that may be created for You shall be subject to the restrictions set out in Section 3 of this Agreement. +# +# # 5. Liability +# **5.1. Limitation of liability.** In no event, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall Perceptron, Inc. be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Perceptron Models and Derivatives (including but not limited to damages for loss of data, loss of goodwill, loss of expected profit or savings, work stoppage, computer failure or malfunction, or any damage caused by malware or security breaches), even if Perceptron, Inc. has been advised of the possibility of such damages. +# +# **5.2. Indemnification.** You agree to indemnify and hold harmless Perceptron, Inc. from and against any claims, damages, or losses arising out of or related to Your use or Distribution of the Perceptron Models and Derivatives. +# +# ## 6. Warranty +# **6.1. Disclaimer.** Unless required by applicable law or agreed to in writing, Perceptron, Inc. provides the Perceptron Models and Derivatives on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. Perceptron, Inc. does not represent nor warrant that the Perceptron Models and Derivatives will be error-free, meet Your or any third partyโ€™s requirements, be secure or will allow You or any third party to achieve any kind of result or generate any kind of content. You are solely responsible for determining the appropriateness of using or Distributing the Perceptron Models and Derivatives and assume any risks associated with Your exercise of rights under this Agreement. +# +# # 7. Termination +# **7.1. Term.** This Agreement is effective as of the date of your acceptance of this Agreement or access to the concerned Perceptron Models or Derivatives and will continue until terminated in accordance with the following terms. +# +# **7.2. Termination.** Perceptron, Inc. may terminate this Agreement at any time if You are in breach of this Agreement. Upon termination of this Agreement, You must cease to use all Perceptron Models and Derivatives and shall permanently delete any copy thereof. Sections 5, 6, 7 and 8 shall survive the termination of this Agreement. +# +# **7.3. Litigation.** If You initiate any legal action or proceedings against Us or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the Model or a Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement will immediately terminate as of the date such legal action or claim is filed or initiated. +# +# # 8. General provisions +# 8.1. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the State of Washington, without regard to its conflict of law principles. +# +# 8.2. Jurisdiction. The state and federal courts located in King County, Washington shall have exclusive jurisdiction over any dispute arising out of or relating to this Agreement, and You and We consent to personal jurisdiction and venue in such courts. +# +# **8.3. Severability.** If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. +# +# # 9. Definitions +# **โ€œAgreementโ€**: means this Perceptron, Inc. Non-Production License agreement governing the access, use, and Distribution of the Perceptron Models and Derivatives. +# +# **โ€œDerivativeโ€**: means any (i) modified version of the Perceptron Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the Perceptron Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered as Derivatives under this Agreement. +# +# **โ€œDistributionโ€**, **โ€œDistributingโ€**, **โ€œDistributeโ€** or **โ€œDistributedโ€**: means providing or making available, by any means, a copy of the Perceptron Models and/or the Derivatives as the case may be, subject to Section 3 of this Agreement. +# +# **โ€œPerceptron, Inc.โ€**, **โ€œWeโ€** or **โ€œUsโ€**: means Perceptron, Inc., a Delaware corporation with its principal place of business at 10900 NE 8th St Suite 613, Bellevue, WA 98004. +# +# **โ€œPerceptron Modelโ€**: means the foundational large language model(s), and its elements which include algorithms, software, instructed checkpoints, parameters, source code (inference code, evaluation code and, if applicable, fine-tuning code) and any other elements associated thereto made available by Perceptron, Inc. under this Agreement, including, if any, the technical documentation, manuals and instructions for the use and operation thereof. +# +# **โ€œNon-Production Environmentโ€**: means any setting, use case, or application of the Perceptron Models or Derivatives that expressly excludes live, real-world conditions, commercial operations, revenue-generating activities, or direct interactions with or impacts on end users (such as, for instance, Your employees or customers). Non-Production Environment may include, but is not limited to, any setting, use case, or application for research, development, testing, quality assurance, training, internal evaluation (other than any internal usage by employees in the context of the companyโ€™s business activities), and demonstration purposes. +# +# **โ€œOutputsโ€**: means any content generated by the operation of the Perceptron Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a Perceptron Models, such as any fine-tuned versions of the Perceptron Models, the weights, or parameters. +# +# **โ€œPersonalโ€**: means any use of a Perceptron Model or a Derivative that is (i) solely for personal, non-profit and non-commercial purposes and (ii) not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities. For illustration purposes, Personal use of a Model or a Derivative does not include any usage by individuals employed in companies in the context of their daily tasks, any activity that is intended to generate revenue, or that is performed on behalf of a commercial entity. +# +# **โ€œYouโ€**: means the individual or entity entering into this Agreement with Perceptron, Inc.. + from __future__ import annotations import copy import math +import re from collections import defaultdict from collections.abc import Sequence -from typing import Any, TypedDict +from typing import Any, Callable, Optional, TypedDict, Union -import numpy as np import PIL.Image import torch import torch.nn as nn import torch.nn.functional as F - - -try: - from torchvision.transforms.v2 import functional as TVF -except ImportError: - TVF = None - - -import re - from genesis.public.tensorstream.tensor_stream import ( Event, Stream, @@ -40,72 +110,50 @@ slice as ts_slice, ) -from ...cache_utils import Cache, SlidingWindowCache, StaticCache -from ...feature_extraction_utils import BatchFeature -from ...generation.utils import GenerationMixin -from ...image_processing_utils import BaseImageProcessor -from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict -from ...image_transforms import convert_to_rgb -from ...image_utils import ( - ImageInput, +from transformers import ( + AutoImageProcessor, + AutoModel, + AutoTokenizer, + BatchFeature, + Cache, + Qwen3Config, + Qwen3ForCausalLM, + Qwen3PreTrainedModel, +) +from transformers.cache_utils import SlidingWindowCache, StaticCache +from transformers.generation.utils import GenerationMixin +from transformers.image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + SizeDict, + group_images_by_shape, + reorder_images, +) +from transformers.image_utils import ( + ChannelDimension, PILImageResampling, - make_flat_list_of_images, - to_numpy_array, - valid_images, - validate_preprocess_arguments, ) -from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...processing_utils import ImagesKwargs, ProcessorMixin, Unpack -from ...tokenization_utils import TensorType -from ...utils import auto_docstring, filter_out_non_signature_kwargs -from ...utils.import_utils import is_torchdynamo_compiling -from ..auto.image_processing_auto import AutoImageProcessor -from ..auto.modeling_auto import AutoModel -from ..auto.tokenization_auto import AutoTokenizer -from ..qwen2.tokenization_qwen2 import Qwen2Tokenizer -from ..qwen3.configuration_qwen3 import Qwen3Config -from ..qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3PreTrainedModel -from ..siglip2.configuration_siglip2 import Siglip2VisionConfig -from ..siglip2.modeling_siglip2 import Siglip2EncoderLayer as HFSiglip2EncoderLayer - +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer +from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig +from transformers.models.siglip2.modeling_siglip2 import Siglip2EncoderLayer as HFSiglip2EncoderLayer +from transformers.processing_utils import ProcessorMixin, Unpack +from transformers.tokenization_utils import TensorType +from transformers.utils import auto_docstring # Vision preprocessing constants -VISION_MEAN = (0.5, 0.5, 0.5) -VISION_STD = (0.5, 0.5, 0.5) -VISION_SCALE = 1 / 255 - - - +from transformers.utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN +from transformers.utils.constants import IMAGENET_STANDARD_STD as VISION_STD +from transformers.utils.import_utils import is_torchdynamo_compiling -def _normalize_rgb_values( - values: float | Sequence[float] | tuple[float, ...], - *, - name: str, -) -> tuple[float, float, float]: - """Coerce RGB normalization parameters into a 3-tuple of floats.""" - if isinstance(values, (list, tuple)): - if len(values) == 3: - return tuple(float(v) for v in values) - if len(values) == 1: - value = float(values[0]) - return (value, value, value) - raise ValueError(f"`{name}` must have length 1 or 3 when provided as a sequence. Got length {len(values)}.") - - value = float(values) - return (value, value, value) - - -def _make_writeable(arr: np.ndarray) -> np.ndarray: - if arr.flags.writeable: - return arr - try: - arr.setflags(write=True) - return arr - except ValueError: - return arr.copy() +_ORIGINAL_ATTENTION_FUNCTIONS: dict[str, Callable[..., tuple[torch.Tensor, Optional[torch.Tensor]]]] = {} +for _attn_name in ("flash_attention_2", "sdpa", "eager"): + if _attn_name in ALL_ATTENTION_FUNCTIONS: + _ORIGINAL_ATTENTION_FUNCTIONS[_attn_name] = ALL_ATTENTION_FUNCTIONS[_attn_name] class IsaacVisionConfig(Siglip2VisionConfig): @@ -122,6 +170,7 @@ class IsaacVisionConfig(Siglip2VisionConfig): model_type = "isaac_vision" base_config_key = "vision_config" + _attn_implementation: str | None = None def __init__( self, @@ -135,126 +184,164 @@ def __init__( self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor self.num_patches = num_patches + if self._attn_implementation is None: + self._attn_implementation = "flash_attention_2" + + @property + def attn_implementation(self) -> str | None: + return self._attn_implementation + + @attn_implementation.setter + def attn_implementation(self, value: str | None) -> None: + self._attn_implementation = value + -class IsaacImageProcessorKwargs(ImagesKwargs): +class IsaacImageProcessorKwargs(DefaultFastImageProcessorKwargs, total=False): patch_size: int | None max_num_patches: int | None min_num_patches: int | None pixel_shuffle_scale: int | None - do_rescale: bool | None - rescale_factor: float | None - do_normalize: bool | None - image_mean: float | Sequence[float] | None - image_std: float | Sequence[float] | None - do_convert_rgb: bool | None @auto_docstring class IsaacImageProcessorFast(BaseImageProcessorFast): + MAX_PIXELS = 60_000_000 # 60โ€‘megapixel ceiling โ‰ˆ 8200 ร— 7300 px r"""Fast torch-based image processor for Isaac vision inputs.""" - slow_image_processor_class = "IsaacImageProcessor" resample = PILImageResampling.BILINEAR model_input_names = ["patches", "token_grids"] valid_kwargs = IsaacImageProcessorKwargs unused_kwargs = ["size", "do_center_crop", "crop_size"] + do_resize = True + size: SizeDict | None = None + default_to_square: bool | None = None + do_center_crop = False + crop_size: SizeDict | None = None + patch_size: int | None = 16 + max_num_patches: int | None = 256 + min_num_patches: int | None = None + pixel_shuffle_scale: int | None = 1 + do_pad = False + pad_size: SizeDict | None = None + do_rescale = True + rescale_factor = 1 / 255 + do_normalize = True + image_mean = list(VISION_MEAN) + image_std = list(VISION_STD) + do_convert_rgb = True + return_tensors = None + data_format = ChannelDimension.FIRST + input_data_format = None + device = None + disable_grouping = False + size_divisor: int | None = None + def __init__( self, - *, - patch_size: int = 16, - max_num_patches: int = 256, - min_num_patches: int | None = None, - pixel_shuffle_scale: int = 1, - do_rescale: bool = True, - rescale_factor: float | None = None, - do_normalize: bool = True, - image_mean: float | Sequence[float] | None = None, - image_std: float | Sequence[float] | None = None, - do_convert_rgb: bool = True, **kwargs: Unpack[IsaacImageProcessorKwargs], ) -> None: super().__init__(**kwargs) + pixel_shuffle_scale = 1 if self.pixel_shuffle_scale is None else int(self.pixel_shuffle_scale) if pixel_shuffle_scale < 1: raise ValueError("`pixel_shuffle_scale` must be >= 1") - - mean_values = _normalize_rgb_values( - image_mean if image_mean is not None else VISION_MEAN, name="image_mean" - ) - std_values = _normalize_rgb_values( - image_std if image_std is not None else VISION_STD, name="image_std" - ) - - self.patch_size = patch_size - self.max_num_patches = max_num_patches - self.min_num_patches = min_num_patches self.pixel_shuffle_scale = pixel_shuffle_scale - self.do_rescale = do_rescale - self.rescale_factor = VISION_SCALE if rescale_factor is None else float(rescale_factor) - self.do_normalize = do_normalize - self.image_mean = list(mean_values) - self.image_std = list(std_values) - self.do_convert_rgb = do_convert_rgb + def _validate_preprocess_kwargs(self, **kwargs): # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for. kwargs.pop("do_resize", None) + kwargs.pop("size", None) + kwargs.pop("do_center_crop", None) + kwargs.pop("crop_size", None) + kwargs.pop("disable_grouping", None) return super()._validate_preprocess_kwargs(**kwargs) + def resize( + self, + image: torch.Tensor, + size: SizeDict, + interpolation: Optional[Any] = None, + antialias: bool = True, + **kwargs, + ) -> torch.Tensor: + if size.height is None or size.width is None: + raise ValueError("IsaacImageProcessorFast requires explicit `height` and `width` when resizing.") + + resize_mode: Any = interpolation + if hasattr(resize_mode, "value"): + resize_mode = resize_mode.value + elif hasattr(resize_mode, "name"): + resize_mode = resize_mode.name.lower() + elif resize_mode is None: + resize_mode = "bilinear" + + if isinstance(resize_mode, str): + mode_key = resize_mode.lower() + else: + mode_key = resize_mode + + resize_kwargs: dict[str, Any] = {} + if mode_key in {"linear", "bilinear", "bicubic", "trilinear"}: + resize_kwargs["align_corners"] = False + + return F.interpolate( + image, + size=(size.height, size.width), + mode=resize_mode, + **resize_kwargs, + ) + def _preprocess( self, - images: list["torch.Tensor"], + images: list[torch.Tensor], do_resize: bool, - patch_size: int, - max_num_patches: int, - interpolation: Any | None, - do_rescale: bool, - rescale_factor: float, - do_normalize: bool, - image_mean: float | Sequence[float] | None, - image_std: float | Sequence[float] | None, - return_tensors: str | TensorType | None, + size: Optional[SizeDict], + interpolation: Optional[Any], + do_center_crop: bool, + crop_size: Optional[SizeDict], + do_rescale: Optional[bool], + rescale_factor: Optional[float], + do_normalize: Optional[bool], + image_mean: Optional[Union[float, Sequence[float]]], + image_std: Optional[Union[float, Sequence[float]]], + disable_grouping: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + do_pad: Optional[bool] = None, + pad_size: Optional[SizeDict] = None, *, + patch_size: int | None = None, + max_num_patches: int | None = None, min_num_patches: int | None = None, pixel_shuffle_scale: int | None = None, - do_convert_rgb: bool | None = None, **kwargs, ) -> BatchFeature: - if TVF is None: - raise ImportError("torchvision is required for IsaacImageProcessorFast but is not installed.") - - min_num_patches = min_num_patches if min_num_patches is not None else self.min_num_patches - pixel_shuffle_scale = pixel_shuffle_scale if pixel_shuffle_scale is not None else self.pixel_shuffle_scale - do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - do_rescale = self.do_rescale if do_rescale is None else do_rescale - do_normalize = self.do_normalize if do_normalize is None else do_normalize - rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor - - mean_values = _normalize_rgb_values( - image_mean if image_mean is not None else self.image_mean, name="image_mean" - ) - std_values = _normalize_rgb_values( - image_std if image_std is not None else self.image_std, name="image_std" - ) + if do_center_crop: + raise ValueError("`do_center_crop` is not supported by IsaacImageProcessorFast.") + if do_pad: + raise ValueError("`do_pad` is not supported by IsaacImageProcessorFast.") - patches_list: list[torch.Tensor] = [] - token_grids: list[torch.Tensor] = [] - virtual_dims: list[list[int]] = [] - real_dims: list[list[int]] = [] - for image in images: - if image.ndim != 3: - raise ValueError("Expected channel-first image tensor with shape (C, H, W).") + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + processed_patches_grouped: dict[tuple[int, ...], torch.Tensor] = {} + token_grids_grouped: dict[tuple[int, ...], torch.Tensor] = {} + virtual_dims_grouped: dict[tuple[int, ...], torch.Tensor] = {} + real_dims_grouped: dict[tuple[int, ...], torch.Tensor] = {} - channels, original_height, original_width = image.shape - if do_convert_rgb and channels == 1: - image = image.repeat(3, 1, 1) + for shape, stacked_images in grouped_images.items(): + if stacked_images.ndim != 4: + raise ValueError("Expected batched channel-first image tensors.") + + batch_size, channels, original_height, original_width = stacked_images.shape + + if bool(self.do_convert_rgb) and channels == 1: + stacked_images = stacked_images.repeat(1, 3, 1, 1) channels = 3 - if original_height * original_width > MAX_PIXELS: + if original_height * original_width > self.MAX_PIXELS: raise ValueError( - f"Image (w={original_width}, h={original_height}) > MAX=`{MAX_PIXELS}`" + f"Image (w={original_width}, h={original_height}) > MAX=`{self.MAX_PIXELS}`" ) target_height, target_width = get_image_size_for_max_num_patches( @@ -267,54 +354,81 @@ def _preprocess( ) if do_resize: - size_dict = SizeDict(height=target_height, width=target_width) - image = self.resize(image=image, size=size_dict, interpolation=interpolation) + resize_size = SizeDict(height=target_height, width=target_width) + image_batch = self.resize( + image=stacked_images, + size=resize_size, + interpolation=interpolation, + ) else: if ((original_height % patch_size) != 0) or ((original_width % patch_size) != 0): raise ValueError( "Image dimensions must be divisible by patch_size when resize is disabled." ) + image_batch = stacked_images + target_height, target_width = original_height, original_width + + if do_rescale: + image_batch = self.rescale_and_normalize( + image_batch, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + ) - # Apply rescaling and normalization as needed - image = self.rescale_and_normalize( - image, - do_rescale, - rescale_factor, - do_normalize, - list(mean_values), - list(std_values), - ) + nhwc_images = image_batch.permute(0, 2, 3, 1) + nhwc_images = _compute_residual_p_frames(nhwc_images, is_p_frame=[False] * batch_size) - # Convert to NHWC for residual P-frame adjustment and patch extraction - nhwc_image = image.permute(1, 2, 0).unsqueeze(0) - nhwc_image = _compute_residual_p_frames(nhwc_image, is_p_frame=[False]) + patches = patchify_vision(nhwc_images, patch_size=patch_size) + _, height_tokens, width_tokens, _ = patches.shape - patches = patchify_vision(nhwc_image, patch_size=patch_size).squeeze(0) - height_tokens, width_tokens, _ = patches.shape + token_grid = torch.tensor( + [height_tokens, width_tokens], + dtype=torch.long, + device=patches.device, + ).unsqueeze(0).repeat(batch_size, 1) - patches_list.append(patches.unsqueeze(0)) - token_grids.append( - torch.tensor([height_tokens, width_tokens], dtype=torch.long, device=patches.device) - ) + real_dim = torch.tensor( + [1, height_tokens, width_tokens], + dtype=torch.long, + device=patches.device, + ).unsqueeze(0).repeat(batch_size, 1) - real_dims.append([1, height_tokens, width_tokens]) if pixel_shuffle_scale > 1: if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): raise ValueError( "Spatial dimensions must be divisible by pixel_shuffle_scale when pixel shuffle is enabled." ) - virtual_dims.append( - [1, height_tokens // pixel_shuffle_scale, width_tokens // pixel_shuffle_scale] - ) + virtual_height = height_tokens // pixel_shuffle_scale + virtual_width = width_tokens // pixel_shuffle_scale else: - virtual_dims.append([1, height_tokens, width_tokens]) - - patches_tensor = torch.cat(patches_list, dim=0) - token_grids_tensor = torch.stack(token_grids, dim=0) - virtual_dims_tensor = torch.tensor(virtual_dims, dtype=torch.long, device=patches_tensor.device) - real_dims_tensor = torch.tensor(real_dims, dtype=torch.long, device=patches_tensor.device) - - batch_feature = BatchFeature( + virtual_height = height_tokens + virtual_width = width_tokens + + virtual_dim = torch.tensor( + [1, virtual_height, virtual_width], + dtype=torch.long, + device=patches.device, + ).unsqueeze(0).repeat(batch_size, 1) + + processed_patches_grouped[shape] = patches + token_grids_grouped[shape] = token_grid + virtual_dims_grouped[shape] = virtual_dim + real_dims_grouped[shape] = real_dim + + patches_slices = reorder_images(processed_patches_grouped, grouped_images_index) + token_grid_slices = reorder_images(token_grids_grouped, grouped_images_index) + virtual_dim_slices = reorder_images(virtual_dims_grouped, grouped_images_index) + real_dim_slices = reorder_images(real_dims_grouped, grouped_images_index) + + patches_tensor = torch.stack(patches_slices, dim=0) + token_grids_tensor = torch.stack(token_grid_slices, dim=0) + virtual_dims_tensor = torch.stack(virtual_dim_slices, dim=0) + real_dims_tensor = torch.stack(real_dim_slices, dim=0) + + return BatchFeature( data={ "patches": patches_tensor, "token_grids": token_grids_tensor, @@ -323,223 +437,57 @@ def _preprocess( }, tensor_type=return_tensors, ) - return batch_feature -class IsaacImageProcessor(BaseImageProcessor): - """Image processor that prepares RGB frames for the Isaac vision encoder.""" - - model_input_names = ["patches", "token_grids"] - - def __init__( - self, - patch_size: int = 16, - max_num_patches: int = 256, - min_num_patches: int | None = None, - pixel_shuffle_scale: int = 1, - do_rescale: bool = True, - rescale_factor: float | None = None, - do_normalize: bool = True, - image_mean: float | Sequence[float] | None = None, - image_std: float | Sequence[float] | None = None, - do_convert_rgb: bool = True, - resize_mode: str = "bilinear", - align_corners: bool = False, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - if pixel_shuffle_scale < 1: - raise ValueError("`pixel_shuffle_scale` must be >= 1") - - rescale_value = VISION_SCALE if rescale_factor is None else float(rescale_factor) - mean_value = VISION_MEAN if image_mean is None else image_mean - std_value = VISION_STD if image_std is None else image_std - - self.patch_size = patch_size - self.max_num_patches = max_num_patches - self.min_num_patches = min_num_patches - self.pixel_shuffle_scale = pixel_shuffle_scale - self.do_rescale = do_rescale - self.rescale_factor = rescale_value - self.do_normalize = do_normalize - self.image_mean = _normalize_rgb_values(mean_value, name="image_mean") - self.image_std = _normalize_rgb_values(std_value, name="image_std") - self.do_convert_rgb = do_convert_rgb - self.resize_mode = resize_mode - self.align_corners = align_corners - - @filter_out_non_signature_kwargs() - def preprocess( - self, - images: ImageInput, - patch_size: int | None = None, - max_num_patches: int | None = None, - min_num_patches: int | None = None, - pixel_shuffle_scale: int | None = None, - do_rescale: bool | None = None, - rescale_factor: float | None = None, - do_normalize: bool | None = None, - image_mean: float | Sequence[float] | None = None, - image_std: float | Sequence[float] | None = None, - do_convert_rgb: bool | None = None, - return_tensors: str | TensorType | None = None, - ) -> BatchFeature: - patch_size = patch_size if patch_size is not None else self.patch_size - max_num_patches = max_num_patches if max_num_patches is not None else self.max_num_patches - min_num_patches = min_num_patches if min_num_patches is not None else self.min_num_patches - pixel_shuffle_scale = ( - pixel_shuffle_scale if pixel_shuffle_scale is not None else self.pixel_shuffle_scale - ) - do_rescale = self.do_rescale if do_rescale is None else do_rescale - rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor - do_normalize = self.do_normalize if do_normalize is None else do_normalize - image_mean = self.image_mean if image_mean is None else _normalize_rgb_values(image_mean, name="image_mean") - image_std = self.image_std if image_std is None else _normalize_rgb_values(image_std, name="image_std") - do_convert_rgb = self.do_convert_rgb if do_convert_rgb is None else do_convert_rgb - - images = self.fetch_images(images) - images = make_flat_list_of_images(images) - - if not images: - raise ValueError("Received an empty list of images for preprocessing.") - if do_convert_rgb: - images = [convert_to_rgb(image) for image in images] - - if not valid_images(images): - raise ValueError( - "Invalid image type. Expected PIL images, numpy arrays, or tensors convertible to numpy arrays." - ) - - validate_preprocess_arguments( - do_rescale=do_rescale, - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - ) - - patches_list = [] - token_grids = [] - virtual_dims = [] - real_dims = [] - - for image in images: - np_image = to_numpy_array(image) - - if np_image.ndim == 2: - np_image = np.repeat(np_image[..., None], 3, axis=-1) - - height, width = np_image.shape[:2] - if height * width > MAX_PIXELS: - raise ValueError(f"Image (w={width}, h={height}) > MAX=`{MAX_PIXELS}`") - - torch_image = torch.from_numpy(_make_writeable(np_image)) - patches, vidims, rdims = self._process_single_image( - torch_image, - patch_size=patch_size, - max_num_patches=max_num_patches, - min_num_patches=min_num_patches, - pixel_shuffle_scale=pixel_shuffle_scale, - do_rescale=do_rescale, - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - ) - - patches_list.append(patches) - token_grids.append(torch.tensor([patches.size(1), patches.size(2)], dtype=torch.long)) - virtual_dims.append(vidims) - real_dims.append(rdims) - - patches_tensor = torch.cat(patches_list, dim=0) - token_grid_tensor = torch.stack(token_grids, dim=0) - virtual_dims_tensor = torch.tensor(virtual_dims, dtype=torch.long) - real_dims_tensor = torch.tensor(real_dims, dtype=torch.long) - - data = { - "patches": patches_tensor, - "token_grids": token_grid_tensor, - "virtual_pixel_size": virtual_dims_tensor, - "real_pixel_size": real_dims_tensor, - } +def _max_from_cu(cu: torch.Tensor | None, fallback: int) -> int: + """Helper to compute max sequence length from cumulative sequence lengths.""" + if cu is None or len(cu) < 2: + return fallback + return int((cu[1:] - cu[:-1]).max().item()) - return BatchFeature(data=data, tensor_type=return_tensors) - def _process_single_image( - self, - image: torch.Tensor, - *, - patch_size: int, - max_num_patches: int, - min_num_patches: int | None, - pixel_shuffle_scale: int, - do_rescale: bool, - rescale_factor: float, - do_normalize: bool, - image_mean: tuple[float, ...], - image_std: tuple[float, ...], - ) -> tuple[torch.Tensor, list[int], list[int]]: - image_uint8 = image.unsqueeze(0) # (1, H, W, C) - image_chw = image_uint8.permute(0, 3, 1, 2) # (1, C, H, W) - - _, _, orig_height, orig_width = image_chw.shape - target_height, target_width = get_image_size_for_max_num_patches( - orig_height, - orig_width, - patch_size, - max_num_patches, - min_num_patches=min_num_patches, - pixel_shuffle_scale=pixel_shuffle_scale, - ) - - if self.resize_mode in {"linear", "bilinear", "bicubic", "trilinear"}: - resized = F.interpolate( - image_chw, - size=(target_height, target_width), - mode=self.resize_mode, - align_corners=self.align_corners, - ) - else: - resized = F.interpolate( - image_chw, - size=(target_height, target_width), - mode=self.resize_mode, - ) +def build_document_attention_mask( + cu_seqlens: torch.Tensor | None, + total_tokens: int, + dtype: torch.dtype, + device: torch.device, +) -> torch.Tensor | None: + """Creates an additive attention mask that blocks cross-document attention.""" - resized = resized.permute(0, 2, 3, 1) # (1, H, W, C) + if cu_seqlens is None: + return None - scale = rescale_factor if do_rescale else 1.0 - mean = image_mean if do_normalize else (0.0, 0.0, 0.0) - std = image_std if do_normalize else (1.0, 1.0, 1.0) - resized = _prepare_image_tensor(resized, scale=scale, mean=mean, std=std) + if cu_seqlens.numel() < 2: + return None - resized = _compute_residual_p_frames(resized, is_p_frame=[False]) + seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() + if seq_sizes.numel() == 0: + return None - patches = patchify_vision(resized, patch_size=patch_size) - _, h_patches, w_patches, _ = patches.shape + seg_ids = torch.repeat_interleave(torch.arange(seq_sizes.numel(), device=device), seq_sizes) + block_mask = seg_ids[:, None] != seg_ids[None, :] + additive_mask = torch.zeros((total_tokens, total_tokens), dtype=dtype, device=device) + additive_mask.masked_fill_(block_mask, float("-inf")) + return additive_mask.view(1, 1, total_tokens, total_tokens) - real_dims = [1, h_patches, w_patches] - if pixel_shuffle_scale > 1: - if (h_patches % pixel_shuffle_scale) or (w_patches % pixel_shuffle_scale): - raise ValueError( - "Spatial dimensions must be divisible by pixel_shuffle_scale when pixel shuffle is enabled." - ) - virtual_dims = [1, h_patches // pixel_shuffle_scale, w_patches // pixel_shuffle_scale] - else: - virtual_dims = real_dims.copy() - return patches, virtual_dims, real_dims +def repeat_kv(hidden_states: torch.Tensor, num_key_value_groups: int) -> torch.Tensor: + """Repeat key/value heads for grouped-query attention.""" + if num_key_value_groups == 1: + return hidden_states -def _max_from_cu(cu: torch.Tensor | None, fallback: int) -> int: - """Helper to compute max sequence length from cumulative sequence lengths.""" - if cu is None or len(cu) < 2: - return fallback - return int((cu[1:] - cu[:-1]).max().item()) + batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, + num_key_value_heads, + num_key_value_groups, + seq_len, + head_dim, + ) + return hidden_states.reshape(batch, num_key_value_heads * num_key_value_groups, seq_len, head_dim) def flash_attention_document_mask_forward( @@ -599,7 +547,8 @@ def sdpa_document_mask_forward( v_lhd: torch.Tensor, # (L, H, D) dropout: float, scaling: float | None, - cu_seqlens: torch.Tensor | None, + attention_mask: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, ) -> torch.Tensor: """SDPA with block-diagonal masking for variable-length sequences.""" L, H, D = q_lhd.shape @@ -610,12 +559,17 @@ def sdpa_document_mask_forward( V = v_lhd.permute(1, 0, 2).unsqueeze(0) # Build block-diagonal mask for variable-length sequences - attn_mask = None - if cu_seqlens is not None: - seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() - seg_ids = torch.repeat_interleave(torch.arange(len(seq_sizes), device=q_lhd.device), seq_sizes) - block_mask = seg_ids[:, None] != seg_ids[None, :] # Cross-document attention blocked - attn_mask = torch.where(block_mask, -torch.inf, 0.0).to(q_lhd.dtype).view(1, 1, L, L) + attn_mask = attention_mask + if attn_mask is None: + attn_mask = build_document_attention_mask( + cu_seqlens=cu_seqlens, + total_tokens=L, + dtype=q_lhd.dtype, + device=q_lhd.device, + ) + + if attn_mask is not None and attn_mask.dtype != Q.dtype: + attn_mask = attn_mask.to(Q.dtype) Y = F.scaled_dot_product_attention(Q, K, V, attn_mask=attn_mask, dropout_p=dropout, scale=scaling) return Y.squeeze(0).permute(1, 0, 2) # Back to (L, H, D) @@ -686,11 +640,23 @@ def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor): class IsaacVisionAttention(nn.Module): """Custom attention that supports variable-length sequences with flash attention.""" - def __init__(self, config): + ATTENTION_KEY_MAP: dict[str, str] = { + "flash_attention_2": "isaac_flash_attention_2", + "flash_attention_3": "isaac_flash_attention_3", + "isaac_flash_attention_2": "isaac_flash_attention_2", + "isaac_flash_attention_3": "isaac_flash_attention_3", + "sdpa": "isaac_sdpa", + "isaac_sdpa": "isaac_sdpa", + "eager": "isaac_eager", + "isaac_eager": "isaac_eager", + } + + def __init__(self, vision_config): super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads + self.vision_config = vision_config + self.config = vision_config + self.embed_dim = vision_config.hidden_size + self.num_heads = vision_config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( @@ -698,7 +664,9 @@ def __init__(self, config): f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 - self.dropout = config.attention_dropout + self.dropout = vision_config.attention_dropout + self.is_causal = False + self.num_key_value_groups = 1 self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) @@ -721,24 +689,46 @@ def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None): k = self.k_proj(x).view(L, H, D) v = self.v_proj(x).view(L, H, D) - attn_impl = getattr(self.config, "_attn_implementation", "flash_attention_3") - - if attn_impl in ("flash_attention_2", "flash_attention_3"): - y_lhd, _ = flash_attention_document_mask_forward( - self, - q, - k, - v, - attention_mask=None, - dropout=p_drop, - scaling=self.scale, - cum_seq_q=cu_seqlens, - cum_seq_k=cu_seqlens, - max_seqlen=max_seqlen, - is_causal=False, - ) - else: - y_lhd = sdpa_document_mask_forward(q, k, v, dropout=p_drop, scaling=self.scale, cu_seqlens=cu_seqlens) + attn_impl = getattr(self.vision_config, "_attn_implementation", "flash_attention_3") + + attn_mask = build_document_attention_mask( + cu_seqlens=cu_seqlens, + total_tokens=L, + dtype=q.dtype, + device=q.device, + ) + + resolved_key = self.ATTENTION_KEY_MAP.get(attn_impl) + attention_fn = ALL_ATTENTION_FUNCTIONS.get(resolved_key) if resolved_key is not None else None + if attention_fn is None: + raise ValueError(f"Attention implementation {attn_impl} not found.") + + query_states = q.transpose(0, 1).unsqueeze(0) + key_states = k.transpose(0, 1).unsqueeze(0) + value_states = v.transpose(0, 1).unsqueeze(0) + + attention_kwargs: dict[str, Any] = { + "dropout": p_drop, + "scaling": self.scale, + "is_causal": False, + } + if cu_seqlens is not None: + attention_kwargs["cu_seq_lens_q"] = cu_seqlens + attention_kwargs["cu_seq_lens_k"] = cu_seqlens + if max_seqlen is not None: + attention_kwargs["max_length_q"] = max_seqlen + attention_kwargs["max_length_k"] = max_seqlen + + attn_output, _ = attention_fn( + self, + query_states, + key_states, + value_states, + attn_mask, + **attention_kwargs, + ) + + y_lhd = attn_output.squeeze(0).permute(1, 0, 2).contiguous() # Merge heads and project y = self.out_proj(y_lhd.reshape(L, self.embed_dim)) @@ -748,21 +738,21 @@ def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None): class IsaacVisionEncoderLayer(HFSiglip2EncoderLayer): """Isaac vision encoder layer with variable-length attention.""" - def __init__(self, config: IsaacVisionConfig): - super().__init__(config) - self.self_attn = IsaacVisionAttention(config) + def __init__(self, vision_config: IsaacVisionConfig): + super().__init__(vision_config) + self.self_attn = IsaacVisionAttention(vision_config) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor = None, max_seqlen: int = None, - ) -> tuple[torch.FloatTensor]: + ) -> torch.Tensor: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states, attn_weights = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, @@ -775,7 +765,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - return (hidden_states,) + return hidden_states class IsaacVisionEncoder(nn.Module): @@ -801,20 +791,208 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = encoder_layer( + hidden_states = encoder_layer( hidden_states, cu_seqlens, max_seqlen, ) - hidden_states = layer_outputs[0] - if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) return hidden_states, all_hidden_states, None +def _isaac_flash_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + is_causal: bool = False, + **kwargs, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + base_fn = _ORIGINAL_ATTENTION_FUNCTIONS.get("flash_attention_2") + if not isinstance(module, IsaacVisionAttention) or base_fn is None: + if base_fn is None: + raise ValueError("Base flash attention function unavailable for fallback.") + return base_fn( + module, + query, + key, + value, + attention_mask, + dropout=dropout, + scaling=scaling, + is_causal=is_causal, + **kwargs, + ) + + if query.dim() != 4 or query.size(0) != 1: + raise ValueError("IsaacVisionAttention expects packed sequences with batch size 1 when using packed attention.") + + _, num_heads, seq_len, head_dim = query.shape + q_lhd = query.transpose(1, 2).reshape(seq_len, num_heads, head_dim) + k_lhd = key.transpose(1, 2).reshape(seq_len, num_heads, head_dim) + v_lhd = value.transpose(1, 2).reshape(seq_len, num_heads, head_dim) + + cum_seq_q = kwargs.get("cu_seq_lens_q") + cum_seq_k = kwargs.get("cu_seq_lens_k", cum_seq_q) + max_seqlen = kwargs.get("max_length_q") + + effective_dropout = dropout if dropout is not None else (module.dropout if module.training else 0.0) + effective_scaling = module.scale if scaling is None else scaling + + attn_mask = attention_mask + if attn_mask is None: + attn_mask = build_document_attention_mask( + cu_seqlens=cum_seq_q, + total_tokens=seq_len, + dtype=q_lhd.dtype, + device=q_lhd.device, + ) + + attn_output_lhd, attn_weights = flash_attention_document_mask_forward( + module, + q_lhd, + k_lhd, + v_lhd, + attention_mask=attn_mask, + dropout=effective_dropout, + scaling=effective_scaling, + cum_seq_q=cum_seq_q, + cum_seq_k=cum_seq_k, + max_seqlen=max_seqlen, + is_causal=is_causal, + ) + + attn_output = attn_output_lhd.permute(1, 0, 2).unsqueeze(0) + return attn_output, attn_weights + + +def _isaac_sdpa_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + is_causal: bool = False, + **kwargs, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + base_fn = _ORIGINAL_ATTENTION_FUNCTIONS.get("sdpa") + if not isinstance(module, IsaacVisionAttention) or base_fn is None: + if base_fn is None: + raise ValueError("Base SDPA function unavailable for fallback.") + return base_fn( + module, + query, + key, + value, + attention_mask, + dropout=dropout, + scaling=scaling, + is_causal=is_causal, + **kwargs, + ) + + if query.dim() != 4 or query.size(0) != 1: + raise ValueError("IsaacVisionAttention expects packed sequences with batch size 1 when using packed attention.") + + _, num_heads, seq_len, head_dim = query.shape + q_lhd = query.transpose(1, 2).reshape(seq_len, num_heads, head_dim) + k_lhd = key.transpose(1, 2).reshape(seq_len, num_heads, head_dim) + v_lhd = value.transpose(1, 2).reshape(seq_len, num_heads, head_dim) + + cum_seq = kwargs.get("cu_seq_lens_q") + effective_dropout = dropout if dropout is not None else (module.dropout if module.training else 0.0) + effective_scaling = module.scale if scaling is None else scaling + + attn_mask = attention_mask + if attn_mask is None: + attn_mask = build_document_attention_mask( + cu_seqlens=cum_seq, + total_tokens=seq_len, + dtype=q_lhd.dtype, + device=q_lhd.device, + ) + + attn_output_lhd = sdpa_document_mask_forward( + q_lhd, + k_lhd, + v_lhd, + dropout=effective_dropout, + scaling=effective_scaling, + attention_mask=attn_mask, + cu_seqlens=cum_seq, + ) + + attn_output = attn_output_lhd.permute(1, 0, 2).unsqueeze(0) + return attn_output, None + + +def _isaac_eager_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + is_causal: bool = False, + **kwargs, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + base_fn = _ORIGINAL_ATTENTION_FUNCTIONS.get("eager") + if not isinstance(module, IsaacVisionAttention) or base_fn is None: + if base_fn is None: + raise ValueError("Base eager attention function unavailable for fallback.") + return base_fn( + module, + query, + key, + value, + attention_mask, + dropout=dropout, + scaling=scaling, + is_causal=is_causal, + **kwargs, + ) + + if query.dim() != 4 or query.size(0) != 1: + raise ValueError("IsaacVisionAttention expects packed sequences with batch size 1 when using packed attention.") + + _, num_heads, seq_len, head_dim = query.shape + q_lhd = query.transpose(1, 2).reshape(seq_len, num_heads, head_dim) + k_lhd = key.transpose(1, 2).reshape(seq_len, num_heads, head_dim) + v_lhd = value.transpose(1, 2).reshape(seq_len, num_heads, head_dim) + + effective_scaling = module.scale if scaling is None else scaling + attn_weights = torch.matmul(q_lhd, k_lhd.transpose(1, 2)) * effective_scaling + + if attention_mask is not None: + mask = attention_mask + if mask.dim() == 4: + mask = mask.squeeze(0).squeeze(0) + attn_weights = attn_weights + mask + + attn_weights = torch.softmax(attn_weights, dim=-1) + if dropout and module.training: + attn_weights = F.dropout(attn_weights, p=dropout, training=True) + + attn_output_lhd = torch.matmul(attn_weights, v_lhd) + attn_output = attn_output_lhd.permute(1, 0, 2).unsqueeze(0) + return attn_output, attn_weights + + +ALL_ATTENTION_FUNCTIONS.register("isaac_flash_attention_2", _isaac_flash_attention_forward) +ALL_ATTENTION_FUNCTIONS.register("isaac_flash_attention_3", _isaac_flash_attention_forward) +ALL_ATTENTION_FUNCTIONS.register("isaac_sdpa", _isaac_sdpa_forward) +ALL_ATTENTION_FUNCTIONS.register("isaac_eager", _isaac_eager_forward) + + def create_pixel_shuffle_index_map( seq_sizes: torch.Tensor, token_grids: torch.Tensor, @@ -991,12 +1169,6 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): return hidden_states -# ============================================================================ -# Configuration -# ============================================================================ - -MAX_PIXELS = 60_000_000 # 60โ€‘megapixel ceiling โ‰ˆ 8200 ร— 7300 px - def get_scaled_image_size( scale: float, original_size: int, @@ -1180,20 +1352,14 @@ class IsaacConfig(Qwen3Config): def __init__( self, - vision_config=None, + vision_config: IsaacVisionConfig | None = None, text_config: Qwen3Config | dict | None = None, - vision_patch_size: int = 16, - vision_max_num_patches: int = 256, - vision_min_num_patches: int | None = None, - pixel_shuffle_scale: int = 1, - vision_rescale_factor: float = VISION_SCALE, - vision_mean: float | Sequence[float] = VISION_MEAN, - vision_std: float | Sequence[float] = VISION_STD, + vision_rescale_factor: float = 1/255, max_sequence_length: int = 16384, vision_token: str = "", - vision_attn_implementation: str | None = None, **kwargs, ): + self._rope_scaling: dict[str, Any] | None = None resolved_text_config = kwargs.pop("text_config", text_config) if isinstance(resolved_text_config, Qwen3Config): text_config_kwargs = copy.deepcopy(resolved_text_config.to_dict()) @@ -1208,30 +1374,25 @@ def __init__( super().__init__(**text_config_kwargs) self.text_config = Qwen3Config(**text_config_kwargs) + if self._rope_scaling is None: + self._rope_scaling = getattr(self.text_config, "rope_scaling", None) + else: + self.text_config.rope_scaling = self._rope_scaling # Handle vision config - either dict or IsaacVisionConfig instance if isinstance(vision_config, dict): self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif isinstance(vision_config, IsaacVisionConfig): + self.vision_config = vision_config elif vision_config is None: self.vision_config = self.sub_configs["vision_config"]() - else: - self.vision_config = vision_config - - # EventStreamProcessor parameters (for backward compatibility) - self.video_patch_size = vision_patch_size - self.vision_max_num_patches = vision_max_num_patches - self.vision_min_num_patches = vision_min_num_patches - self.pixel_shuffle_scale = pixel_shuffle_scale # Vision normalization parameters self.vision_rescale_factor = float(vision_rescale_factor) - self.vision_mean = _normalize_rgb_values(vision_mean, name="vision_mean") - self.vision_std = _normalize_rgb_values(vision_std, name="vision_std") # Processing parameters self.max_sequence_length = max_sequence_length self.vision_token = vision_token - self.vision_attn_implementation = vision_attn_implementation def get_text_config(self, *_, **kwargs) -> Qwen3Config: # Accept optional decoder/encoder flags to align with HF composite configs @@ -1239,6 +1400,34 @@ def get_text_config(self, *_, **kwargs) -> Qwen3Config: kwargs.pop("encoder", None) return self.text_config + @property + def rope_scaling(self): + if hasattr(self, "text_config") and self.text_config is not None: + return getattr(self.text_config, "rope_scaling", None) + return self._rope_scaling + + @rope_scaling.setter + def rope_scaling(self, value): + self._rope_scaling = value + if hasattr(self, "text_config") and self.text_config is not None: + self.text_config.rope_scaling = value + + @property + def vision_attn_implementation(self) -> str | None: + + value = getattr(self.vision_config, "_attn_implementation", None) + if value is None: + value = getattr(self.vision_config, "attn_implementation", None) + return value + + @vision_attn_implementation.setter + def vision_attn_implementation(self, value: str | None) -> None: + self.vision_config._attn_implementation = value + if value is not None: + self.vision_config.attn_implementation = value + elif hasattr(self.vision_config, "attn_implementation"): + delattr(self.vision_config, "attn_implementation") + # ============================================================================ # Processor Components @@ -1287,28 +1476,21 @@ def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> # Processor # ============================================================================ + class IsaacProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] - image_processor_class = ("IsaacImageProcessor", "IsaacImageProcessorFast") + image_processor_class = ("IsaacImageProcessorFast",) tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") def __init__( self, - image_processor: IsaacImageProcessor | BaseImageProcessorFast | None = None, + image_processor: IsaacImageProcessorFast | None = None, tokenizer: Qwen2Tokenizer | None = None, *, vision_token: str = "", max_sequence_length: int = 16384, - vision_patch_size: int = 16, - vision_max_num_patches: int = 256, - vision_min_num_patches: int | None = None, - pixel_shuffle_scale: int = 1, rescale_factor: float | None = None, - image_mean: float | Sequence[float] | None = None, - image_std: float | Sequence[float] | None = None, - vision_attn_implementation: str | None = None, config: IsaacConfig | dict | None = None, - **kwargs, ) -> None: if tokenizer is None: raise ValueError("`tokenizer` must be provided to initialize IsaacProcessor.") @@ -1317,58 +1499,18 @@ def __init__( config = IsaacConfig(**config) if config is not None: - vision_patch_size = config.video_patch_size - vision_max_num_patches = config.vision_max_num_patches - vision_min_num_patches = config.vision_min_num_patches - pixel_shuffle_scale = config.pixel_shuffle_scale max_sequence_length = config.max_sequence_length vision_token = config.vision_token - vision_attn_implementation = config.vision_attn_implementation rescale_factor = config.vision_rescale_factor - image_mean = tuple(config.vision_mean) - image_std = tuple(config.vision_std) resolved_rescale_factor = ( - float(rescale_factor) if rescale_factor is not None else float(VISION_SCALE) - ) - resolved_image_mean = _normalize_rgb_values( - image_mean if image_mean is not None else VISION_MEAN, - name="image_mean", - ) - resolved_image_std = _normalize_rgb_values( - image_std if image_std is not None else VISION_STD, - name="image_std", + float(rescale_factor) if rescale_factor is not None else float(1/255) ) - if image_processor is None: - image_processor = IsaacImageProcessor( - patch_size=vision_patch_size, - max_num_patches=vision_max_num_patches, - min_num_patches=vision_min_num_patches, - pixel_shuffle_scale=pixel_shuffle_scale, - rescale_factor=resolved_rescale_factor, - image_mean=resolved_image_mean, - image_std=resolved_image_std, - ) - else: - vision_patch_size = getattr(image_processor, "patch_size", vision_patch_size) - vision_max_num_patches = getattr(image_processor, "max_num_patches", vision_max_num_patches) - vision_min_num_patches = getattr(image_processor, "min_num_patches", vision_min_num_patches) - pixel_shuffle_scale = getattr(image_processor, "pixel_shuffle_scale", pixel_shuffle_scale) - resolved_rescale_factor = getattr(image_processor, "rescale_factor", resolved_rescale_factor) - resolved_image_mean = _normalize_rgb_values( - getattr(image_processor, "image_mean", resolved_image_mean), - name="image_mean", - ) - resolved_image_std = _normalize_rgb_values( - getattr(image_processor, "image_std", resolved_image_std), - name="image_std", - ) - if config is not None: config.vision_rescale_factor = resolved_rescale_factor - config.vision_mean = resolved_image_mean - config.vision_std = resolved_image_std + + self.image_processor = image_processor super().__init__(image_processor, tokenizer) self.current_processor = self.image_processor @@ -1379,15 +1521,6 @@ def __init__( self.vision_token = vision_token self.max_sequence_length = max_sequence_length - self.vision_attn_implementation = vision_attn_implementation - - self.patch_size = getattr(self.image_processor, "patch_size", vision_patch_size) - self.max_num_patches = getattr(self.image_processor, "max_num_patches", vision_max_num_patches) - self.min_num_patches = getattr(self.image_processor, "min_num_patches", vision_min_num_patches) - self.pixel_shuffle_scale = getattr(self.image_processor, "pixel_shuffle_scale", pixel_shuffle_scale) - self.rescale_factor = getattr(self.image_processor, "rescale_factor", resolved_rescale_factor) - self.image_mean = tuple(getattr(self.image_processor, "image_mean", resolved_image_mean)) - self.image_std = tuple(getattr(self.image_processor, "image_std", resolved_image_std)) def build_event_stream_simple( self, @@ -1536,7 +1669,8 @@ def __init__(self, config: IsaacConfig, device=None): super().__init__() self.config = config - rope_scaling = getattr(config, "rope_scaling", None) or {} + rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config + rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) if rope_type not in ROPE_INIT_FUNCTIONS: raise ValueError(f"Unsupported rope_type '{rope_type}' for IsaacRotaryEmbedding") @@ -1546,10 +1680,10 @@ def __init__(self, config: IsaacConfig, device=None): sanitized_scaling = {k: v for k, v in rope_scaling.items() if k not in self.EXTRA_ROPE_KEYS} if sanitized_scaling != rope_scaling: - config_for_rope = copy.copy(config) + config_for_rope = copy.copy(rope_source_cfg) config_for_rope.rope_scaling = sanitized_scaling else: - config_for_rope = config + config_for_rope = rope_source_cfg init_device = device if device is not None and getattr(device, "type", None) != "meta" else None inv_freq, attention_scaling = rope_init_fn(config_for_rope, device=init_device) @@ -1619,19 +1753,12 @@ def __init__(self, config: IsaacConfig): self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) - vision_cfg = config.vision_config - # Use vision_attn_implementation if specified, otherwise fall back to general attn_implementation - vision_cfg._attn_implementation = ( - config.vision_attn_implementation - if config.vision_attn_implementation is not None - else config._attn_implementation - ) - if vision_cfg is None: + if config.vision_config is None: raise ValueError("IsaacConfig should always have vision_config") - hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2) + hidden_dim = config.vision_config.hidden_size * (config.vision_config.pixel_shuffle_scale_factor**2) self.vision_embedding = nn.Sequential( - IsaacVisionTransformer(vision_cfg), + IsaacVisionTransformer(config.vision_config), nn.Linear( hidden_dim, 4 * hidden_dim, @@ -1970,7 +2097,6 @@ class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): def __init__(self, config: IsaacConfig): super().__init__(config) - self.model = IsaacModel(config) # Use our custom model self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -2132,18 +2258,9 @@ def can_generate(self) -> bool: return True - -def _load_isaac_fast_image_processor(): - try: - from .image_processing_isaac_fast import IsaacImageProcessorFast as fast_cls - except ImportError: - fast_cls = None - return fast_cls - AutoImageProcessor.register( IsaacConfig, - slow_image_processor_class=IsaacImageProcessor, - fast_image_processor_class=_load_isaac_fast_image_processor(), + fast_image_processor_class=IsaacImageProcessorFast, exist_ok=True, ) @@ -2152,20 +2269,9 @@ def _load_isaac_fast_image_processor(): "IsaacConfig", "IsaacModel", "IsaacForConditionalGeneration", - "IsaacImageProcessor", "IsaacImageProcessorFast", "IsaacProcessor", ] -def _prepare_image_tensor(image: torch.Tensor, scale: float, mean: tuple[float, ...], std: tuple[float, ...]) -> torch.Tensor: - """Mirror the prepare_image_tensor utility used in the training pipelines.""" - if not torch.is_floating_point(image): - image = image.float() - - rescaled = image * scale - mean_tensor = torch.tensor(mean, dtype=torch.float32, device=rescaled.device).view(1, 1, 1, -1) - std_tensor = torch.tensor(std, dtype=torch.float32, device=rescaled.device).view(1, 1, 1, -1) - normalized = (rescaled - mean_tensor) / std_tensor - return normalized def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> torch.Tensor: @@ -2179,3 +2285,4 @@ def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> p_indices = frame_indices[torch.tensor(is_p_frame, device=frames.device)] frames[p_indices] = frames[p_indices] - frames[last_i_indices[p_indices]] return frames + From 965215cc72016b564c63c3f473f6c8cc8716d1da Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Wed, 15 Oct 2025 13:12:18 +0400 Subject: [PATCH 06/77] fix: update imports --- .../models/isaac/modular_isaac.py | 50 ++++++++----------- 1 file changed, 20 insertions(+), 30 deletions(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 726c55aa56f6..1972243d4830 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -110,44 +110,34 @@ slice as ts_slice, ) -from transformers import ( - AutoImageProcessor, - AutoModel, - AutoTokenizer, - BatchFeature, - Cache, - Qwen3Config, - Qwen3ForCausalLM, - Qwen3PreTrainedModel, -) -from transformers.cache_utils import SlidingWindowCache, StaticCache -from transformers.generation.utils import GenerationMixin -from transformers.image_processing_utils_fast import ( +from ...cache_utils import Cache, SlidingWindowCache, StaticCache +from ...generation.utils import GenerationMixin +from ...image_processing_utils_fast import ( BaseImageProcessorFast, + BatchFeature, DefaultFastImageProcessorKwargs, SizeDict, group_images_by_shape, reorder_images, ) -from transformers.image_utils import ( - ChannelDimension, - PILImageResampling, -) -from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer -from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig -from transformers.models.siglip2.modeling_siglip2 import Siglip2EncoderLayer as HFSiglip2EncoderLayer -from transformers.processing_utils import ProcessorMixin, Unpack -from transformers.tokenization_utils import TensorType -from transformers.utils import auto_docstring +from ...image_utils import ChannelDimension, PILImageResampling +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import ProcessorMixin, Unpack +from ...utils import TensorType, auto_docstring # Vision preprocessing constants -from transformers.utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN -from transformers.utils.constants import IMAGENET_STANDARD_STD as VISION_STD -from transformers.utils.import_utils import is_torchdynamo_compiling +from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN +from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD +from ...utils.import_utils import is_torchdynamo_compiling +from ..auto import AutoImageProcessor, AutoModel, AutoTokenizer +from ..qwen2.tokenization_qwen2 import Qwen2Tokenizer +from ..qwen3.configuration_qwen3 import Qwen3Config +from ..qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3PreTrainedModel +from ..siglip2.configuration_siglip2 import Siglip2VisionConfig +from ..siglip2.modeling_siglip2 import Siglip2EncoderLayer as HFSiglip2EncoderLayer _ORIGINAL_ATTENTION_FUNCTIONS: dict[str, Callable[..., tuple[torch.Tensor, Optional[torch.Tensor]]]] = {} From 4c4f1c9481a1c574759d1b89b50de80aea2ad044 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Wed, 15 Oct 2025 16:47:28 +0400 Subject: [PATCH 07/77] fix: adjust typing to get modular convert script working --- src/transformers/models/isaac/modular_isaac.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 1972243d4830..e750a2bc839f 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -1469,12 +1469,12 @@ def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> class IsaacProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] - image_processor_class = ("IsaacImageProcessorFast",) + image_processor_class = "IsaacImageProcessorFast" tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") def __init__( self, - image_processor: IsaacImageProcessorFast | None = None, + image_processor: "IsaacImageProcessorFast | None" = None, tokenizer: Qwen2Tokenizer | None = None, *, vision_token: str = "", From 021a1aee6756203e79d56e1d92cf680951ec363e Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Wed, 15 Oct 2025 16:47:39 +0400 Subject: [PATCH 08/77] feat: modular convert utility outputs --- .../models/isaac/configuration_isaac.py | 178 +++++--- .../isaac/image_processing_isaac_fast.py | 394 ++++++++++++------ .../models/isaac/modeling_isaac.py | 326 +++++++++------ .../models/isaac/processing_isaac.py | 182 ++++---- 4 files changed, 676 insertions(+), 404 deletions(-) diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index 3ffac505e2ee..bc3d59b58cef 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -4,11 +4,90 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ +# Perceptron, Inc. Non-Production License + +### 1. Scope and acceptance + +# **1.1. Scope of the Agreement.** +# This Agreement applies to any use, modification, or Distribution of any Perceptron Model by You, regardless of the source You obtained a copy of such Perceptron Model. +# +# **1.2. Acceptance.** By accessing, using, modifying, Distributing a Perceptron Model, or by creating, using or distributing a Derivative of the Perceptron Model, You agree to be bound by this Agreement. +# +# **1.3. Acceptance on behalf of a third-party.** If You accept this Agreement on behalf of Your employer or another person or entity, You warrant and represent that You have the authority to act and accept this Agreement on their behalf. In such a case, the word โ€œYouโ€ in this Agreement will refer to Your employer or such other person or entity. +# +# ## 2. License +# **2.1. Grant of rights.** Subject to Section 3 below, Perceptron, Inc. hereby grants You a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable, limited license to use, copy, modify, and Distribute under the conditions provided in Section 2.2 below, the Perceptron Model and any Derivatives made by or for Perceptron, Inc. and to create Derivatives of the Perceptron Model. +# +# **2.2. Distribution of Perceptron Model and Derivatives made by or for Perceptron, Inc..** Subject to Section 3 below, You may Distribute copies of the Perceptron Model and/or Derivatives made by or for Perceptron, Inc., under the following conditions: +# - You must make available a copy of this Agreement to third-party recipients of the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. you Distribute, it being specified that any rights to use the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. shall be directly granted by Perceptron, Inc. to said third-party recipients pursuant to the Perceptron, Inc. Non-Production License agreement executed between these parties; +# - You must retain in all copies of the Perceptron Models the following attribution notice within a โ€œNoticeโ€ text file distributed as part of such copies: โ€œLicensed by Perceptron, Inc. under the Perceptron, Inc. Non-Production Licenseโ€. +# +# **2.3. Distribution of Derivatives made by or for You.** Subject to Section 3 below, You may Distribute any Derivatives made by or for You under additional or different terms and conditions, provided that: +# - In any event, the use and modification of Perceptron Model and/or Derivatives made by or for Perceptron, Inc. shall remain governed by the terms and conditions of this Agreement; +# - You include in any such Derivatives made by or for You prominent notices stating that You modified the concerned Perceptron Model; and +# - Any terms and conditions You impose on any third-party recipients relating to Derivatives made by or for You shall neither limit such third-party recipientsโ€™ use of the Perceptron Model or any Derivatives made by or for Perceptron, Inc. in accordance with the Perceptron, Inc. Non-Production License nor conflict with any of its terms and conditions. +# +# ## 3. Limitations +# **3.1. Misrepresentation.** You must not misrepresent or imply, through any means, that the Derivatives made by or for You and/or any modified version of the Perceptron Model You Distribute under your name and responsibility is an official product of Perceptron, Inc. or has been endorsed, approved or validated by Perceptron, Inc., unless You are authorized by Us to do so in writing. +# +# **3.2. Usage Limitation** +# - You shall only use the Perceptron Models and Derivatives (whether or not created by Perceptron, Inc.) for testing, research, Personal, or evaluation purposes in Non-Production Environments; +# - Subject to the foregoing, You shall not supply the Perceptron Models or Derivatives in the course of a commercial activity, whether in return for payment or free of charge, in any medium or form, including but not limited to through a hosted or managed service (e.g. SaaS, cloud instances, etc.), or behind a software layer. +# +# **3.3. Usage not permitted under this Agreement.** If You want to use a Perceptron Model or a Derivative for any purpose that is not expressly authorized under this Agreement, You must request a license from Perceptron, Inc., which Perceptron, Inc. may grant to You in Perceptron, Inc.โ€™s sole discretion. Please contact Perceptron, Inc. at the following e-mail address if You want to discuss such a license: sales@perceptron.inc +# +# ## 4. Intellectual Property +# **4.1. Trademarks.** No trademark licenses are granted under this Agreement, and in connection with the Perceptron Models, You may not use any name or mark owned by or associated with Perceptron, Inc. or any of its affiliates, except (i) as required for reasonable and customary use in describing and Distributing the Perceptron Models and Derivatives made by or for Perceptron, Inc. and (ii) for attribution purposes as required by this Agreement. +# +# **4.2. Outputs.** We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs You generate and their subsequent uses in accordance with this Agreement. +# +# **4.3. Derivatives.** By entering into this Agreement, You accept that any Derivatives that You may create or that may be created for You shall be subject to the restrictions set out in Section 3 of this Agreement. +# +# # 5. Liability +# **5.1. Limitation of liability.** In no event, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall Perceptron, Inc. be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Perceptron Models and Derivatives (including but not limited to damages for loss of data, loss of goodwill, loss of expected profit or savings, work stoppage, computer failure or malfunction, or any damage caused by malware or security breaches), even if Perceptron, Inc. has been advised of the possibility of such damages. +# +# **5.2. Indemnification.** You agree to indemnify and hold harmless Perceptron, Inc. from and against any claims, damages, or losses arising out of or related to Your use or Distribution of the Perceptron Models and Derivatives. +# +# ## 6. Warranty +# **6.1. Disclaimer.** Unless required by applicable law or agreed to in writing, Perceptron, Inc. provides the Perceptron Models and Derivatives on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. Perceptron, Inc. does not represent nor warrant that the Perceptron Models and Derivatives will be error-free, meet Your or any third partyโ€™s requirements, be secure or will allow You or any third party to achieve any kind of result or generate any kind of content. You are solely responsible for determining the appropriateness of using or Distributing the Perceptron Models and Derivatives and assume any risks associated with Your exercise of rights under this Agreement. +# +# # 7. Termination +# **7.1. Term.** This Agreement is effective as of the date of your acceptance of this Agreement or access to the concerned Perceptron Models or Derivatives and will continue until terminated in accordance with the following terms. +# +# **7.2. Termination.** Perceptron, Inc. may terminate this Agreement at any time if You are in breach of this Agreement. Upon termination of this Agreement, You must cease to use all Perceptron Models and Derivatives and shall permanently delete any copy thereof. Sections 5, 6, 7 and 8 shall survive the termination of this Agreement. +# +# **7.3. Litigation.** If You initiate any legal action or proceedings against Us or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the Model or a Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement will immediately terminate as of the date such legal action or claim is filed or initiated. +# +# # 8. General provisions +# 8.1. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the State of Washington, without regard to its conflict of law principles. +# +# 8.2. Jurisdiction. The state and federal courts located in King County, Washington shall have exclusive jurisdiction over any dispute arising out of or relating to this Agreement, and You and We consent to personal jurisdiction and venue in such courts. +# +# **8.3. Severability.** If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. +# +# # 9. Definitions +# **โ€œAgreementโ€**: means this Perceptron, Inc. Non-Production License agreement governing the access, use, and Distribution of the Perceptron Models and Derivatives. +# +# **โ€œDerivativeโ€**: means any (i) modified version of the Perceptron Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the Perceptron Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered as Derivatives under this Agreement. +# +# **โ€œDistributionโ€**, **โ€œDistributingโ€**, **โ€œDistributeโ€** or **โ€œDistributedโ€**: means providing or making available, by any means, a copy of the Perceptron Models and/or the Derivatives as the case may be, subject to Section 3 of this Agreement. +# +# **โ€œPerceptron, Inc.โ€**, **โ€œWeโ€** or **โ€œUsโ€**: means Perceptron, Inc., a Delaware corporation with its principal place of business at 10900 NE 8th St Suite 613, Bellevue, WA 98004. +# +# **โ€œPerceptron Modelโ€**: means the foundational large language model(s), and its elements which include algorithms, software, instructed checkpoints, parameters, source code (inference code, evaluation code and, if applicable, fine-tuning code) and any other elements associated thereto made available by Perceptron, Inc. under this Agreement, including, if any, the technical documentation, manuals and instructions for the use and operation thereof. +# +# **โ€œNon-Production Environmentโ€**: means any setting, use case, or application of the Perceptron Models or Derivatives that expressly excludes live, real-world conditions, commercial operations, revenue-generating activities, or direct interactions with or impacts on end users (such as, for instance, Your employees or customers). Non-Production Environment may include, but is not limited to, any setting, use case, or application for research, development, testing, quality assurance, training, internal evaluation (other than any internal usage by employees in the context of the companyโ€™s business activities), and demonstration purposes. +# +# **โ€œOutputsโ€**: means any content generated by the operation of the Perceptron Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a Perceptron Models, such as any fine-tuned versions of the Perceptron Models, the weights, or parameters. +# +# **โ€œPersonalโ€**: means any use of a Perceptron Model or a Derivative that is (i) solely for personal, non-profit and non-commercial purposes and (ii) not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities. For illustration purposes, Personal use of a Model or a Derivative does not include any usage by individuals employed in companies in the context of their daily tasks, any activity that is intended to generate revenue, or that is performed on behalf of a commercial entity. +# +# **โ€œYouโ€**: means the individual or entity entering into this Agreement with Perceptron, Inc.. + import copy -from collections.abc import Sequence +from typing import Any -# Build the list of all image processors from ...configuration_utils import PretrainedConfig, layer_type_validation from ...modeling_rope_utils import rope_config_validation @@ -27,6 +106,7 @@ class IsaacVisionConfig(PretrainedConfig): model_type = "isaac_vision" base_config_key = "vision_config" + _attn_implementation: str | None = None def __init__( self, @@ -50,29 +130,16 @@ def __init__( # Add our custom fields self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor + if self._attn_implementation is None: + self._attn_implementation = "flash_attention_2" -# Vision preprocessing constants -VISION_MEAN = (0.5, 0.5, 0.5) -VISION_STD = (0.5, 0.5, 0.5) -VISION_SCALE = 1 / 255 - - -def _normalize_rgb_values( - values: float | Sequence[float] | tuple[float, ...], - *, - name: str, -) -> tuple[float, float, float]: - """Coerce RGB normalization parameters into a 3-tuple of floats.""" - if isinstance(values, (list, tuple)): - if len(values) == 3: - return tuple(float(v) for v in values) - if len(values) == 1: - value = float(values[0]) - return (value, value, value) - raise ValueError(f"`{name}` must have length 1 or 3 when provided as a sequence. Got length {len(values)}.") + @property + def attn_implementation(self) -> str | None: + return self._attn_implementation - value = float(values) - return (value, value, value) + @attn_implementation.setter + def attn_implementation(self, value: str | None) -> None: + self._attn_implementation = value class IsaacConfig(PretrainedConfig): @@ -101,20 +168,18 @@ class IsaacConfig(PretrainedConfig): def __init__( self, - vision_config=None, + vision_config: IsaacVisionConfig | None = None, text_config: Qwen3Config | dict | None = None, - vision_patch_size: int = 16, - vision_max_num_patches: int = 256, - vision_min_num_patches: int | None = None, - pixel_shuffle_scale: int = 1, - vision_rescale_factor: float = VISION_SCALE, - vision_mean: float | Sequence[float] = VISION_MEAN, - vision_std: float | Sequence[float] = VISION_STD, + vision_rescale_factor: float = 1 / 255, max_sequence_length: int = 16384, vision_token: str = "", - vision_attn_implementation: str | None = None, **kwargs, ): + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self._rope_scaling: dict[str, Any] | None = None resolved_text_config = kwargs.pop("text_config", text_config) if isinstance(resolved_text_config, Qwen3Config): text_config_kwargs = copy.deepcopy(resolved_text_config.to_dict()) @@ -165,36 +230,26 @@ def __init__( for i in range(self.num_hidden_layers) ] layer_type_validation(self.layer_types) - - super().__init__( - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) self.text_config = Qwen3Config(**text_config_kwargs) + if self._rope_scaling is None: + self._rope_scaling = getattr(self.text_config, "rope_scaling", None) + else: + self.text_config.rope_scaling = self._rope_scaling # Handle vision config - either dict or IsaacVisionConfig instance if isinstance(vision_config, dict): self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif isinstance(vision_config, IsaacVisionConfig): + self.vision_config = vision_config elif vision_config is None: self.vision_config = self.sub_configs["vision_config"]() - else: - self.vision_config = vision_config - - # EventStreamProcessor parameters (for backward compatibility) - self.video_patch_size = vision_patch_size - self.vision_max_num_patches = vision_max_num_patches - self.vision_min_num_patches = vision_min_num_patches - self.pixel_shuffle_scale = pixel_shuffle_scale # Vision normalization parameters self.vision_rescale_factor = float(vision_rescale_factor) - self.vision_mean = _normalize_rgb_values(vision_mean, name="vision_mean") - self.vision_std = _normalize_rgb_values(vision_std, name="vision_std") # Processing parameters self.max_sequence_length = max_sequence_length self.vision_token = vision_token - self.vision_attn_implementation = vision_attn_implementation def get_text_config(self, *_, **kwargs) -> Qwen3Config: # Accept optional decoder/encoder flags to align with HF composite configs @@ -202,5 +257,32 @@ def get_text_config(self, *_, **kwargs) -> Qwen3Config: kwargs.pop("encoder", None) return self.text_config + @property + def rope_scaling(self): + if hasattr(self, "text_config") and self.text_config is not None: + return getattr(self.text_config, "rope_scaling", None) + return self._rope_scaling + + @rope_scaling.setter + def rope_scaling(self, value): + self._rope_scaling = value + if hasattr(self, "text_config") and self.text_config is not None: + self.text_config.rope_scaling = value + + @property + def vision_attn_implementation(self) -> str | None: + value = getattr(self.vision_config, "_attn_implementation", None) + if value is None: + value = getattr(self.vision_config, "attn_implementation", None) + return value + + @vision_attn_implementation.setter + def vision_attn_implementation(self, value: str | None) -> None: + self.vision_config._attn_implementation = value + if value is not None: + self.vision_config.attn_implementation = value + elif hasattr(self.vision_config, "attn_implementation"): + delattr(self.vision_config, "attn_implementation") + __all__ = ["IsaacConfig"] diff --git a/src/transformers/models/isaac/image_processing_isaac_fast.py b/src/transformers/models/isaac/image_processing_isaac_fast.py index 72c93f1c3a81..ce778ef509da 100644 --- a/src/transformers/models/isaac/image_processing_isaac_fast.py +++ b/src/transformers/models/isaac/image_processing_isaac_fast.py @@ -4,50 +4,108 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ +# Perceptron, Inc. Non-Production License + +### 1. Scope and acceptance + +# **1.1. Scope of the Agreement.** +# This Agreement applies to any use, modification, or Distribution of any Perceptron Model by You, regardless of the source You obtained a copy of such Perceptron Model. +# +# **1.2. Acceptance.** By accessing, using, modifying, Distributing a Perceptron Model, or by creating, using or distributing a Derivative of the Perceptron Model, You agree to be bound by this Agreement. +# +# **1.3. Acceptance on behalf of a third-party.** If You accept this Agreement on behalf of Your employer or another person or entity, You warrant and represent that You have the authority to act and accept this Agreement on their behalf. In such a case, the word โ€œYouโ€ in this Agreement will refer to Your employer or such other person or entity. +# +# ## 2. License +# **2.1. Grant of rights.** Subject to Section 3 below, Perceptron, Inc. hereby grants You a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable, limited license to use, copy, modify, and Distribute under the conditions provided in Section 2.2 below, the Perceptron Model and any Derivatives made by or for Perceptron, Inc. and to create Derivatives of the Perceptron Model. +# +# **2.2. Distribution of Perceptron Model and Derivatives made by or for Perceptron, Inc..** Subject to Section 3 below, You may Distribute copies of the Perceptron Model and/or Derivatives made by or for Perceptron, Inc., under the following conditions: +# - You must make available a copy of this Agreement to third-party recipients of the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. you Distribute, it being specified that any rights to use the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. shall be directly granted by Perceptron, Inc. to said third-party recipients pursuant to the Perceptron, Inc. Non-Production License agreement executed between these parties; +# - You must retain in all copies of the Perceptron Models the following attribution notice within a โ€œNoticeโ€ text file distributed as part of such copies: โ€œLicensed by Perceptron, Inc. under the Perceptron, Inc. Non-Production Licenseโ€. +# +# **2.3. Distribution of Derivatives made by or for You.** Subject to Section 3 below, You may Distribute any Derivatives made by or for You under additional or different terms and conditions, provided that: +# - In any event, the use and modification of Perceptron Model and/or Derivatives made by or for Perceptron, Inc. shall remain governed by the terms and conditions of this Agreement; +# - You include in any such Derivatives made by or for You prominent notices stating that You modified the concerned Perceptron Model; and +# - Any terms and conditions You impose on any third-party recipients relating to Derivatives made by or for You shall neither limit such third-party recipientsโ€™ use of the Perceptron Model or any Derivatives made by or for Perceptron, Inc. in accordance with the Perceptron, Inc. Non-Production License nor conflict with any of its terms and conditions. +# +# ## 3. Limitations +# **3.1. Misrepresentation.** You must not misrepresent or imply, through any means, that the Derivatives made by or for You and/or any modified version of the Perceptron Model You Distribute under your name and responsibility is an official product of Perceptron, Inc. or has been endorsed, approved or validated by Perceptron, Inc., unless You are authorized by Us to do so in writing. +# +# **3.2. Usage Limitation** +# - You shall only use the Perceptron Models and Derivatives (whether or not created by Perceptron, Inc.) for testing, research, Personal, or evaluation purposes in Non-Production Environments; +# - Subject to the foregoing, You shall not supply the Perceptron Models or Derivatives in the course of a commercial activity, whether in return for payment or free of charge, in any medium or form, including but not limited to through a hosted or managed service (e.g. SaaS, cloud instances, etc.), or behind a software layer. +# +# **3.3. Usage not permitted under this Agreement.** If You want to use a Perceptron Model or a Derivative for any purpose that is not expressly authorized under this Agreement, You must request a license from Perceptron, Inc., which Perceptron, Inc. may grant to You in Perceptron, Inc.โ€™s sole discretion. Please contact Perceptron, Inc. at the following e-mail address if You want to discuss such a license: sales@perceptron.inc +# +# ## 4. Intellectual Property +# **4.1. Trademarks.** No trademark licenses are granted under this Agreement, and in connection with the Perceptron Models, You may not use any name or mark owned by or associated with Perceptron, Inc. or any of its affiliates, except (i) as required for reasonable and customary use in describing and Distributing the Perceptron Models and Derivatives made by or for Perceptron, Inc. and (ii) for attribution purposes as required by this Agreement. +# +# **4.2. Outputs.** We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs You generate and their subsequent uses in accordance with this Agreement. +# +# **4.3. Derivatives.** By entering into this Agreement, You accept that any Derivatives that You may create or that may be created for You shall be subject to the restrictions set out in Section 3 of this Agreement. +# +# # 5. Liability +# **5.1. Limitation of liability.** In no event, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall Perceptron, Inc. be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Perceptron Models and Derivatives (including but not limited to damages for loss of data, loss of goodwill, loss of expected profit or savings, work stoppage, computer failure or malfunction, or any damage caused by malware or security breaches), even if Perceptron, Inc. has been advised of the possibility of such damages. +# +# **5.2. Indemnification.** You agree to indemnify and hold harmless Perceptron, Inc. from and against any claims, damages, or losses arising out of or related to Your use or Distribution of the Perceptron Models and Derivatives. +# +# ## 6. Warranty +# **6.1. Disclaimer.** Unless required by applicable law or agreed to in writing, Perceptron, Inc. provides the Perceptron Models and Derivatives on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. Perceptron, Inc. does not represent nor warrant that the Perceptron Models and Derivatives will be error-free, meet Your or any third partyโ€™s requirements, be secure or will allow You or any third party to achieve any kind of result or generate any kind of content. You are solely responsible for determining the appropriateness of using or Distributing the Perceptron Models and Derivatives and assume any risks associated with Your exercise of rights under this Agreement. +# +# # 7. Termination +# **7.1. Term.** This Agreement is effective as of the date of your acceptance of this Agreement or access to the concerned Perceptron Models or Derivatives and will continue until terminated in accordance with the following terms. +# +# **7.2. Termination.** Perceptron, Inc. may terminate this Agreement at any time if You are in breach of this Agreement. Upon termination of this Agreement, You must cease to use all Perceptron Models and Derivatives and shall permanently delete any copy thereof. Sections 5, 6, 7 and 8 shall survive the termination of this Agreement. +# +# **7.3. Litigation.** If You initiate any legal action or proceedings against Us or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the Model or a Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement will immediately terminate as of the date such legal action or claim is filed or initiated. +# +# # 8. General provisions +# 8.1. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the State of Washington, without regard to its conflict of law principles. +# +# 8.2. Jurisdiction. The state and federal courts located in King County, Washington shall have exclusive jurisdiction over any dispute arising out of or relating to this Agreement, and You and We consent to personal jurisdiction and venue in such courts. +# +# **8.3. Severability.** If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. +# +# # 9. Definitions +# **โ€œAgreementโ€**: means this Perceptron, Inc. Non-Production License agreement governing the access, use, and Distribution of the Perceptron Models and Derivatives. +# +# **โ€œDerivativeโ€**: means any (i) modified version of the Perceptron Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the Perceptron Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered as Derivatives under this Agreement. +# +# **โ€œDistributionโ€**, **โ€œDistributingโ€**, **โ€œDistributeโ€** or **โ€œDistributedโ€**: means providing or making available, by any means, a copy of the Perceptron Models and/or the Derivatives as the case may be, subject to Section 3 of this Agreement. +# +# **โ€œPerceptron, Inc.โ€**, **โ€œWeโ€** or **โ€œUsโ€**: means Perceptron, Inc., a Delaware corporation with its principal place of business at 10900 NE 8th St Suite 613, Bellevue, WA 98004. +# +# **โ€œPerceptron Modelโ€**: means the foundational large language model(s), and its elements which include algorithms, software, instructed checkpoints, parameters, source code (inference code, evaluation code and, if applicable, fine-tuning code) and any other elements associated thereto made available by Perceptron, Inc. under this Agreement, including, if any, the technical documentation, manuals and instructions for the use and operation thereof. +# +# **โ€œNon-Production Environmentโ€**: means any setting, use case, or application of the Perceptron Models or Derivatives that expressly excludes live, real-world conditions, commercial operations, revenue-generating activities, or direct interactions with or impacts on end users (such as, for instance, Your employees or customers). Non-Production Environment may include, but is not limited to, any setting, use case, or application for research, development, testing, quality assurance, training, internal evaluation (other than any internal usage by employees in the context of the companyโ€™s business activities), and demonstration purposes. +# +# **โ€œOutputsโ€**: means any content generated by the operation of the Perceptron Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a Perceptron Models, such as any fine-tuned versions of the Perceptron Models, the weights, or parameters. +# +# **โ€œPersonalโ€**: means any use of a Perceptron Model or a Derivative that is (i) solely for personal, non-profit and non-commercial purposes and (ii) not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities. For illustration purposes, Personal use of a Model or a Derivative does not include any usage by individuals employed in companies in the context of their daily tasks, any activity that is intended to generate revenue, or that is performed on behalf of a commercial entity. +# +# **โ€œYouโ€**: means the individual or entity entering into this Agreement with Perceptron, Inc.. + import math from collections.abc import Sequence -from typing import Any +from typing import Any, Optional, Union import torch - -from ...feature_extraction_utils import BatchFeature -from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict -from ...image_utils import PILImageResampling +import torch.nn.functional as F + +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + BatchFeature, + SizeDict, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ChannelDimension, PILImageResampling from ...processing_utils import Unpack -from ...tokenization_utils import TensorType -from ...utils import auto_docstring -from .image_processing_isaac import IsaacImageProcessorKwargs - +from ...utils import TensorType, auto_docstring # Vision preprocessing constants -VISION_MEAN = (0.5, 0.5, 0.5) -VISION_STD = (0.5, 0.5, 0.5) -VISION_SCALE = 1 / 255 - - -def _normalize_rgb_values( - values: float | Sequence[float] | tuple[float, ...], - *, - name: str, -) -> tuple[float, float, float]: - """Coerce RGB normalization parameters into a 3-tuple of floats.""" - if isinstance(values, (list, tuple)): - if len(values) == 3: - return tuple(float(v) for v in values) - if len(values) == 1: - value = float(values[0]) - return (value, value, value) - raise ValueError(f"`{name}` must have length 1 or 3 when provided as a sequence. Got length {len(values)}.") - - value = float(values) - return (value, value, value) - - -# ============================================================================ -# Configuration -# ============================================================================ - -MAX_PIXELS = 60_000_000 # 60โ€‘megapixel ceiling โ‰ˆ 8200 ร— 7300 px +from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN +from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD +from .processing_isaac import IsaacImageProcessorKwargs def get_scaled_image_size( @@ -182,104 +240,140 @@ def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> @auto_docstring class IsaacImageProcessorFast(BaseImageProcessorFast): + MAX_PIXELS = 60_000_000 # 60โ€‘megapixel ceiling โ‰ˆ 8200 ร— 7300 px r"""Fast torch-based image processor for Isaac vision inputs.""" - slow_image_processor_class = "IsaacImageProcessor" - resample = PILImageResampling.BILINEAR model_input_names = ["patches", "token_grids"] valid_kwargs = IsaacImageProcessorKwargs unused_kwargs = ["size", "do_center_crop", "crop_size"] + do_resize = True + size: SizeDict | None = None + default_to_square: bool | None = None + do_center_crop = False + crop_size: SizeDict | None = None + patch_size: int | None = 16 + max_num_patches: int | None = 256 + min_num_patches: int | None = None + pixel_shuffle_scale: int | None = 1 + do_pad = False + pad_size: SizeDict | None = None + do_rescale = True + rescale_factor = 1 / 255 + do_normalize = True + image_mean = list(VISION_MEAN) + image_std = list(VISION_STD) + do_convert_rgb = True + return_tensors = None + data_format = ChannelDimension.FIRST + input_data_format = None + device = None + disable_grouping = False + size_divisor: int | None = None + def __init__( self, - *, - patch_size: int = 16, - max_num_patches: int = 256, - min_num_patches: int | None = None, - pixel_shuffle_scale: int = 1, - do_rescale: bool = True, - rescale_factor: float | None = None, - do_normalize: bool = True, - image_mean: float | Sequence[float] | None = None, - image_std: float | Sequence[float] | None = None, - do_convert_rgb: bool = True, **kwargs: Unpack[IsaacImageProcessorKwargs], ) -> None: super().__init__(**kwargs) + pixel_shuffle_scale = 1 if self.pixel_shuffle_scale is None else int(self.pixel_shuffle_scale) if pixel_shuffle_scale < 1: raise ValueError("`pixel_shuffle_scale` must be >= 1") - - mean_values = _normalize_rgb_values(image_mean if image_mean is not None else VISION_MEAN, name="image_mean") - std_values = _normalize_rgb_values(image_std if image_std is not None else VISION_STD, name="image_std") - - self.patch_size = patch_size - self.max_num_patches = max_num_patches - self.min_num_patches = min_num_patches self.pixel_shuffle_scale = pixel_shuffle_scale - self.do_rescale = do_rescale - self.rescale_factor = VISION_SCALE if rescale_factor is None else float(rescale_factor) - self.do_normalize = do_normalize - self.image_mean = list(mean_values) - self.image_std = list(std_values) - self.do_convert_rgb = do_convert_rgb def _validate_preprocess_kwargs(self, **kwargs): # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for. kwargs.pop("do_resize", None) + kwargs.pop("size", None) + kwargs.pop("do_center_crop", None) + kwargs.pop("crop_size", None) + kwargs.pop("disable_grouping", None) return super()._validate_preprocess_kwargs(**kwargs) + def resize( + self, + image: torch.Tensor, + size: SizeDict, + interpolation: Optional[Any] = None, + antialias: bool = True, + **kwargs, + ) -> torch.Tensor: + if size.height is None or size.width is None: + raise ValueError("IsaacImageProcessorFast requires explicit `height` and `width` when resizing.") + + resize_mode: Any = interpolation + if hasattr(resize_mode, "value"): + resize_mode = resize_mode.value + elif hasattr(resize_mode, "name"): + resize_mode = resize_mode.name.lower() + elif resize_mode is None: + resize_mode = "bilinear" + + if isinstance(resize_mode, str): + mode_key = resize_mode.lower() + else: + mode_key = resize_mode + + resize_kwargs: dict[str, Any] = {} + if mode_key in {"linear", "bilinear", "bicubic", "trilinear"}: + resize_kwargs["align_corners"] = False + + return F.interpolate( + image, + size=(size.height, size.width), + mode=resize_mode, + **resize_kwargs, + ) + def _preprocess( self, - images: list["torch.Tensor"], + images: list[torch.Tensor], do_resize: bool, - patch_size: int, - max_num_patches: int, - interpolation: Any | None, - do_rescale: bool, - rescale_factor: float, - do_normalize: bool, - image_mean: float | Sequence[float] | None, - image_std: float | Sequence[float] | None, - return_tensors: str | TensorType | None, + size: Optional[SizeDict], + interpolation: Optional[Any], + do_center_crop: bool, + crop_size: Optional[SizeDict], + do_rescale: Optional[bool], + rescale_factor: Optional[float], + do_normalize: Optional[bool], + image_mean: Optional[Union[float, Sequence[float]]], + image_std: Optional[Union[float, Sequence[float]]], + disable_grouping: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + do_pad: Optional[bool] = None, + pad_size: Optional[SizeDict] = None, *, + patch_size: int | None = None, + max_num_patches: int | None = None, min_num_patches: int | None = None, pixel_shuffle_scale: int | None = None, - do_convert_rgb: bool | None = None, **kwargs, ) -> BatchFeature: - if TVF is None: - raise ImportError("torchvision is required for IsaacImageProcessorFast but is not installed.") - - min_num_patches = min_num_patches if min_num_patches is not None else self.min_num_patches - pixel_shuffle_scale = pixel_shuffle_scale if pixel_shuffle_scale is not None else self.pixel_shuffle_scale - do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - do_rescale = self.do_rescale if do_rescale is None else do_rescale - do_normalize = self.do_normalize if do_normalize is None else do_normalize - rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor - - mean_values = _normalize_rgb_values( - image_mean if image_mean is not None else self.image_mean, name="image_mean" - ) - std_values = _normalize_rgb_values(image_std if image_std is not None else self.image_std, name="image_std") + if do_center_crop: + raise ValueError("`do_center_crop` is not supported by IsaacImageProcessorFast.") + if do_pad: + raise ValueError("`do_pad` is not supported by IsaacImageProcessorFast.") + + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + processed_patches_grouped: dict[tuple[int, ...], torch.Tensor] = {} + token_grids_grouped: dict[tuple[int, ...], torch.Tensor] = {} + virtual_dims_grouped: dict[tuple[int, ...], torch.Tensor] = {} + real_dims_grouped: dict[tuple[int, ...], torch.Tensor] = {} - patches_list: list[torch.Tensor] = [] - token_grids: list[torch.Tensor] = [] - virtual_dims: list[list[int]] = [] - real_dims: list[list[int]] = [] + for shape, stacked_images in grouped_images.items(): + if stacked_images.ndim != 4: + raise ValueError("Expected batched channel-first image tensors.") - for image in images: - if image.ndim != 3: - raise ValueError("Expected channel-first image tensor with shape (C, H, W).") + batch_size, channels, original_height, original_width = stacked_images.shape - channels, original_height, original_width = image.shape - if do_convert_rgb and channels == 1: - image = image.repeat(3, 1, 1) + if bool(self.do_convert_rgb) and channels == 1: + stacked_images = stacked_images.repeat(1, 3, 1, 1) channels = 3 - if original_height * original_width > MAX_PIXELS: - raise ValueError(f"Image (w={original_width}, h={original_height}) > MAX=`{MAX_PIXELS}`") + if original_height * original_width > self.MAX_PIXELS: + raise ValueError(f"Image (w={original_width}, h={original_height}) > MAX=`{self.MAX_PIXELS}`") target_height, target_width = get_image_size_for_max_num_patches( original_height, @@ -291,48 +385,91 @@ def _preprocess( ) if do_resize: - size_dict = SizeDict(height=target_height, width=target_width) - image = self.resize(image=image, size=size_dict, interpolation=interpolation) + resize_size = SizeDict(height=target_height, width=target_width) + image_batch = self.resize( + image=stacked_images, + size=resize_size, + interpolation=interpolation, + ) else: if ((original_height % patch_size) != 0) or ((original_width % patch_size) != 0): raise ValueError("Image dimensions must be divisible by patch_size when resize is disabled.") - - # Apply rescaling and normalization as needed - image = self.rescale_and_normalize( - image, - do_rescale, - rescale_factor, - do_normalize, - list(mean_values), - list(std_values), + image_batch = stacked_images + target_height, target_width = original_height, original_width + + if do_rescale: + image_batch = self.rescale_and_normalize( + image_batch, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + ) + + nhwc_images = image_batch.permute(0, 2, 3, 1) + nhwc_images = _compute_residual_p_frames(nhwc_images, is_p_frame=[False] * batch_size) + + patches = patchify_vision(nhwc_images, patch_size=patch_size) + _, height_tokens, width_tokens, _ = patches.shape + + token_grid = ( + torch.tensor( + [height_tokens, width_tokens], + dtype=torch.long, + device=patches.device, + ) + .unsqueeze(0) + .repeat(batch_size, 1) ) - # Convert to NHWC for residual P-frame adjustment and patch extraction - nhwc_image = image.permute(1, 2, 0).unsqueeze(0) - nhwc_image = _compute_residual_p_frames(nhwc_image, is_p_frame=[False]) - - patches = patchify_vision(nhwc_image, patch_size=patch_size).squeeze(0) - height_tokens, width_tokens, _ = patches.shape - - patches_list.append(patches.unsqueeze(0)) - token_grids.append(torch.tensor([height_tokens, width_tokens], dtype=torch.long, device=patches.device)) + real_dim = ( + torch.tensor( + [1, height_tokens, width_tokens], + dtype=torch.long, + device=patches.device, + ) + .unsqueeze(0) + .repeat(batch_size, 1) + ) - real_dims.append([1, height_tokens, width_tokens]) if pixel_shuffle_scale > 1: if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): raise ValueError( "Spatial dimensions must be divisible by pixel_shuffle_scale when pixel shuffle is enabled." ) - virtual_dims.append([1, height_tokens // pixel_shuffle_scale, width_tokens // pixel_shuffle_scale]) + virtual_height = height_tokens // pixel_shuffle_scale + virtual_width = width_tokens // pixel_shuffle_scale else: - virtual_dims.append([1, height_tokens, width_tokens]) + virtual_height = height_tokens + virtual_width = width_tokens + + virtual_dim = ( + torch.tensor( + [1, virtual_height, virtual_width], + dtype=torch.long, + device=patches.device, + ) + .unsqueeze(0) + .repeat(batch_size, 1) + ) + + processed_patches_grouped[shape] = patches + token_grids_grouped[shape] = token_grid + virtual_dims_grouped[shape] = virtual_dim + real_dims_grouped[shape] = real_dim + + patches_slices = reorder_images(processed_patches_grouped, grouped_images_index) + token_grid_slices = reorder_images(token_grids_grouped, grouped_images_index) + virtual_dim_slices = reorder_images(virtual_dims_grouped, grouped_images_index) + real_dim_slices = reorder_images(real_dims_grouped, grouped_images_index) - patches_tensor = torch.cat(patches_list, dim=0) - token_grids_tensor = torch.stack(token_grids, dim=0) - virtual_dims_tensor = torch.tensor(virtual_dims, dtype=torch.long, device=patches_tensor.device) - real_dims_tensor = torch.tensor(real_dims, dtype=torch.long, device=patches_tensor.device) + patches_tensor = torch.stack(patches_slices, dim=0) + token_grids_tensor = torch.stack(token_grid_slices, dim=0) + virtual_dims_tensor = torch.stack(virtual_dim_slices, dim=0) + real_dims_tensor = torch.stack(real_dim_slices, dim=0) - batch_feature = BatchFeature( + return BatchFeature( data={ "patches": patches_tensor, "token_grids": token_grids_tensor, @@ -341,7 +478,6 @@ def _preprocess( }, tensor_type=return_tensors, ) - return batch_feature __all__ = ["IsaacImageProcessorFast"] diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 4ebf79738d7a..3b26d5ba7769 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -4,11 +4,90 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ +# Perceptron, Inc. Non-Production License + +### 1. Scope and acceptance + +# **1.1. Scope of the Agreement.** +# This Agreement applies to any use, modification, or Distribution of any Perceptron Model by You, regardless of the source You obtained a copy of such Perceptron Model. +# +# **1.2. Acceptance.** By accessing, using, modifying, Distributing a Perceptron Model, or by creating, using or distributing a Derivative of the Perceptron Model, You agree to be bound by this Agreement. +# +# **1.3. Acceptance on behalf of a third-party.** If You accept this Agreement on behalf of Your employer or another person or entity, You warrant and represent that You have the authority to act and accept this Agreement on their behalf. In such a case, the word โ€œYouโ€ in this Agreement will refer to Your employer or such other person or entity. +# +# ## 2. License +# **2.1. Grant of rights.** Subject to Section 3 below, Perceptron, Inc. hereby grants You a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable, limited license to use, copy, modify, and Distribute under the conditions provided in Section 2.2 below, the Perceptron Model and any Derivatives made by or for Perceptron, Inc. and to create Derivatives of the Perceptron Model. +# +# **2.2. Distribution of Perceptron Model and Derivatives made by or for Perceptron, Inc..** Subject to Section 3 below, You may Distribute copies of the Perceptron Model and/or Derivatives made by or for Perceptron, Inc., under the following conditions: +# - You must make available a copy of this Agreement to third-party recipients of the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. you Distribute, it being specified that any rights to use the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. shall be directly granted by Perceptron, Inc. to said third-party recipients pursuant to the Perceptron, Inc. Non-Production License agreement executed between these parties; +# - You must retain in all copies of the Perceptron Models the following attribution notice within a โ€œNoticeโ€ text file distributed as part of such copies: โ€œLicensed by Perceptron, Inc. under the Perceptron, Inc. Non-Production Licenseโ€. +# +# **2.3. Distribution of Derivatives made by or for You.** Subject to Section 3 below, You may Distribute any Derivatives made by or for You under additional or different terms and conditions, provided that: +# - In any event, the use and modification of Perceptron Model and/or Derivatives made by or for Perceptron, Inc. shall remain governed by the terms and conditions of this Agreement; +# - You include in any such Derivatives made by or for You prominent notices stating that You modified the concerned Perceptron Model; and +# - Any terms and conditions You impose on any third-party recipients relating to Derivatives made by or for You shall neither limit such third-party recipientsโ€™ use of the Perceptron Model or any Derivatives made by or for Perceptron, Inc. in accordance with the Perceptron, Inc. Non-Production License nor conflict with any of its terms and conditions. +# +# ## 3. Limitations +# **3.1. Misrepresentation.** You must not misrepresent or imply, through any means, that the Derivatives made by or for You and/or any modified version of the Perceptron Model You Distribute under your name and responsibility is an official product of Perceptron, Inc. or has been endorsed, approved or validated by Perceptron, Inc., unless You are authorized by Us to do so in writing. +# +# **3.2. Usage Limitation** +# - You shall only use the Perceptron Models and Derivatives (whether or not created by Perceptron, Inc.) for testing, research, Personal, or evaluation purposes in Non-Production Environments; +# - Subject to the foregoing, You shall not supply the Perceptron Models or Derivatives in the course of a commercial activity, whether in return for payment or free of charge, in any medium or form, including but not limited to through a hosted or managed service (e.g. SaaS, cloud instances, etc.), or behind a software layer. +# +# **3.3. Usage not permitted under this Agreement.** If You want to use a Perceptron Model or a Derivative for any purpose that is not expressly authorized under this Agreement, You must request a license from Perceptron, Inc., which Perceptron, Inc. may grant to You in Perceptron, Inc.โ€™s sole discretion. Please contact Perceptron, Inc. at the following e-mail address if You want to discuss such a license: sales@perceptron.inc +# +# ## 4. Intellectual Property +# **4.1. Trademarks.** No trademark licenses are granted under this Agreement, and in connection with the Perceptron Models, You may not use any name or mark owned by or associated with Perceptron, Inc. or any of its affiliates, except (i) as required for reasonable and customary use in describing and Distributing the Perceptron Models and Derivatives made by or for Perceptron, Inc. and (ii) for attribution purposes as required by this Agreement. +# +# **4.2. Outputs.** We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs You generate and their subsequent uses in accordance with this Agreement. +# +# **4.3. Derivatives.** By entering into this Agreement, You accept that any Derivatives that You may create or that may be created for You shall be subject to the restrictions set out in Section 3 of this Agreement. +# +# # 5. Liability +# **5.1. Limitation of liability.** In no event, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall Perceptron, Inc. be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Perceptron Models and Derivatives (including but not limited to damages for loss of data, loss of goodwill, loss of expected profit or savings, work stoppage, computer failure or malfunction, or any damage caused by malware or security breaches), even if Perceptron, Inc. has been advised of the possibility of such damages. +# +# **5.2. Indemnification.** You agree to indemnify and hold harmless Perceptron, Inc. from and against any claims, damages, or losses arising out of or related to Your use or Distribution of the Perceptron Models and Derivatives. +# +# ## 6. Warranty +# **6.1. Disclaimer.** Unless required by applicable law or agreed to in writing, Perceptron, Inc. provides the Perceptron Models and Derivatives on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. Perceptron, Inc. does not represent nor warrant that the Perceptron Models and Derivatives will be error-free, meet Your or any third partyโ€™s requirements, be secure or will allow You or any third party to achieve any kind of result or generate any kind of content. You are solely responsible for determining the appropriateness of using or Distributing the Perceptron Models and Derivatives and assume any risks associated with Your exercise of rights under this Agreement. +# +# # 7. Termination +# **7.1. Term.** This Agreement is effective as of the date of your acceptance of this Agreement or access to the concerned Perceptron Models or Derivatives and will continue until terminated in accordance with the following terms. +# +# **7.2. Termination.** Perceptron, Inc. may terminate this Agreement at any time if You are in breach of this Agreement. Upon termination of this Agreement, You must cease to use all Perceptron Models and Derivatives and shall permanently delete any copy thereof. Sections 5, 6, 7 and 8 shall survive the termination of this Agreement. +# +# **7.3. Litigation.** If You initiate any legal action or proceedings against Us or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the Model or a Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement will immediately terminate as of the date such legal action or claim is filed or initiated. +# +# # 8. General provisions +# 8.1. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the State of Washington, without regard to its conflict of law principles. +# +# 8.2. Jurisdiction. The state and federal courts located in King County, Washington shall have exclusive jurisdiction over any dispute arising out of or relating to this Agreement, and You and We consent to personal jurisdiction and venue in such courts. +# +# **8.3. Severability.** If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. +# +# # 9. Definitions +# **โ€œAgreementโ€**: means this Perceptron, Inc. Non-Production License agreement governing the access, use, and Distribution of the Perceptron Models and Derivatives. +# +# **โ€œDerivativeโ€**: means any (i) modified version of the Perceptron Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the Perceptron Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered as Derivatives under this Agreement. +# +# **โ€œDistributionโ€**, **โ€œDistributingโ€**, **โ€œDistributeโ€** or **โ€œDistributedโ€**: means providing or making available, by any means, a copy of the Perceptron Models and/or the Derivatives as the case may be, subject to Section 3 of this Agreement. +# +# **โ€œPerceptron, Inc.โ€**, **โ€œWeโ€** or **โ€œUsโ€**: means Perceptron, Inc., a Delaware corporation with its principal place of business at 10900 NE 8th St Suite 613, Bellevue, WA 98004. +# +# **โ€œPerceptron Modelโ€**: means the foundational large language model(s), and its elements which include algorithms, software, instructed checkpoints, parameters, source code (inference code, evaluation code and, if applicable, fine-tuning code) and any other elements associated thereto made available by Perceptron, Inc. under this Agreement, including, if any, the technical documentation, manuals and instructions for the use and operation thereof. +# +# **โ€œNon-Production Environmentโ€**: means any setting, use case, or application of the Perceptron Models or Derivatives that expressly excludes live, real-world conditions, commercial operations, revenue-generating activities, or direct interactions with or impacts on end users (such as, for instance, Your employees or customers). Non-Production Environment may include, but is not limited to, any setting, use case, or application for research, development, testing, quality assurance, training, internal evaluation (other than any internal usage by employees in the context of the companyโ€™s business activities), and demonstration purposes. +# +# **โ€œOutputsโ€**: means any content generated by the operation of the Perceptron Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a Perceptron Models, such as any fine-tuned versions of the Perceptron Models, the weights, or parameters. +# +# **โ€œPersonalโ€**: means any use of a Perceptron Model or a Derivative that is (i) solely for personal, non-profit and non-commercial purposes and (ii) not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities. For illustration purposes, Personal use of a Model or a Derivative does not include any usage by individuals employed in companies in the context of their daily tasks, any activity that is intended to generate revenue, or that is performed on behalf of a commercial entity. +# +# **โ€œYouโ€**: means the individual or entity entering into this Agreement with Perceptron, Inc.. + import copy from collections import defaultdict -from collections.abc import Callable -from typing import Any, Optional, TypedDict +from typing import Any, Callable, Optional, TypedDict import torch import torch.nn as nn @@ -34,7 +113,7 @@ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_torchdynamo_compiling -from ..auto.modeling_auto import AutoModel +from ..auto import AutoModel from .configuration_isaac import IsaacConfig, IsaacVisionConfig @@ -100,100 +179,51 @@ def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor): return embeddings -def _max_from_cu(cu: torch.Tensor | None, fallback: int) -> int: - """Helper to compute max sequence length from cumulative sequence lengths.""" - if cu is None or len(cu) < 2: - return fallback - return int((cu[1:] - cu[:-1]).max().item()) - - -def flash_attention_document_mask_forward( - module: torch.nn.Module, - q_lhd: torch.Tensor, # (L, H, D) - k_lhd: torch.Tensor, # (L, H, D) - v_lhd: torch.Tensor, # (L, H, D) - attention_mask: torch.Tensor | None = None, # unused for FA path - dropout: float = 0.0, - scaling: float | None = None, - cum_seq_q: torch.Tensor | None = None, - cum_seq_k: torch.Tensor | None = None, - max_seqlen: int | None = None, - is_causal: bool = False, - **kwargs, -) -> tuple[torch.Tensor, None]: - """FlashAttention that consumes (L, H, D) directly to avoid layout churn.""" - L, H, D = q_lhd.shape - - # Compute max block length once (honor caller when provided) - if max_seqlen is not None: - max_q = max_k = int(max_seqlen) - else: - max_q = _max_from_cu(cum_seq_q, L) - max_k = _max_from_cu(cum_seq_k, L) - - # Ensure contiguity only if needed - if not q_lhd.is_contiguous(): - q_lhd = q_lhd.contiguous() - if not k_lhd.is_contiguous(): - k_lhd = k_lhd.contiguous() - if not v_lhd.is_contiguous(): - v_lhd = v_lhd.contiguous() - - out_lhd, *_ = torch.ops.aten._flash_attention_forward( - query=q_lhd, # (L, H, D) - key=k_lhd, # (L, H, D) - value=v_lhd, # (L, H, D) - cum_seq_q=cum_seq_q, - cum_seq_k=cum_seq_k, - max_q=max_q, - max_k=max_k, - dropout_p=dropout, - is_causal=is_causal, - return_debug_mask=False, - scale=scaling, - window_size_left=-1, - window_size_right=-1, - alibi_slopes=None, - ) - return out_lhd, None # (L, H, D) - - -def sdpa_document_mask_forward( - q_lhd: torch.Tensor, # (L, H, D) - k_lhd: torch.Tensor, # (L, H, D) - v_lhd: torch.Tensor, # (L, H, D) - dropout: float, - scaling: float | None, +def build_document_attention_mask( cu_seqlens: torch.Tensor | None, -) -> torch.Tensor: - """SDPA with block-diagonal masking for variable-length sequences.""" - L, H, D = q_lhd.shape + total_tokens: int, + dtype: torch.dtype, + device: torch.device, +) -> torch.Tensor | None: + """Creates an additive attention mask that blocks cross-document attention.""" - # Transpose to (1, H, L, D) format for SDPA - Q = q_lhd.permute(1, 0, 2).unsqueeze(0) - K = k_lhd.permute(1, 0, 2).unsqueeze(0) - V = v_lhd.permute(1, 0, 2).unsqueeze(0) + if cu_seqlens is None: + return None - # Build block-diagonal mask for variable-length sequences - attn_mask = None - if cu_seqlens is not None: - seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() - seg_ids = torch.repeat_interleave(torch.arange(len(seq_sizes), device=q_lhd.device), seq_sizes) - block_mask = seg_ids[:, None] != seg_ids[None, :] # Cross-document attention blocked - attn_mask = torch.where(block_mask, -torch.inf, 0.0).to(q_lhd.dtype).view(1, 1, L, L) + if cu_seqlens.numel() < 2: + return None - Y = F.scaled_dot_product_attention(Q, K, V, attn_mask=attn_mask, dropout_p=dropout, scale=scaling) - return Y.squeeze(0).permute(1, 0, 2) # Back to (L, H, D) + seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() + if seq_sizes.numel() == 0: + return None + + seg_ids = torch.repeat_interleave(torch.arange(seq_sizes.numel(), device=device), seq_sizes) + block_mask = seg_ids[:, None] != seg_ids[None, :] + additive_mask = torch.zeros((total_tokens, total_tokens), dtype=dtype, device=device) + additive_mask.masked_fill_(block_mask, float("-inf")) + return additive_mask.view(1, 1, total_tokens, total_tokens) class IsaacVisionAttention(nn.Module): """Custom attention that supports variable-length sequences with flash attention.""" - def __init__(self, config): + ATTENTION_KEY_MAP: dict[str, str] = { + "flash_attention_2": "isaac_flash_attention_2", + "flash_attention_3": "isaac_flash_attention_3", + "isaac_flash_attention_2": "isaac_flash_attention_2", + "isaac_flash_attention_3": "isaac_flash_attention_3", + "sdpa": "isaac_sdpa", + "isaac_sdpa": "isaac_sdpa", + "eager": "isaac_eager", + "isaac_eager": "isaac_eager", + } + + def __init__(self, vision_config): super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads + self.vision_config = vision_config + self.config = vision_config + self.embed_dim = vision_config.hidden_size + self.num_heads = vision_config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( @@ -201,7 +231,9 @@ def __init__(self, config): f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 - self.dropout = config.attention_dropout + self.dropout = vision_config.attention_dropout + self.is_causal = False + self.num_key_value_groups = 1 self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) @@ -224,24 +256,46 @@ def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None): k = self.k_proj(x).view(L, H, D) v = self.v_proj(x).view(L, H, D) - attn_impl = getattr(self.config, "_attn_implementation", "flash_attention_3") - - if attn_impl in ("flash_attention_2", "flash_attention_3"): - y_lhd, _ = flash_attention_document_mask_forward( - self, - q, - k, - v, - attention_mask=None, - dropout=p_drop, - scaling=self.scale, - cum_seq_q=cu_seqlens, - cum_seq_k=cu_seqlens, - max_seqlen=max_seqlen, - is_causal=False, - ) - else: - y_lhd = sdpa_document_mask_forward(q, k, v, dropout=p_drop, scaling=self.scale, cu_seqlens=cu_seqlens) + attn_impl = getattr(self.vision_config, "_attn_implementation", "flash_attention_3") + + attn_mask = build_document_attention_mask( + cu_seqlens=cu_seqlens, + total_tokens=L, + dtype=q.dtype, + device=q.device, + ) + + resolved_key = self.ATTENTION_KEY_MAP.get(attn_impl) + attention_fn = ALL_ATTENTION_FUNCTIONS.get(resolved_key) if resolved_key is not None else None + if attention_fn is None: + raise ValueError(f"Attention implementation {attn_impl} not found.") + + query_states = q.transpose(0, 1).unsqueeze(0) + key_states = k.transpose(0, 1).unsqueeze(0) + value_states = v.transpose(0, 1).unsqueeze(0) + + attention_kwargs: dict[str, Any] = { + "dropout": p_drop, + "scaling": self.scale, + "is_causal": False, + } + if cu_seqlens is not None: + attention_kwargs["cu_seq_lens_q"] = cu_seqlens + attention_kwargs["cu_seq_lens_k"] = cu_seqlens + if max_seqlen is not None: + attention_kwargs["max_length_q"] = max_seqlen + attention_kwargs["max_length_k"] = max_seqlen + + attn_output, _ = attention_fn( + self, + query_states, + key_states, + value_states, + attn_mask, + **attention_kwargs, + ) + + y_lhd = attn_output.squeeze(0).permute(1, 0, 2).contiguous() # Merge heads and project y = self.out_proj(y_lhd.reshape(L, self.embed_dim)) @@ -251,21 +305,21 @@ def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None): class IsaacVisionEncoderLayer(HFSiglip2EncoderLayer): """Isaac vision encoder layer with variable-length attention.""" - def __init__(self, config: IsaacVisionConfig): - super().__init__(config) - self.self_attn = IsaacVisionAttention(config) + def __init__(self, vision_config: IsaacVisionConfig): + super().__init__(vision_config) + self.self_attn = IsaacVisionAttention(vision_config) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor = None, max_seqlen: int = None, - ) -> tuple[torch.FloatTensor]: + ) -> torch.Tensor: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states, attn_weights = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, @@ -278,7 +332,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - return (hidden_states,) + return hidden_states class IsaacVisionEncoder(nn.Module): @@ -304,14 +358,12 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = encoder_layer( + hidden_states = encoder_layer( hidden_states, cu_seqlens, max_seqlen, ) - hidden_states = layer_outputs[0] - if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -558,7 +610,8 @@ def __init__(self, config: IsaacConfig, device=None): super().__init__() self.config = config - rope_scaling = getattr(config, "rope_scaling", None) or {} + rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config + rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) if rope_type not in ROPE_INIT_FUNCTIONS: raise ValueError(f"Unsupported rope_type '{rope_type}' for IsaacRotaryEmbedding") @@ -568,10 +621,10 @@ def __init__(self, config: IsaacConfig, device=None): sanitized_scaling = {k: v for k, v in rope_scaling.items() if k not in self.EXTRA_ROPE_KEYS} if sanitized_scaling != rope_scaling: - config_for_rope = copy.copy(config) + config_for_rope = copy.copy(rope_source_cfg) config_for_rope.rope_scaling = sanitized_scaling else: - config_for_rope = config + config_for_rope = rope_source_cfg init_device = device if device is not None and getattr(device, "type", None) != "meta" else None inv_freq, attention_scaling = rope_init_fn(config_for_rope, device=init_device) @@ -697,16 +750,21 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: +def repeat_kv(hidden_states: torch.Tensor, num_key_value_groups: int) -> torch.Tensor: + """Repeat key/value heads for grouped-query attention.""" + + if num_key_value_groups == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, + num_key_value_heads, + num_key_value_groups, + seq_len, + head_dim, + ) + return hidden_states.reshape(batch, num_key_value_heads * num_key_value_groups, seq_len, head_dim) def eager_attention_forward( @@ -910,19 +968,12 @@ def __init__(self, config: IsaacConfig): self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) - vision_cfg = config.vision_config - # Use vision_attn_implementation if specified, otherwise fall back to general attn_implementation - vision_cfg._attn_implementation = ( - config.vision_attn_implementation - if config.vision_attn_implementation is not None - else config._attn_implementation - ) - if vision_cfg is None: + if config.vision_config is None: raise ValueError("IsaacConfig should always have vision_config") - hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2) + hidden_dim = config.vision_config.hidden_size * (config.vision_config.pixel_shuffle_scale_factor**2) self.vision_embedding = nn.Sequential( - IsaacVisionTransformer(vision_cfg), + IsaacVisionTransformer(config.vision_config), nn.Linear( hidden_dim, 4 * hidden_dim, @@ -1287,7 +1338,6 @@ class IsaacForConditionalGeneration(IsaacPreTrainedModel, GenerationMixin): def __init__(self, config: IsaacConfig): super().__init__(config) - self.model = IsaacModel(config) # Use our custom model self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index c2781962fe6d..894832a36a9d 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -4,9 +4,88 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ +# Perceptron, Inc. Non-Production License + +### 1. Scope and acceptance + +# **1.1. Scope of the Agreement.** +# This Agreement applies to any use, modification, or Distribution of any Perceptron Model by You, regardless of the source You obtained a copy of such Perceptron Model. +# +# **1.2. Acceptance.** By accessing, using, modifying, Distributing a Perceptron Model, or by creating, using or distributing a Derivative of the Perceptron Model, You agree to be bound by this Agreement. +# +# **1.3. Acceptance on behalf of a third-party.** If You accept this Agreement on behalf of Your employer or another person or entity, You warrant and represent that You have the authority to act and accept this Agreement on their behalf. In such a case, the word โ€œYouโ€ in this Agreement will refer to Your employer or such other person or entity. +# +# ## 2. License +# **2.1. Grant of rights.** Subject to Section 3 below, Perceptron, Inc. hereby grants You a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable, limited license to use, copy, modify, and Distribute under the conditions provided in Section 2.2 below, the Perceptron Model and any Derivatives made by or for Perceptron, Inc. and to create Derivatives of the Perceptron Model. +# +# **2.2. Distribution of Perceptron Model and Derivatives made by or for Perceptron, Inc..** Subject to Section 3 below, You may Distribute copies of the Perceptron Model and/or Derivatives made by or for Perceptron, Inc., under the following conditions: +# - You must make available a copy of this Agreement to third-party recipients of the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. you Distribute, it being specified that any rights to use the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. shall be directly granted by Perceptron, Inc. to said third-party recipients pursuant to the Perceptron, Inc. Non-Production License agreement executed between these parties; +# - You must retain in all copies of the Perceptron Models the following attribution notice within a โ€œNoticeโ€ text file distributed as part of such copies: โ€œLicensed by Perceptron, Inc. under the Perceptron, Inc. Non-Production Licenseโ€. +# +# **2.3. Distribution of Derivatives made by or for You.** Subject to Section 3 below, You may Distribute any Derivatives made by or for You under additional or different terms and conditions, provided that: +# - In any event, the use and modification of Perceptron Model and/or Derivatives made by or for Perceptron, Inc. shall remain governed by the terms and conditions of this Agreement; +# - You include in any such Derivatives made by or for You prominent notices stating that You modified the concerned Perceptron Model; and +# - Any terms and conditions You impose on any third-party recipients relating to Derivatives made by or for You shall neither limit such third-party recipientsโ€™ use of the Perceptron Model or any Derivatives made by or for Perceptron, Inc. in accordance with the Perceptron, Inc. Non-Production License nor conflict with any of its terms and conditions. +# +# ## 3. Limitations +# **3.1. Misrepresentation.** You must not misrepresent or imply, through any means, that the Derivatives made by or for You and/or any modified version of the Perceptron Model You Distribute under your name and responsibility is an official product of Perceptron, Inc. or has been endorsed, approved or validated by Perceptron, Inc., unless You are authorized by Us to do so in writing. +# +# **3.2. Usage Limitation** +# - You shall only use the Perceptron Models and Derivatives (whether or not created by Perceptron, Inc.) for testing, research, Personal, or evaluation purposes in Non-Production Environments; +# - Subject to the foregoing, You shall not supply the Perceptron Models or Derivatives in the course of a commercial activity, whether in return for payment or free of charge, in any medium or form, including but not limited to through a hosted or managed service (e.g. SaaS, cloud instances, etc.), or behind a software layer. +# +# **3.3. Usage not permitted under this Agreement.** If You want to use a Perceptron Model or a Derivative for any purpose that is not expressly authorized under this Agreement, You must request a license from Perceptron, Inc., which Perceptron, Inc. may grant to You in Perceptron, Inc.โ€™s sole discretion. Please contact Perceptron, Inc. at the following e-mail address if You want to discuss such a license: sales@perceptron.inc +# +# ## 4. Intellectual Property +# **4.1. Trademarks.** No trademark licenses are granted under this Agreement, and in connection with the Perceptron Models, You may not use any name or mark owned by or associated with Perceptron, Inc. or any of its affiliates, except (i) as required for reasonable and customary use in describing and Distributing the Perceptron Models and Derivatives made by or for Perceptron, Inc. and (ii) for attribution purposes as required by this Agreement. +# +# **4.2. Outputs.** We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs You generate and their subsequent uses in accordance with this Agreement. +# +# **4.3. Derivatives.** By entering into this Agreement, You accept that any Derivatives that You may create or that may be created for You shall be subject to the restrictions set out in Section 3 of this Agreement. +# +# # 5. Liability +# **5.1. Limitation of liability.** In no event, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall Perceptron, Inc. be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Perceptron Models and Derivatives (including but not limited to damages for loss of data, loss of goodwill, loss of expected profit or savings, work stoppage, computer failure or malfunction, or any damage caused by malware or security breaches), even if Perceptron, Inc. has been advised of the possibility of such damages. +# +# **5.2. Indemnification.** You agree to indemnify and hold harmless Perceptron, Inc. from and against any claims, damages, or losses arising out of or related to Your use or Distribution of the Perceptron Models and Derivatives. +# +# ## 6. Warranty +# **6.1. Disclaimer.** Unless required by applicable law or agreed to in writing, Perceptron, Inc. provides the Perceptron Models and Derivatives on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. Perceptron, Inc. does not represent nor warrant that the Perceptron Models and Derivatives will be error-free, meet Your or any third partyโ€™s requirements, be secure or will allow You or any third party to achieve any kind of result or generate any kind of content. You are solely responsible for determining the appropriateness of using or Distributing the Perceptron Models and Derivatives and assume any risks associated with Your exercise of rights under this Agreement. +# +# # 7. Termination +# **7.1. Term.** This Agreement is effective as of the date of your acceptance of this Agreement or access to the concerned Perceptron Models or Derivatives and will continue until terminated in accordance with the following terms. +# +# **7.2. Termination.** Perceptron, Inc. may terminate this Agreement at any time if You are in breach of this Agreement. Upon termination of this Agreement, You must cease to use all Perceptron Models and Derivatives and shall permanently delete any copy thereof. Sections 5, 6, 7 and 8 shall survive the termination of this Agreement. +# +# **7.3. Litigation.** If You initiate any legal action or proceedings against Us or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the Model or a Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement will immediately terminate as of the date such legal action or claim is filed or initiated. +# +# # 8. General provisions +# 8.1. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the State of Washington, without regard to its conflict of law principles. +# +# 8.2. Jurisdiction. The state and federal courts located in King County, Washington shall have exclusive jurisdiction over any dispute arising out of or relating to this Agreement, and You and We consent to personal jurisdiction and venue in such courts. +# +# **8.3. Severability.** If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. +# +# # 9. Definitions +# **โ€œAgreementโ€**: means this Perceptron, Inc. Non-Production License agreement governing the access, use, and Distribution of the Perceptron Models and Derivatives. +# +# **โ€œDerivativeโ€**: means any (i) modified version of the Perceptron Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the Perceptron Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered as Derivatives under this Agreement. +# +# **โ€œDistributionโ€**, **โ€œDistributingโ€**, **โ€œDistributeโ€** or **โ€œDistributedโ€**: means providing or making available, by any means, a copy of the Perceptron Models and/or the Derivatives as the case may be, subject to Section 3 of this Agreement. +# +# **โ€œPerceptron, Inc.โ€**, **โ€œWeโ€** or **โ€œUsโ€**: means Perceptron, Inc., a Delaware corporation with its principal place of business at 10900 NE 8th St Suite 613, Bellevue, WA 98004. +# +# **โ€œPerceptron Modelโ€**: means the foundational large language model(s), and its elements which include algorithms, software, instructed checkpoints, parameters, source code (inference code, evaluation code and, if applicable, fine-tuning code) and any other elements associated thereto made available by Perceptron, Inc. under this Agreement, including, if any, the technical documentation, manuals and instructions for the use and operation thereof. +# +# **โ€œNon-Production Environmentโ€**: means any setting, use case, or application of the Perceptron Models or Derivatives that expressly excludes live, real-world conditions, commercial operations, revenue-generating activities, or direct interactions with or impacts on end users (such as, for instance, Your employees or customers). Non-Production Environment may include, but is not limited to, any setting, use case, or application for research, development, testing, quality assurance, training, internal evaluation (other than any internal usage by employees in the context of the companyโ€™s business activities), and demonstration purposes. +# +# **โ€œOutputsโ€**: means any content generated by the operation of the Perceptron Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a Perceptron Models, such as any fine-tuned versions of the Perceptron Models, the weights, or parameters. +# +# **โ€œPersonalโ€**: means any use of a Perceptron Model or a Derivative that is (i) solely for personal, non-profit and non-commercial purposes and (ii) not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities. For illustration purposes, Personal use of a Model or a Derivative does not include any usage by individuals employed in companies in the context of their daily tasks, any activity that is intended to generate revenue, or that is performed on behalf of a commercial entity. +# +# **โ€œYouโ€**: means the individual or entity entering into this Agreement with Perceptron, Inc.. + import math import re -from collections.abc import Sequence import PIL.Image import torch @@ -14,36 +93,18 @@ from genesis.public.tensorstream.tensor_stream_utils import slice as ts_slice from genesis.public.tensorstream.tensor_stream_utils import tensor_stream_token_view -from ...feature_extraction_utils import BatchFeature -from ...image_processing_utils_fast import BaseImageProcessorFast +from ...image_processing_utils_fast import BatchFeature, DefaultFastImageProcessorKwargs from ...processing_utils import ProcessorMixin -from ...tokenization_utils import TensorType +from ...utils import TensorType +from ..auto import AutoTokenizer from .configuration_isaac import IsaacConfig -from .image_processing_isaac import IsaacImageProcessor - - -# Vision preprocessing constants -VISION_MEAN = (0.5, 0.5, 0.5) -VISION_STD = (0.5, 0.5, 0.5) -VISION_SCALE = 1 / 255 -def _normalize_rgb_values( - values: float | Sequence[float] | tuple[float, ...], - *, - name: str, -) -> tuple[float, float, float]: - """Coerce RGB normalization parameters into a 3-tuple of floats.""" - if isinstance(values, (list, tuple)): - if len(values) == 3: - return tuple(float(v) for v in values) - if len(values) == 1: - value = float(values[0]) - return (value, value, value) - raise ValueError(f"`{name}` must have length 1 or 3 when provided as a sequence. Got length {len(values)}.") - - value = float(values) - return (value, value, value) +class IsaacImageProcessorKwargs(DefaultFastImageProcessorKwargs, total=False): + patch_size: int | None + max_num_patches: int | None + min_num_patches: int | None + pixel_shuffle_scale: int | None # ============================================================================ @@ -96,26 +157,18 @@ def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> class IsaacProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] - image_processor_class = ("IsaacImageProcessor", "IsaacImageProcessorFast") + image_processor_class = "IsaacImageProcessorFast" tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") def __init__( self, - image_processor: IsaacImageProcessor | BaseImageProcessorFast | None = None, + image_processor: "IsaacImageProcessorFast | None" = None, tokenizer: Qwen2Tokenizer | None = None, *, vision_token: str = "", max_sequence_length: int = 16384, - vision_patch_size: int = 16, - vision_max_num_patches: int = 256, - vision_min_num_patches: int | None = None, - pixel_shuffle_scale: int = 1, rescale_factor: float | None = None, - image_mean: float | Sequence[float] | None = None, - image_std: float | Sequence[float] | None = None, - vision_attn_implementation: str | None = None, config: IsaacConfig | dict | None = None, - **kwargs, ) -> None: if tokenizer is None: raise ValueError("`tokenizer` must be provided to initialize IsaacProcessor.") @@ -124,56 +177,16 @@ def __init__( config = IsaacConfig(**config) if config is not None: - vision_patch_size = config.video_patch_size - vision_max_num_patches = config.vision_max_num_patches - vision_min_num_patches = config.vision_min_num_patches - pixel_shuffle_scale = config.pixel_shuffle_scale max_sequence_length = config.max_sequence_length vision_token = config.vision_token - vision_attn_implementation = config.vision_attn_implementation rescale_factor = config.vision_rescale_factor - image_mean = tuple(config.vision_mean) - image_std = tuple(config.vision_std) - resolved_rescale_factor = float(rescale_factor) if rescale_factor is not None else float(VISION_SCALE) - resolved_image_mean = _normalize_rgb_values( - image_mean if image_mean is not None else VISION_MEAN, - name="image_mean", - ) - resolved_image_std = _normalize_rgb_values( - image_std if image_std is not None else VISION_STD, - name="image_std", - ) - - if image_processor is None: - image_processor = IsaacImageProcessor( - patch_size=vision_patch_size, - max_num_patches=vision_max_num_patches, - min_num_patches=vision_min_num_patches, - pixel_shuffle_scale=pixel_shuffle_scale, - rescale_factor=resolved_rescale_factor, - image_mean=resolved_image_mean, - image_std=resolved_image_std, - ) - else: - vision_patch_size = getattr(image_processor, "patch_size", vision_patch_size) - vision_max_num_patches = getattr(image_processor, "max_num_patches", vision_max_num_patches) - vision_min_num_patches = getattr(image_processor, "min_num_patches", vision_min_num_patches) - pixel_shuffle_scale = getattr(image_processor, "pixel_shuffle_scale", pixel_shuffle_scale) - resolved_rescale_factor = getattr(image_processor, "rescale_factor", resolved_rescale_factor) - resolved_image_mean = _normalize_rgb_values( - getattr(image_processor, "image_mean", resolved_image_mean), - name="image_mean", - ) - resolved_image_std = _normalize_rgb_values( - getattr(image_processor, "image_std", resolved_image_std), - name="image_std", - ) + resolved_rescale_factor = float(rescale_factor) if rescale_factor is not None else float(1 / 255) if config is not None: config.vision_rescale_factor = resolved_rescale_factor - config.vision_mean = resolved_image_mean - config.vision_std = resolved_image_std + + self.image_processor = image_processor super().__init__(image_processor, tokenizer) self.current_processor = self.image_processor @@ -184,15 +197,6 @@ def __init__( self.vision_token = vision_token self.max_sequence_length = max_sequence_length - self.vision_attn_implementation = vision_attn_implementation - - self.patch_size = getattr(self.image_processor, "patch_size", vision_patch_size) - self.max_num_patches = getattr(self.image_processor, "max_num_patches", vision_max_num_patches) - self.min_num_patches = getattr(self.image_processor, "min_num_patches", vision_min_num_patches) - self.pixel_shuffle_scale = getattr(self.image_processor, "pixel_shuffle_scale", pixel_shuffle_scale) - self.rescale_factor = getattr(self.image_processor, "rescale_factor", resolved_rescale_factor) - self.image_mean = tuple(getattr(self.image_processor, "image_mean", resolved_image_mean)) - self.image_std = tuple(getattr(self.image_processor, "image_std", resolved_image_std)) def build_event_stream_simple( self, From 3d8b786ab5e01d6e14d0f50c0578c87d1767da1c Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 09:41:58 +0400 Subject: [PATCH 09/77] feat: port updates to isaac --- .../models/isaac/modular_isaac.py | 608 +++++++++--------- 1 file changed, 304 insertions(+), 304 deletions(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index e750a2bc839f..d1c82fe30c53 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -1,20 +1,22 @@ -# Perceptron, Inc. Non-Production License +# Copyright (c) 2024 Perceptron, Inc. All rights reserved. +# Perceptron, Inc. Non-Production License (2024-01-01) + ### 1. Scope and acceptance # **1.1. Scope of the Agreement.** -# This Agreement applies to any use, modification, or Distribution of any Perceptron Model by You, regardless of the source You obtained a copy of such Perceptron Model. +# This Agreement applies to any use, modification, or Distribution of any Perceptron Model by You, regardless of the source You obtained a copy of such Perceptron Model. # -# **1.2. Acceptance.** By accessing, using, modifying, Distributing a Perceptron Model, or by creating, using or distributing a Derivative of the Perceptron Model, You agree to be bound by this Agreement. +# **1.2. Acceptance.** By accessing, using, modifying, Distributing a Perceptron Model, or by creating, using or distributing a Derivative of the Perceptron Model, You agree to be bound by this Agreement. # # **1.3. Acceptance on behalf of a third-party.** If You accept this Agreement on behalf of Your employer or another person or entity, You warrant and represent that You have the authority to act and accept this Agreement on their behalf. In such a case, the word โ€œYouโ€ in this Agreement will refer to Your employer or such other person or entity. # -# ## 2. License -# **2.1. Grant of rights.** Subject to Section 3 below, Perceptron, Inc. hereby grants You a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable, limited license to use, copy, modify, and Distribute under the conditions provided in Section 2.2 below, the Perceptron Model and any Derivatives made by or for Perceptron, Inc. and to create Derivatives of the Perceptron Model. +# ## 2. License +# **2.1. Grant of rights.** Subject to Section 3 below, Perceptron, Inc. hereby grants You a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable, limited license to use, copy, modify, and Distribute under the conditions provided in Section 2.2 below, the Perceptron Model and any Derivatives made by or for Perceptron, Inc. and to create Derivatives of the Perceptron Model. # -# **2.2. Distribution of Perceptron Model and Derivatives made by or for Perceptron, Inc..** Subject to Section 3 below, You may Distribute copies of the Perceptron Model and/or Derivatives made by or for Perceptron, Inc., under the following conditions: -# - You must make available a copy of this Agreement to third-party recipients of the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. you Distribute, it being specified that any rights to use the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. shall be directly granted by Perceptron, Inc. to said third-party recipients pursuant to the Perceptron, Inc. Non-Production License agreement executed between these parties; -# - You must retain in all copies of the Perceptron Models the following attribution notice within a โ€œNoticeโ€ text file distributed as part of such copies: โ€œLicensed by Perceptron, Inc. under the Perceptron, Inc. Non-Production Licenseโ€. +# **2.2. Distribution of Perceptron Model and Derivatives made by or for Perceptron, Inc..** Subject to Section 3 below, You may Distribute copies of the Perceptron Model and/or Derivatives made by or for Perceptron, Inc., under the following conditions: +# - You must make available a copy of this Agreement to third-party recipients of the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. you Distribute, it being specified that any rights to use the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. shall be directly granted by Perceptron, Inc. to said third-party recipients pursuant to the Perceptron, Inc. Non-Production License agreement executed between these parties; +# - You must retain in all copies of the Perceptron Models the following attribution notice within a โ€œNoticeโ€ text file distributed as part of such copies: โ€œLicensed by Perceptron, Inc. under the Perceptron, Inc. Non-Production Licenseโ€. # # **2.3. Distribution of Derivatives made by or for You.** Subject to Section 3 below, You may Distribute any Derivatives made by or for You under additional or different terms and conditions, provided that: # - In any event, the use and modification of Perceptron Model and/or Derivatives made by or for Perceptron, Inc. shall remain governed by the terms and conditions of this Agreement; @@ -22,12 +24,12 @@ # - Any terms and conditions You impose on any third-party recipients relating to Derivatives made by or for You shall neither limit such third-party recipientsโ€™ use of the Perceptron Model or any Derivatives made by or for Perceptron, Inc. in accordance with the Perceptron, Inc. Non-Production License nor conflict with any of its terms and conditions. # # ## 3. Limitations -# **3.1. Misrepresentation.** You must not misrepresent or imply, through any means, that the Derivatives made by or for You and/or any modified version of the Perceptron Model You Distribute under your name and responsibility is an official product of Perceptron, Inc. or has been endorsed, approved or validated by Perceptron, Inc., unless You are authorized by Us to do so in writing. +# **3.1. Misrepresentation.** You must not misrepresent or imply, through any means, that the Derivatives made by or for You and/or any modified version of the Perceptron Model You Distribute under your name and responsibility is an official product of Perceptron, Inc. or has been endorsed, approved or validated by Perceptron, Inc., unless You are authorized by Us to do so in writing. # # **3.2. Usage Limitation** -# - You shall only use the Perceptron Models and Derivatives (whether or not created by Perceptron, Inc.) for testing, research, Personal, or evaluation purposes in Non-Production Environments; +# - You shall only use the Perceptron Models and Derivatives (whether or not created by Perceptron, Inc.) for testing, research, Personal, or evaluation purposes in Non-Production Environments; # - Subject to the foregoing, You shall not supply the Perceptron Models or Derivatives in the course of a commercial activity, whether in return for payment or free of charge, in any medium or form, including but not limited to through a hosted or managed service (e.g. SaaS, cloud instances, etc.), or behind a software layer. -# +# # **3.3. Usage not permitted under this Agreement.** If You want to use a Perceptron Model or a Derivative for any purpose that is not expressly authorized under this Agreement, You must request a license from Perceptron, Inc., which Perceptron, Inc. may grant to You in Perceptron, Inc.โ€™s sole discretion. Please contact Perceptron, Inc. at the following e-mail address if You want to discuss such a license: sales@perceptron.inc # # ## 4. Intellectual Property @@ -35,10 +37,10 @@ # # **4.2. Outputs.** We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs You generate and their subsequent uses in accordance with this Agreement. # -# **4.3. Derivatives.** By entering into this Agreement, You accept that any Derivatives that You may create or that may be created for You shall be subject to the restrictions set out in Section 3 of this Agreement. +# **4.3. Derivatives.** By entering into this Agreement, You accept that any Derivatives that You may create or that may be created for You shall be subject to the restrictions set out in Section 3 of this Agreement. # # # 5. Liability -# **5.1. Limitation of liability.** In no event, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall Perceptron, Inc. be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Perceptron Models and Derivatives (including but not limited to damages for loss of data, loss of goodwill, loss of expected profit or savings, work stoppage, computer failure or malfunction, or any damage caused by malware or security breaches), even if Perceptron, Inc. has been advised of the possibility of such damages. +# **5.1. Limitation of liability.** In no event, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall Perceptron, Inc. be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Perceptron Models and Derivatives (including but not limited to damages for loss of data, loss of goodwill, loss of expected profit or savings, work stoppage, computer failure or malfunction, or any damage caused by malware or security breaches), even if Perceptron, Inc. has been advised of the possibility of such damages. # # **5.2. Indemnification.** You agree to indemnify and hold harmless Perceptron, Inc. from and against any claims, damages, or losses arising out of or related to Your use or Distribution of the Perceptron Models and Derivatives. # @@ -46,9 +48,9 @@ # **6.1. Disclaimer.** Unless required by applicable law or agreed to in writing, Perceptron, Inc. provides the Perceptron Models and Derivatives on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. Perceptron, Inc. does not represent nor warrant that the Perceptron Models and Derivatives will be error-free, meet Your or any third partyโ€™s requirements, be secure or will allow You or any third party to achieve any kind of result or generate any kind of content. You are solely responsible for determining the appropriateness of using or Distributing the Perceptron Models and Derivatives and assume any risks associated with Your exercise of rights under this Agreement. # # # 7. Termination -# **7.1. Term.** This Agreement is effective as of the date of your acceptance of this Agreement or access to the concerned Perceptron Models or Derivatives and will continue until terminated in accordance with the following terms. +# **7.1. Term.** This Agreement is effective as of the date of your acceptance of this Agreement or access to the concerned Perceptron Models or Derivatives and will continue until terminated in accordance with the following terms. # -# **7.2. Termination.** Perceptron, Inc. may terminate this Agreement at any time if You are in breach of this Agreement. Upon termination of this Agreement, You must cease to use all Perceptron Models and Derivatives and shall permanently delete any copy thereof. Sections 5, 6, 7 and 8 shall survive the termination of this Agreement. +# **7.2. Termination.** Perceptron, Inc. may terminate this Agreement at any time if You are in breach of this Agreement. Upon termination of this Agreement, You must cease to use all Perceptron Models and Derivatives and shall permanently delete any copy thereof. Sections 5, 6, 7 and 8 shall survive the termination of this Agreement. # # **7.3. Litigation.** If You initiate any legal action or proceedings against Us or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the Model or a Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement will immediately terminate as of the date such legal action or claim is filed or initiated. # @@ -59,22 +61,22 @@ # # **8.3. Severability.** If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. # -# # 9. Definitions -# **โ€œAgreementโ€**: means this Perceptron, Inc. Non-Production License agreement governing the access, use, and Distribution of the Perceptron Models and Derivatives. +# # 9. Definitions +# **โ€œAgreementโ€**: means this Perceptron, Inc. Non-Production License agreement governing the access, use, and Distribution of the Perceptron Models and Derivatives. # -# **โ€œDerivativeโ€**: means any (i) modified version of the Perceptron Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the Perceptron Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered as Derivatives under this Agreement. +# **โ€œDerivativeโ€**: means any (i) modified version of the Perceptron Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the Perceptron Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered as Derivatives under this Agreement. # -# **โ€œDistributionโ€**, **โ€œDistributingโ€**, **โ€œDistributeโ€** or **โ€œDistributedโ€**: means providing or making available, by any means, a copy of the Perceptron Models and/or the Derivatives as the case may be, subject to Section 3 of this Agreement. +# **โ€œDistributionโ€**, **โ€œDistributingโ€**, **โ€œDistributeโ€** or **โ€œDistributedโ€**: means providing or making available, by any means, a copy of the Perceptron Models and/or the Derivatives as the case may be, subject to Section 3 of this Agreement. # # **โ€œPerceptron, Inc.โ€**, **โ€œWeโ€** or **โ€œUsโ€**: means Perceptron, Inc., a Delaware corporation with its principal place of business at 10900 NE 8th St Suite 613, Bellevue, WA 98004. # -# **โ€œPerceptron Modelโ€**: means the foundational large language model(s), and its elements which include algorithms, software, instructed checkpoints, parameters, source code (inference code, evaluation code and, if applicable, fine-tuning code) and any other elements associated thereto made available by Perceptron, Inc. under this Agreement, including, if any, the technical documentation, manuals and instructions for the use and operation thereof. +# **โ€œPerceptron Modelโ€**: means the foundational large language model(s), and its elements which include algorithms, software, instructed checkpoints, parameters, source code (inference code, evaluation code and, if applicable, fine-tuning code) and any other elements associated thereto made available by Perceptron, Inc. under this Agreement, including, if any, the technical documentation, manuals and instructions for the use and operation thereof. # -# **โ€œNon-Production Environmentโ€**: means any setting, use case, or application of the Perceptron Models or Derivatives that expressly excludes live, real-world conditions, commercial operations, revenue-generating activities, or direct interactions with or impacts on end users (such as, for instance, Your employees or customers). Non-Production Environment may include, but is not limited to, any setting, use case, or application for research, development, testing, quality assurance, training, internal evaluation (other than any internal usage by employees in the context of the companyโ€™s business activities), and demonstration purposes. +# **โ€œNon-Production Environmentโ€**: means any setting, use case, or application of the Perceptron Models or Derivatives that expressly excludes live, real-world conditions, commercial operations, revenue-generating activities, or direct interactions with or impacts on end users (such as, for instance, Your employees or customers). Non-Production Environment may include, but is not limited to, any setting, use case, or application for research, development, testing, quality assurance, training, internal evaluation (other than any internal usage by employees in the context of the companyโ€™s business activities), and demonstration purposes. # -# **โ€œOutputsโ€**: means any content generated by the operation of the Perceptron Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a Perceptron Models, such as any fine-tuned versions of the Perceptron Models, the weights, or parameters. +# **โ€œOutputsโ€**: means any content generated by the operation of the Perceptron Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a Perceptron Models, such as any fine-tuned versions of the Perceptron Models, the weights, or parameters. # -# **โ€œPersonalโ€**: means any use of a Perceptron Model or a Derivative that is (i) solely for personal, non-profit and non-commercial purposes and (ii) not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities. For illustration purposes, Personal use of a Model or a Derivative does not include any usage by individuals employed in companies in the context of their daily tasks, any activity that is intended to generate revenue, or that is performed on behalf of a commercial entity. +# **โ€œPersonalโ€**: means any use of a Perceptron Model or a Derivative that is (i) solely for personal, non-profit and non-commercial purposes and (ii) not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities. For illustration purposes, Personal use of a Model or a Derivative does not include any usage by individuals employed in companies in the context of their daily tasks, any activity that is intended to generate revenue, or that is performed on behalf of a commercial entity. # # **โ€œYouโ€**: means the individual or entity entering into this Agreement with Perceptron, Inc.. @@ -84,60 +86,81 @@ import math import re from collections import defaultdict -from collections.abc import Sequence -from typing import Any, Callable, Optional, TypedDict, Union +from typing import Any, Callable, Optional, Sequence, Union import PIL.Image import torch import torch.nn as nn import torch.nn.functional as F -from genesis.public.tensorstream.tensor_stream import ( - Event, - Stream, - TensorStream, - TextType, - VisionType, - create_stream, - group_streams, -) -from genesis.public.tensorstream.tensor_stream_utils import ( - compute_mrope_pos_tensor, - modality_mask, - reconstruct_tensor_stream_from_compact_dict, - tensor_stream_token_view, -) -from genesis.public.tensorstream.tensor_stream_utils import ( - slice as ts_slice, +from transformers import ( + AutoImageProcessor, + AutoModel, + AutoTokenizer, + BatchFeature, + Cache, + Qwen3Config, + Qwen3ForCausalLM, + Qwen3PreTrainedModel, ) - -from ...cache_utils import Cache, SlidingWindowCache, StaticCache -from ...generation.utils import GenerationMixin -from ...image_processing_utils_fast import ( +from transformers.cache_utils import SlidingWindowCache, StaticCache +from transformers.generation.utils import GenerationMixin +from transformers.image_processing_utils_fast import ( BaseImageProcessorFast, - BatchFeature, DefaultFastImageProcessorKwargs, SizeDict, group_images_by_shape, reorder_images, ) -from ...image_utils import ChannelDimension, PILImageResampling -from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS -from ...processing_utils import ProcessorMixin, Unpack -from ...utils import TensorType, auto_docstring +from transformers.image_utils import ( + ChannelDimension, + PILImageResampling, +) +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding +from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig +from transformers.models.siglip2.modeling_siglip2 import ( + Siglip2Attention, + Siglip2Encoder as HFSiglip2Encoder, + Siglip2EncoderLayer as HFSiglip2EncoderLayer, + Siglip2VisionEmbeddings as HFSiglip2VisionEmbeddings, +) +from transformers.processing_utils import ProcessorMixin, Unpack +from transformers.tokenization_utils import TensorType +from transformers.utils import auto_docstring +from transformers.utils.generic import can_return_tuple # Vision preprocessing constants -from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN -from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD -from ...utils.import_utils import is_torchdynamo_compiling -from ..auto import AutoImageProcessor, AutoModel, AutoTokenizer -from ..qwen2.tokenization_qwen2 import Qwen2Tokenizer -from ..qwen3.configuration_qwen3 import Qwen3Config -from ..qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3PreTrainedModel -from ..siglip2.configuration_siglip2 import Siglip2VisionConfig -from ..siglip2.modeling_siglip2 import Siglip2EncoderLayer as HFSiglip2EncoderLayer +from transformers.utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN +from transformers.utils.constants import IMAGENET_STANDARD_STD as VISION_STD +from transformers.utils.import_utils import is_torchdynamo_compiling + +try: + from genesis.public.tensorstream.tensor_stream import ( + Event, + Stream, + TensorStream, + TextType, + VisionType, + create_stream, + group_streams, + ) + from genesis.public.tensorstream.tensor_stream_utils import ( + compute_mrope_pos_tensor, + modality_mask, + reconstruct_tensor_stream_from_compact_dict, + tensor_stream_token_view, + ) + from genesis.public.tensorstream.tensor_stream_utils import ( + slice as ts_slice, + ) +except ModuleNotFoundError as exc: # pragma: no cover - import guard + raise ModuleNotFoundError( + "genesis.public.tensorstream is required for the Isaac HuggingFace integration. " + "Ensure the TensorStream package is installed and on PYTHONPATH." + ) from exc _ORIGINAL_ATTENTION_FUNCTIONS: dict[str, Callable[..., tuple[torch.Tensor, Optional[torch.Tensor]]]] = {} @@ -177,14 +200,6 @@ def __init__( if self._attn_implementation is None: self._attn_implementation = "flash_attention_2" - @property - def attn_implementation(self) -> str | None: - return self._attn_implementation - - @attn_implementation.setter - def attn_implementation(self, value: str | None) -> None: - self._attn_implementation = value - class IsaacImageProcessorKwargs(DefaultFastImageProcessorKwargs, total=False): patch_size: int | None @@ -250,7 +265,7 @@ def _validate_preprocess_kwargs(self, **kwargs): def resize( self, - image: torch.Tensor, + image: "torch.Tensor", size: SizeDict, interpolation: Optional[Any] = None, antialias: bool = True, @@ -285,7 +300,7 @@ def resize( def _preprocess( self, - images: list[torch.Tensor], + images: list["torch.Tensor"], do_resize: bool, size: Optional[SizeDict], interpolation: Optional[Any], @@ -463,21 +478,24 @@ def build_document_attention_mask( return additive_mask.view(1, 1, total_tokens, total_tokens) -def repeat_kv(hidden_states: torch.Tensor, num_key_value_groups: int) -> torch.Tensor: - """Repeat key/value heads for grouped-query attention.""" - if num_key_value_groups == 1: - return hidden_states - batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, - num_key_value_heads, - num_key_value_groups, - seq_len, - head_dim, +def ensure_document_attention_mask( + attention_mask: Optional[torch.Tensor], + cu_seqlens: Optional[torch.Tensor], + total_tokens: int, + dtype: torch.dtype, + device: torch.device, +) -> Optional[torch.Tensor]: + if attention_mask is not None or cu_seqlens is None: + return attention_mask + + return build_document_attention_mask( + cu_seqlens=cu_seqlens, + total_tokens=total_tokens, + dtype=dtype, + device=device, ) - return hidden_states.reshape(batch, num_key_value_heads * num_key_value_groups, seq_len, head_dim) def flash_attention_document_mask_forward( @@ -565,69 +583,72 @@ def sdpa_document_mask_forward( return Y.squeeze(0).permute(1, 0, 2) # Back to (L, H, D) -class IsaacVisionEmbeddings(nn.Module): +class IsaacVisionEmbeddings(HFSiglip2VisionEmbeddings): + """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences.""" + def __init__(self, config: IsaacVisionConfig): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.patch_size = config.patch_size + super().__init__(config) - self.patch_embedding = nn.Linear( - in_features=config.num_channels * self.patch_size * self.patch_size, - out_features=self.embed_dim, - ) + def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor: + packed_pixel_values, seq_lengths = self._pack_to_batch(seq_patches, spatial_shapes) + if packed_pixel_values is None: + return seq_patches.new_zeros((0, self.embed_dim)) - self.num_patches = config.num_patches - self.position_embedding_size = int(self.num_patches**0.5) - self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + embeddings = super().forward(packed_pixel_values, spatial_shapes) + return self._unpack_from_batch(embeddings, seq_lengths) - def positional_embeddings(self, spatial_shapes: torch.Tensor) -> torch.Tensor: - # Prepare positional embeddings grid: (1, embed_dim, h, w) - positional_embeddings = ( - self.position_embedding.weight.reshape(self.position_embedding_size, self.position_embedding_size, -1) - .permute(2, 0, 1) - .unsqueeze(0) - ) + def _pack_to_batch( + self, + seq_patches: torch.Tensor, + spatial_shapes: torch.Tensor, + ) -> tuple[torch.Tensor | None, torch.Tensor]: + if seq_patches.ndim != 2: + raise ValueError("`seq_patches` is expected to be 2D (total_patches, patch_dim).") + if spatial_shapes.ndim != 2 or spatial_shapes.size(-1) != 2: + raise ValueError("`spatial_shapes` must have shape (num_images, 2) with (height_tokens, width_tokens).") + + seq_lengths = spatial_shapes.long().prod(dim=-1) + total_patches = int(seq_lengths.sum().item()) + if total_patches != seq_patches.size(0): + raise ValueError( + "Mismatch between packed patches and spatial shapes: got " + f"{seq_patches.size(0)} patches but spatial shapes imply {total_patches}." + ) - pos_embeds_list = [] - mode = "bilinear" - align_corners = False - for spatial_shape in spatial_shapes: - height, width = spatial_shape - # Guard to ensure height and width are positive for torch.compile - if height > 0 and width > 0: - resized_pos_embed = F.interpolate( - positional_embeddings, - size=(height, width), - mode=mode, - align_corners=align_corners, - antialias=True, - ) - # Reshape from (1, embed_dim, height, width) to (height*width, embed_dim) - resized_pos_embed = resized_pos_embed.reshape(self.embed_dim, height * width).transpose(0, 1) - else: - # Fallback - should never happen in practice - raise RuntimeError( - "Encountered non-positive spatial dimensions while computing positional embeddings." - ) - pos_embeds_list.append(resized_pos_embed) + batch_size = spatial_shapes.size(0) + if batch_size == 0: + return None, seq_lengths + + max_length = int(seq_lengths.max().item()) + patch_dim = seq_patches.size(-1) + device = seq_patches.device + + packed_pixel_values = seq_patches.new_zeros((batch_size, max_length, patch_dim), device=device) - # Concatenate all positional embeddings along the sequence dimension - pos_embeds = torch.cat(pos_embeds_list, dim=0) - return pos_embeds + start = 0 + for batch_idx, length in enumerate(seq_lengths.tolist()): + if length == 0: + continue + end = start + length + packed_pixel_values[batch_idx, :length] = seq_patches[start:end] + start = end - def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor): - # Apply patch embeddings - target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding(seq_patches.to(dtype=target_dtype)) - pos_embeds = self.positional_embeddings(spatial_shapes) + return packed_pixel_values, seq_lengths - # Add positional embeddings to patch embeddings - embeddings = patch_embeds + pos_embeds - return embeddings + def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor) -> torch.Tensor: + output_chunks: list[torch.Tensor] = [] + for batch_idx, length in enumerate(seq_lengths.tolist()): + if length == 0: + continue + output_chunks.append(embeddings[batch_idx, :length]) + if not output_chunks: + return embeddings.new_zeros((0, embeddings.size(-1))) -class IsaacVisionAttention(nn.Module): + return torch.cat(output_chunks, dim=0) + + +class IsaacVisionAttention(Siglip2Attention): """Custom attention that supports variable-length sequences with flash attention.""" ATTENTION_KEY_MAP: dict[str, str] = { @@ -642,28 +663,34 @@ class IsaacVisionAttention(nn.Module): } def __init__(self, vision_config): - super().__init__() + super().__init__(vision_config) self.vision_config = vision_config - self.config = vision_config - self.embed_dim = vision_config.hidden_size - self.num_heads = vision_config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) - self.scale = self.head_dim**-0.5 - self.dropout = vision_config.attention_dropout - self.is_causal = False - self.num_key_value_groups = 1 - - self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + self._variable_length_metadata = None + + def _variable_length_context(self, *, cu_seqlens=None, max_seqlen=None): + """Store packed-sequence metadata for the next forward call.""" + self._variable_length_metadata = (cu_seqlens, max_seqlen) + + def _consume_variable_length_metadata(self): + if self._variable_length_metadata is None: + return None, None + cu_seqlens, max_seqlen = self._variable_length_metadata + self._variable_length_metadata = None + return cu_seqlens, max_seqlen + + def forward(self, hidden_states, attention_mask=None, **kwargs): + cu_seqlens = kwargs.pop('cu_seqlens', None) + max_seqlen = kwargs.pop('max_seqlen', None) + kwargs.pop('output_attentions', None) + if kwargs: + unexpected = ', '.join(sorted(kwargs)) + raise TypeError(f'Unexpected kwargs for IsaacVisionAttention.forward: {unexpected}') + cached_cu, cached_max = self._consume_variable_length_metadata() + if cu_seqlens is None: + cu_seqlens = cached_cu + if max_seqlen is None: + max_seqlen = cached_max - def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None): # Expect packed sequences with batch_size == 1 batch_size, L, _ = hidden_states.shape if batch_size != 1: @@ -681,11 +708,12 @@ def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None): attn_impl = getattr(self.vision_config, "_attn_implementation", "flash_attention_3") - attn_mask = build_document_attention_mask( - cu_seqlens=cu_seqlens, - total_tokens=L, - dtype=q.dtype, - device=q.device, + attn_mask = ensure_document_attention_mask( + attention_mask, + cu_seqlens, + L, + q.dtype, + q.device, ) resolved_key = self.ATTENTION_KEY_MAP.get(attn_impl) @@ -735,62 +763,78 @@ def __init__(self, vision_config: IsaacVisionConfig): def forward( self, hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor = None, - max_seqlen: int = None, - ) -> torch.Tensor: - residual = hidden_states - - hidden_states = self.layer_norm1(hidden_states) + attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + output_attentions: bool = False, + ): + if cu_seqlens is not None or max_seqlen is not None: + self.self_attn._variable_length_context( + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) - hidden_states, _ = self.self_attn( - hidden_states=hidden_states, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, + attention_mask = ensure_document_attention_mask( + attention_mask, + cu_seqlens, + hidden_states.size(1), + hidden_states.dtype, + hidden_states.device, ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states + return super().forward( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) -class IsaacVisionEncoder(nn.Module): +class IsaacVisionEncoder(HFSiglip2Encoder): """Encoder using Isaac encoder layers with variable-length attention support.""" def __init__(self, config: IsaacVisionConfig): - super().__init__() - self.config = config + super().__init__(config) self.layers = nn.ModuleList([IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + def __variable_length_context(self, cu_seqlens, max_seqlen) -> None: + if cu_seqlens is None and max_seqlen is None: + return + + for layer in self.layers: + if isinstance(layer, IsaacVisionEncoderLayer): + layer.self_attn._variable_length_context( + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + @can_return_tuple def forward( self, inputs_embeds, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - output_hidden_states: bool = False, + attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ): - all_hidden_states = () if output_hidden_states else None + self.__variable_length_context(cu_seqlens, max_seqlen) - hidden_states = inputs_embeds - - for encoder_layer in self.layers: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - hidden_states = encoder_layer( - hidden_states, - cu_seqlens, - max_seqlen, - ) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + attention_mask = ensure_document_attention_mask( + attention_mask, + cu_seqlens, + inputs_embeds.size(1), + inputs_embeds.dtype, + inputs_embeds.device, + ) - return hidden_states, all_hidden_states, None + return super().forward( + inputs_embeds, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) def _isaac_flash_attention_forward( @@ -1137,11 +1181,13 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): max_seqlen = int(seq_sizes.max().item()) if seq_sizes.numel() > 0 else 0 # Pass through encoder with variable-length attention parameters - hidden_states, _, _ = self.encoder( + encoder_outputs = self.encoder( inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + return_dict=True, ) + hidden_states = encoder_outputs.last_hidden_state # Apply final layer normalization hidden_states = self.post_layernorm(hidden_states) @@ -1274,65 +1320,6 @@ def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: return patches -def precompute_cos_sin_3d( - position_ids: torch.Tensor, # shape (3, B, T) - inv_freq: torch.Tensor, # shape (dim//2,) - mrope_half_section: list[int], # sum to dim//2 -) -> tuple[torch.Tensor, torch.Tensor]: - r"""Generate 3D rotary embeddings for multi-axis positions. - - Args: - position_ids (`torch.Tensor`): - Tensor of shape `(3, batch_size, seq_len)` containing positional indices for the x/y/t axes. - inv_freq (`torch.Tensor`): - Precomputed inverse frequency vector used to derive rotary phases. - mrope_half_section (`list[int]`): - Sizes the axis-specific frequency blocks. - - Returns: - `tuple[torch.Tensor, torch.Tensor]`: Cosine and sine tensors, each of shape `(batch_size, seq_len, dim)`, ready - to be passed into rotary attention layers. - """ - B = position_ids.shape[1] - T = position_ids.shape[2] - dim_half = inv_freq.shape[0] - device = position_ids.device - - # Initialize with full dimension (not half) to match LLaMA - cos_3d = torch.zeros((B, T, dim_half * 2), dtype=torch.float32, device=device) - sin_3d = torch.zeros((B, T, dim_half * 2), dtype=torch.float32, device=device) - - offset = 0 - for d in range(3): - block_size = mrope_half_section[d] - freq_slice = inv_freq[offset : offset + block_size] # shape => (block_size,) - # shape => (B, T, block_size) - phase = position_ids[d].unsqueeze(-1).float() * freq_slice - - cos_part = phase.cos() - sin_part = phase.sin() - - # Duplicate values for both halves of the dimension - cos_3d[:, :, offset : offset + block_size] = cos_part - cos_3d[:, :, dim_half + offset : dim_half + offset + block_size] = cos_part - sin_3d[:, :, offset : offset + block_size] = sin_part - sin_3d[:, :, dim_half + offset : dim_half + offset + block_size] = sin_part - - offset += block_size - - return cos_3d, sin_3d - - -class RopeScaling(TypedDict, total=False): - rope_type: str - factor: float - mrope_section: list[int] - mrope_interleaved: bool - low_freq_factor: float - high_freq_factor: float - original_max_position_embeddings: int - - class IsaacConfig(Qwen3Config): """Configuration class for Isaac multimodal model.""" @@ -1469,12 +1456,12 @@ def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> class IsaacProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] - image_processor_class = "IsaacImageProcessorFast" + image_processor_class = ("IsaacImageProcessorFast",) tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") def __init__( self, - image_processor: "IsaacImageProcessorFast | None" = None, + image_processor: IsaacImageProcessorFast | None = None, tokenizer: Qwen2Tokenizer | None = None, *, vision_token: str = "", @@ -1658,36 +1645,19 @@ class IsaacRotaryEmbedding(nn.Module): def __init__(self, config: IsaacConfig, device=None): super().__init__() - self.config = config rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) - if rope_type not in ROPE_INIT_FUNCTIONS: - raise ValueError(f"Unsupported rope_type '{rope_type}' for IsaacRotaryEmbedding") - - self.rope_type = rope_type - rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] sanitized_scaling = {k: v for k, v in rope_scaling.items() if k not in self.EXTRA_ROPE_KEYS} - if sanitized_scaling != rope_scaling: - config_for_rope = copy.copy(rope_source_cfg) - config_for_rope.rope_scaling = sanitized_scaling - else: - config_for_rope = rope_source_cfg + config_for_rope = copy.copy(rope_source_cfg) + config_for_rope.rope_scaling = sanitized_scaling if sanitized_scaling else None init_device = device if device is not None and getattr(device, "type", None) != "meta" else None - inv_freq, attention_scaling = rope_init_fn(config_for_rope, device=init_device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.attention_scaling = self._normalize_scale(attention_scaling) + self._qwen_rotary = Qwen2_5_VLRotaryEmbedding(config_for_rope, device=init_device) - rotary_half_dim = self.inv_freq.shape[0] + rotary_half_dim = self._qwen_rotary.inv_freq.shape[0] self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim) - - @staticmethod - def _normalize_scale(scale: torch.Tensor | float) -> torch.Tensor | float: - if isinstance(scale, torch.Tensor): - return scale.detach().clone() - return float(scale) + self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size @staticmethod def _resolve_mrope_section(section: list[int] | None, rotary_half_dim: int) -> list[int]: @@ -1706,27 +1676,54 @@ def _resolve_mrope_section(section: list[int] | None, rotary_half_dim: int) -> l ) return section - def forward(self, position_ids: torch.Tensor, modality_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def _combine_axes(self, tensor: torch.Tensor) -> torch.Tensor: + split_sections = tuple(self.mrope_section * 2) + chunks = tensor.split(split_sections, dim=-1) + return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) + + @property + def inv_freq(self) -> torch.Tensor: + return self._qwen_rotary.inv_freq + + def forward( + self, + position_ids: torch.Tensor, + modality_tensor: torch.Tensor, + hidden_states: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if position_ids.ndim != 3 or position_ids.size(-1) != 3: + raise ValueError("`position_ids` must have shape (batch, seq_len, 3) for MRoPE") + if modality_tensor.shape != position_ids.shape[:2]: + raise ValueError("`modality_tensor` must align with the first two dims of `position_ids`") + + if hidden_states is None: + batch, seq_len, _ = position_ids.shape + hidden_states = torch.zeros( + batch, + seq_len, + self.hidden_size, + dtype=torch.float32, + device=position_ids.device, + ) + with torch.no_grad(): - position_ids = position_ids.clone() + pos = position_ids.clone() not_spatial = modality_tensor != VisionType.image.value if not_spatial.any(): - data_1d = position_ids[not_spatial][..., 0].unsqueeze(-1) - position_ids[not_spatial] = data_1d.expand(-1, position_ids.shape[-1]) + data_1d = pos[not_spatial][..., 0].unsqueeze(-1) + pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) + + pos_axes = pos.permute(2, 0, 1).contiguous() - position_ids = position_ids.permute(2, 0, 1) - cos, sin = precompute_cos_sin_3d(position_ids, self.inv_freq, self.mrope_section) - scale = self.attention_scaling - if isinstance(scale, torch.Tensor): - scale = scale.to(device=cos.device, dtype=cos.dtype) - elif scale != 1.0: - scale = cos.new_tensor(scale) - if isinstance(scale, torch.Tensor) or scale != 1.0: - cos = cos * scale - sin = sin * scale + cos_axes, sin_axes = self._qwen_rotary(hidden_states, pos_axes) - return cos, sin + cos_axes = cos_axes.to(hidden_states.dtype) + sin_axes = sin_axes.to(hidden_states.dtype) + cos_combined = self._combine_axes(cos_axes) + sin_combined = self._combine_axes(sin_axes) + + return cos_combined, sin_combined class IsaacModel(Qwen3PreTrainedModel): supports_gradient_checkpointing = True @@ -1894,7 +1891,11 @@ def forward( position_ids = compute_position_ids_input_ids(input_ids) # Compute MRoPE position embeddings if we have custom rotary_emb - cos, sin = self.rotary_emb(position_ids, modality_tensor) + cos, sin = self.rotary_emb( + position_ids, + modality_tensor, + hidden_states=inputs_embeds, + ) cos = cos.to(inputs_embeds.dtype) sin = sin.to(inputs_embeds.dtype) @@ -2275,4 +2276,3 @@ def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> p_indices = frame_indices[torch.tensor(is_p_frame, device=frames.device)] frames[p_indices] = frames[p_indices] - frames[last_i_indices[p_indices]] return frames - From 92f56b815a305578685de47ab5192e3293d37a7c Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 10:00:13 +0400 Subject: [PATCH 10/77] fix: changes to enable modular convert --- .../models/isaac/modular_isaac.py | 52 +++++++++---------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index d1c82fe30c53..5c485598e9c4 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -92,50 +92,46 @@ import torch import torch.nn as nn import torch.nn.functional as F -from transformers import ( - AutoImageProcessor, - AutoModel, - AutoTokenizer, - BatchFeature, - Cache, - Qwen3Config, - Qwen3ForCausalLM, - Qwen3PreTrainedModel, -) -from transformers.cache_utils import SlidingWindowCache, StaticCache -from transformers.generation.utils import GenerationMixin -from transformers.image_processing_utils_fast import ( +from ...cache_utils import Cache, SlidingWindowCache, StaticCache +from ...generation.utils import GenerationMixin +from ...feature_extraction_utils import BatchFeature +from ...image_processing_utils_fast import ( BaseImageProcessorFast, DefaultFastImageProcessorKwargs, SizeDict, group_images_by_shape, reorder_images, ) -from transformers.image_utils import ( +from ...models.auto.image_processing_auto import AutoImageProcessor +from ...models.auto.modeling_auto import AutoModel +from ...models.auto.tokenization_auto import AutoTokenizer +from ...models.qwen3.configuration_qwen3 import Qwen3Config +from ...models.qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3PreTrainedModel +from ...image_utils import ( ChannelDimension, PILImageResampling, ) -from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer -from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding -from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig -from transformers.models.siglip2.modeling_siglip2 import ( +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ..qwen2.tokenization_qwen2 import Qwen2Tokenizer +from ..qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding +from ..siglip2.configuration_siglip2 import Siglip2VisionConfig +from ..siglip2.modeling_siglip2 import ( Siglip2Attention, Siglip2Encoder as HFSiglip2Encoder, Siglip2EncoderLayer as HFSiglip2EncoderLayer, Siglip2VisionEmbeddings as HFSiglip2VisionEmbeddings, ) -from transformers.processing_utils import ProcessorMixin, Unpack -from transformers.tokenization_utils import TensorType -from transformers.utils import auto_docstring -from transformers.utils.generic import can_return_tuple +from ...processing_utils import ProcessorMixin, Unpack +from ...tokenization_utils import TensorType +from ...utils import auto_docstring +from ...utils.generic import can_return_tuple # Vision preprocessing constants -from transformers.utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN -from transformers.utils.constants import IMAGENET_STANDARD_STD as VISION_STD -from transformers.utils.import_utils import is_torchdynamo_compiling +from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN +from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD +from ...utils.import_utils import is_torchdynamo_compiling try: from genesis.public.tensorstream.tensor_stream import ( From 70bcc770e0d8cdff2b8ef79f922cf57d0874b96a Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 10:00:34 +0400 Subject: [PATCH 11/77] chore: modular convert script artifacts --- .../models/isaac/configuration_isaac.py | 14 +- .../models/isaac/image_processing_isaac.py | 432 ---------------- .../isaac/image_processing_isaac_fast.py | 20 +- .../models/isaac/modeling_isaac.py | 468 +++++++++--------- .../models/isaac/processing_isaac.py | 19 +- 5 files changed, 264 insertions(+), 689 deletions(-) delete mode 100644 src/transformers/models/isaac/image_processing_isaac.py diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index bc3d59b58cef..abd5df690fdf 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -4,7 +4,9 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# Perceptron, Inc. Non-Production License +# Copyright (c) 2024 Perceptron, Inc. All rights reserved. +# Perceptron, Inc. Non-Production License (2024-01-01) + ### 1. Scope and acceptance @@ -88,8 +90,10 @@ import copy from typing import Any +# Build the list of all image processors from ...configuration_utils import PretrainedConfig, layer_type_validation from ...modeling_rope_utils import rope_config_validation +from ...models.qwen3.configuration_qwen3 import Qwen3Config class IsaacVisionConfig(PretrainedConfig): @@ -133,14 +137,6 @@ def __init__( if self._attn_implementation is None: self._attn_implementation = "flash_attention_2" - @property - def attn_implementation(self) -> str | None: - return self._attn_implementation - - @attn_implementation.setter - def attn_implementation(self, value: str | None) -> None: - self._attn_implementation = value - class IsaacConfig(PretrainedConfig): """Configuration class for Isaac multimodal model.""" diff --git a/src/transformers/models/isaac/image_processing_isaac.py b/src/transformers/models/isaac/image_processing_isaac.py deleted file mode 100644 index d38740a4cfaf..000000000000 --- a/src/transformers/models/isaac/image_processing_isaac.py +++ /dev/null @@ -1,432 +0,0 @@ -# ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# This file was automatically generated from src/transformers/models/isaac/modular_isaac.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_isaac.py file directly. One of our CI enforces this. -# ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -import math -from collections.abc import Sequence - -import numpy as np -import torch -import torch.nn.functional as F - -from ...feature_extraction_utils import BatchFeature -from ...image_processing_utils import BaseImageProcessor -from ...image_transforms import convert_to_rgb -from ...image_utils import ( - ImageInput, - make_flat_list_of_images, - to_numpy_array, - valid_images, - validate_preprocess_arguments, -) -from ...processing_utils import ImagesKwargs -from ...tokenization_utils import TensorType -from ...utils import filter_out_non_signature_kwargs - - -class IsaacImageProcessorKwargs(ImagesKwargs): - patch_size: int | None - max_num_patches: int | None - min_num_patches: int | None - pixel_shuffle_scale: int | None - do_rescale: bool | None - rescale_factor: float | None - do_normalize: bool | None - image_mean: float | Sequence[float] | None - image_std: float | Sequence[float] | None - do_convert_rgb: bool | None - - -# Vision preprocessing constants -VISION_MEAN = (0.5, 0.5, 0.5) -VISION_STD = (0.5, 0.5, 0.5) -VISION_SCALE = 1 / 255 - - -def _normalize_rgb_values( - values: float | Sequence[float] | tuple[float, ...], - *, - name: str, -) -> tuple[float, float, float]: - """Coerce RGB normalization parameters into a 3-tuple of floats.""" - if isinstance(values, (list, tuple)): - if len(values) == 3: - return tuple(float(v) for v in values) - if len(values) == 1: - value = float(values[0]) - return (value, value, value) - raise ValueError(f"`{name}` must have length 1 or 3 when provided as a sequence. Got length {len(values)}.") - - value = float(values) - return (value, value, value) - - -def _make_writeable(arr: np.ndarray) -> np.ndarray: - if arr.flags.writeable: - return arr - try: - arr.setflags(write=True) - return arr - except ValueError: - return arr.copy() - - -# ============================================================================ -# Configuration -# ============================================================================ - -MAX_PIXELS = 60_000_000 # 60โ€‘megapixel ceiling โ‰ˆ 8200 ร— 7300 px - - -def get_scaled_image_size( - scale: float, - original_size: int, - patch_size: int, - pixel_shuffle_scale: int, -) -> int: - scaled_size = scale * original_size - divisor = patch_size * pixel_shuffle_scale - scaled_size = math.ceil(scaled_size / divisor) * divisor - scaled_size = max(divisor, scaled_size) - return int(scaled_size) - - -def get_image_size_for_max_num_patches( - image_height: int, - image_width: int, - patch_size: int, - max_num_patches: int, - min_num_patches: int | None = None, - eps: float = 1e-5, - pixel_shuffle_scale: int = 1, -) -> tuple[int, int]: - r"""Compute a target resolution whose patch grid satisfies patching parametrization. - - Args: - image_height (`int`): - Height in pixels of the source image prior to any resizing. - image_width (`int`): - Width in pixels of the source image prior to any resizing. - patch_size (`int`): - Size of the square patch used by the vision encoder. - max_num_patches (`int`): - Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. - min_num_patches (`int`, *optional*): - Lower bound on the number of patches. When provided the image will be scaled up if necessary. - eps (`float`, *optional*, defaults to 1e-5): - Convergence tolerance for the internal binary search to determing the target dimensions. - pixel_shuffle_scale (`int`, *optional*, defaults to 1): - Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. - - Returns: - `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` - and respect both the maximum and optional minimum patch-count constraints. - """ - - # Ensure divisibility - divisor = patch_size * pixel_shuffle_scale - adjusted_height = math.ceil(image_height / divisor) * divisor - adjusted_height = max(divisor, adjusted_height) - adjusted_width = math.ceil(image_width / divisor) * divisor - adjusted_width = max(divisor, adjusted_width) - - num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) - - if min_num_patches is not None and num_patches < min_num_patches: - # Scale up - scale_min, scale_max = 1.0, 100.0 - while (scale_max - scale_min) >= eps: - scale = (scale_min + scale_max) / 2 - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) - num_patches = (target_height / patch_size) * (target_width / patch_size) - if num_patches >= min_num_patches: - scale_max = scale - else: - scale_min = scale - scale = scale_max - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) - return target_height, target_width - elif num_patches <= max_num_patches: - return adjusted_height, adjusted_width - else: - # Scale down - scale_min, scale_max = eps / 10, 1.0 - while (scale_max - scale_min) >= eps: - scale = (scale_min + scale_max) / 2 - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) - num_patches = (target_height / patch_size) * (target_width / patch_size) - if num_patches <= max_num_patches: - scale_min = scale - else: - scale_max = scale - scale = scale_min - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) - return target_height, target_width - - -def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: - r"""Convert normalized images into flattened ViT-style patches. - - Args: - image (`torch.Tensor`): - Tensor of shape `(num_images, height, width, channels)`. - patch_size (`int`): - Edge length of the square patches - - Returns: - `torch.Tensor`: - Patch tensor where each position stores the flattened pixels belonging to that patch. - - Raises: - ValueError: If `height` or `width` is not divisible by `patch_size`. - """ - num_images, height, width, channels = image.shape - if height % patch_size or width % patch_size: - raise ValueError(f"Dimensions of images {image.shape} are not divisible by patch_size={patch_size}.") - patches = image.reshape(num_images, height // patch_size, patch_size, width // patch_size, patch_size, channels) - patches = patches.permute(0, 1, 3, 2, 4, 5) - patches = patches.reshape( - num_images, height // patch_size, width // patch_size, channels * patch_size * patch_size - ) - return patches - - -def _prepare_image_tensor( - image: torch.Tensor, scale: float, mean: tuple[float, ...], std: tuple[float, ...] -) -> torch.Tensor: - """Mirror the prepare_image_tensor utility used in the training pipelines.""" - if not torch.is_floating_point(image): - image = image.float() - - rescaled = image * scale - mean_tensor = torch.tensor(mean, dtype=torch.float32, device=rescaled.device).view(1, 1, 1, -1) - std_tensor = torch.tensor(std, dtype=torch.float32, device=rescaled.device).view(1, 1, 1, -1) - normalized = (rescaled - mean_tensor) / std_tensor - return normalized - - -def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> torch.Tensor: - """Compute residuals for P-frames to stay in sync with the training pipeline.""" - if not any(is_p_frame): - return frames - - frame_indices = torch.arange(len(is_p_frame), device=frames.device) - i_frame_mask = torch.tensor([not flag for flag in is_p_frame], device=frames.device) - last_i_indices = torch.cummax((i_frame_mask * (1 + frame_indices)), dim=0).values.long() - 1 - p_indices = frame_indices[torch.tensor(is_p_frame, device=frames.device)] - frames[p_indices] = frames[p_indices] - frames[last_i_indices[p_indices]] - return frames - - -class IsaacImageProcessor(BaseImageProcessor): - """Image processor that prepares RGB frames for the Isaac vision encoder.""" - - model_input_names = ["patches", "token_grids"] - - def __init__( - self, - patch_size: int = 16, - max_num_patches: int = 256, - min_num_patches: int | None = None, - pixel_shuffle_scale: int = 1, - do_rescale: bool = True, - rescale_factor: float | None = None, - do_normalize: bool = True, - image_mean: float | Sequence[float] | None = None, - image_std: float | Sequence[float] | None = None, - do_convert_rgb: bool = True, - resize_mode: str = "bilinear", - align_corners: bool = False, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - if pixel_shuffle_scale < 1: - raise ValueError("`pixel_shuffle_scale` must be >= 1") - - rescale_value = VISION_SCALE if rescale_factor is None else float(rescale_factor) - mean_value = VISION_MEAN if image_mean is None else image_mean - std_value = VISION_STD if image_std is None else image_std - - self.patch_size = patch_size - self.max_num_patches = max_num_patches - self.min_num_patches = min_num_patches - self.pixel_shuffle_scale = pixel_shuffle_scale - self.do_rescale = do_rescale - self.rescale_factor = rescale_value - self.do_normalize = do_normalize - self.image_mean = _normalize_rgb_values(mean_value, name="image_mean") - self.image_std = _normalize_rgb_values(std_value, name="image_std") - self.do_convert_rgb = do_convert_rgb - self.resize_mode = resize_mode - self.align_corners = align_corners - - @filter_out_non_signature_kwargs() - def preprocess( - self, - images: ImageInput, - patch_size: int | None = None, - max_num_patches: int | None = None, - min_num_patches: int | None = None, - pixel_shuffle_scale: int | None = None, - do_rescale: bool | None = None, - rescale_factor: float | None = None, - do_normalize: bool | None = None, - image_mean: float | Sequence[float] | None = None, - image_std: float | Sequence[float] | None = None, - do_convert_rgb: bool | None = None, - return_tensors: str | TensorType | None = None, - ) -> BatchFeature: - patch_size = patch_size if patch_size is not None else self.patch_size - max_num_patches = max_num_patches if max_num_patches is not None else self.max_num_patches - min_num_patches = min_num_patches if min_num_patches is not None else self.min_num_patches - pixel_shuffle_scale = pixel_shuffle_scale if pixel_shuffle_scale is not None else self.pixel_shuffle_scale - do_rescale = self.do_rescale if do_rescale is None else do_rescale - rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor - do_normalize = self.do_normalize if do_normalize is None else do_normalize - image_mean = self.image_mean if image_mean is None else _normalize_rgb_values(image_mean, name="image_mean") - image_std = self.image_std if image_std is None else _normalize_rgb_values(image_std, name="image_std") - do_convert_rgb = self.do_convert_rgb if do_convert_rgb is None else do_convert_rgb - - images = self.fetch_images(images) - images = make_flat_list_of_images(images) - - if not images: - raise ValueError("Received an empty list of images for preprocessing.") - if do_convert_rgb: - images = [convert_to_rgb(image) for image in images] - - if not valid_images(images): - raise ValueError( - "Invalid image type. Expected PIL images, numpy arrays, or tensors convertible to numpy arrays." - ) - - validate_preprocess_arguments( - do_rescale=do_rescale, - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - ) - - patches_list = [] - token_grids = [] - virtual_dims = [] - real_dims = [] - - for image in images: - np_image = to_numpy_array(image) - - if np_image.ndim == 2: - np_image = np.repeat(np_image[..., None], 3, axis=-1) - - height, width = np_image.shape[:2] - if height * width > MAX_PIXELS: - raise ValueError(f"Image (w={width}, h={height}) > MAX=`{MAX_PIXELS}`") - - torch_image = torch.from_numpy(_make_writeable(np_image)) - patches, vidims, rdims = self._process_single_image( - torch_image, - patch_size=patch_size, - max_num_patches=max_num_patches, - min_num_patches=min_num_patches, - pixel_shuffle_scale=pixel_shuffle_scale, - do_rescale=do_rescale, - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - ) - - patches_list.append(patches) - token_grids.append(torch.tensor([patches.size(1), patches.size(2)], dtype=torch.long)) - virtual_dims.append(vidims) - real_dims.append(rdims) - - patches_tensor = torch.cat(patches_list, dim=0) - token_grid_tensor = torch.stack(token_grids, dim=0) - virtual_dims_tensor = torch.tensor(virtual_dims, dtype=torch.long) - real_dims_tensor = torch.tensor(real_dims, dtype=torch.long) - - data = { - "patches": patches_tensor, - "token_grids": token_grid_tensor, - "virtual_pixel_size": virtual_dims_tensor, - "real_pixel_size": real_dims_tensor, - } - - return BatchFeature(data=data, tensor_type=return_tensors) - - def _process_single_image( - self, - image: torch.Tensor, - *, - patch_size: int, - max_num_patches: int, - min_num_patches: int | None, - pixel_shuffle_scale: int, - do_rescale: bool, - rescale_factor: float, - do_normalize: bool, - image_mean: tuple[float, ...], - image_std: tuple[float, ...], - ) -> tuple[torch.Tensor, list[int], list[int]]: - image_uint8 = image.unsqueeze(0) # (1, H, W, C) - image_chw = image_uint8.permute(0, 3, 1, 2) # (1, C, H, W) - - _, _, orig_height, orig_width = image_chw.shape - target_height, target_width = get_image_size_for_max_num_patches( - orig_height, - orig_width, - patch_size, - max_num_patches, - min_num_patches=min_num_patches, - pixel_shuffle_scale=pixel_shuffle_scale, - ) - - if self.resize_mode in {"linear", "bilinear", "bicubic", "trilinear"}: - resized = F.interpolate( - image_chw, - size=(target_height, target_width), - mode=self.resize_mode, - align_corners=self.align_corners, - ) - else: - resized = F.interpolate( - image_chw, - size=(target_height, target_width), - mode=self.resize_mode, - ) - - resized = resized.permute(0, 2, 3, 1) # (1, H, W, C) - - scale = rescale_factor if do_rescale else 1.0 - mean = image_mean if do_normalize else (0.0, 0.0, 0.0) - std = image_std if do_normalize else (1.0, 1.0, 1.0) - resized = _prepare_image_tensor(resized, scale=scale, mean=mean, std=std) - - resized = _compute_residual_p_frames(resized, is_p_frame=[False]) - - patches = patchify_vision(resized, patch_size=patch_size) - _, h_patches, w_patches, _ = patches.shape - - real_dims = [1, h_patches, w_patches] - if pixel_shuffle_scale > 1: - if (h_patches % pixel_shuffle_scale) or (w_patches % pixel_shuffle_scale): - raise ValueError( - "Spatial dimensions must be divisible by pixel_shuffle_scale when pixel shuffle is enabled." - ) - virtual_dims = [1, h_patches // pixel_shuffle_scale, w_patches // pixel_shuffle_scale] - else: - virtual_dims = real_dims.copy() - - return patches, virtual_dims, real_dims - - -__all__ = ["IsaacImageProcessor"] diff --git a/src/transformers/models/isaac/image_processing_isaac_fast.py b/src/transformers/models/isaac/image_processing_isaac_fast.py index ce778ef509da..d7380f23b1fd 100644 --- a/src/transformers/models/isaac/image_processing_isaac_fast.py +++ b/src/transformers/models/isaac/image_processing_isaac_fast.py @@ -4,7 +4,9 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# Perceptron, Inc. Non-Production License +# Copyright (c) 2024 Perceptron, Inc. All rights reserved. +# Perceptron, Inc. Non-Production License (2024-01-01) + ### 1. Scope and acceptance @@ -91,16 +93,12 @@ import torch import torch.nn.functional as F -from ...image_processing_utils_fast import ( - BaseImageProcessorFast, - BatchFeature, - SizeDict, - group_images_by_shape, - reorder_images, -) +from ...feature_extraction_utils import BatchFeature +from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict, group_images_by_shape, reorder_images from ...image_utils import ChannelDimension, PILImageResampling from ...processing_utils import Unpack -from ...utils import TensorType, auto_docstring +from ...tokenization_utils import TensorType +from ...utils import auto_docstring # Vision preprocessing constants from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN @@ -294,7 +292,7 @@ def _validate_preprocess_kwargs(self, **kwargs): def resize( self, - image: torch.Tensor, + image: "torch.Tensor", size: SizeDict, interpolation: Optional[Any] = None, antialias: bool = True, @@ -329,7 +327,7 @@ def resize( def _preprocess( self, - images: list[torch.Tensor], + images: list["torch.Tensor"], do_resize: bool, size: Optional[SizeDict], interpolation: Optional[Any], diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 3b26d5ba7769..ba69ef2a8e48 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -4,7 +4,9 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# Perceptron, Inc. Non-Production License +# Copyright (c) 2024 Perceptron, Inc. All rights reserved. +# Perceptron, Inc. Non-Production License (2024-01-01) + ### 1. Scope and acceptance @@ -87,17 +89,10 @@ import copy from collections import defaultdict -from typing import Any, Callable, Optional, TypedDict +from typing import Any, Callable, Optional import torch import torch.nn as nn -import torch.nn.functional as F -from genesis.public.tensorstream.tensor_stream import TensorStream, TextType, VisionType, group_streams -from genesis.public.tensorstream.tensor_stream_utils import ( - compute_mrope_pos_tensor, - modality_mask, - reconstruct_tensor_stream_from_compact_dict, -) from ...activations import ACT2FN from ...cache_utils import Cache, SlidingWindowCache, StaticCache @@ -107,76 +102,81 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...models.auto.modeling_auto import AutoModel +from ...models.qwen3.configuration_qwen3 import Qwen3Config +from ...models.qwen3.modeling_qwen3 import Qwen3PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils import TransformersKwargs, auto_docstring from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import can_return_tuple from ...utils.import_utils import is_torchdynamo_compiling -from ..auto import AutoModel from .configuration_isaac import IsaacConfig, IsaacVisionConfig -class IsaacVisionEmbeddings(nn.Module): +class IsaacVisionEmbeddings(HFSiglip2VisionEmbeddings): + """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences.""" + def __init__(self, config: IsaacVisionConfig): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.patch_size = config.patch_size + super().__init__(config) - self.patch_embedding = nn.Linear( - in_features=config.num_channels * self.patch_size * self.patch_size, - out_features=self.embed_dim, - ) + def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor: + packed_pixel_values, seq_lengths = self._pack_to_batch(seq_patches, spatial_shapes) + if packed_pixel_values is None: + return seq_patches.new_zeros((0, self.embed_dim)) - self.num_patches = config.num_patches - self.position_embedding_size = int(self.num_patches**0.5) - self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + embeddings = super().forward(packed_pixel_values, spatial_shapes) + return self._unpack_from_batch(embeddings, seq_lengths) - def positional_embeddings(self, spatial_shapes: torch.Tensor) -> torch.Tensor: - # Prepare positional embeddings grid: (1, embed_dim, h, w) - positional_embeddings = ( - self.position_embedding.weight.reshape(self.position_embedding_size, self.position_embedding_size, -1) - .permute(2, 0, 1) - .unsqueeze(0) - ) + def _pack_to_batch( + self, + seq_patches: torch.Tensor, + spatial_shapes: torch.Tensor, + ) -> tuple[torch.Tensor | None, torch.Tensor]: + if seq_patches.ndim != 2: + raise ValueError("`seq_patches` is expected to be 2D (total_patches, patch_dim).") + if spatial_shapes.ndim != 2 or spatial_shapes.size(-1) != 2: + raise ValueError("`spatial_shapes` must have shape (num_images, 2) with (height_tokens, width_tokens).") + + seq_lengths = spatial_shapes.long().prod(dim=-1) + total_patches = int(seq_lengths.sum().item()) + if total_patches != seq_patches.size(0): + raise ValueError( + "Mismatch between packed patches and spatial shapes: got " + f"{seq_patches.size(0)} patches but spatial shapes imply {total_patches}." + ) - pos_embeds_list = [] - mode = "bilinear" - align_corners = False - for spatial_shape in spatial_shapes: - height, width = spatial_shape - # Guard to ensure height and width are positive for torch.compile - if height > 0 and width > 0: - resized_pos_embed = F.interpolate( - positional_embeddings, - size=(height, width), - mode=mode, - align_corners=align_corners, - antialias=True, - ) - # Reshape from (1, embed_dim, height, width) to (height*width, embed_dim) - resized_pos_embed = resized_pos_embed.reshape(self.embed_dim, height * width).transpose(0, 1) - else: - # Fallback - should never happen in practice - raise RuntimeError( - "Encountered non-positive spatial dimensions while computing positional embeddings." - ) - pos_embeds_list.append(resized_pos_embed) + batch_size = spatial_shapes.size(0) + if batch_size == 0: + return None, seq_lengths - # Concatenate all positional embeddings along the sequence dimension - pos_embeds = torch.cat(pos_embeds_list, dim=0) - return pos_embeds + max_length = int(seq_lengths.max().item()) + patch_dim = seq_patches.size(-1) + device = seq_patches.device - def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor): - # Apply patch embeddings - target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding(seq_patches.to(dtype=target_dtype)) - pos_embeds = self.positional_embeddings(spatial_shapes) + packed_pixel_values = seq_patches.new_zeros((batch_size, max_length, patch_dim), device=device) - # Add positional embeddings to patch embeddings - embeddings = patch_embeds + pos_embeds - return embeddings + start = 0 + for batch_idx, length in enumerate(seq_lengths.tolist()): + if length == 0: + continue + end = start + length + packed_pixel_values[batch_idx, :length] = seq_patches[start:end] + start = end + + return packed_pixel_values, seq_lengths + + def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor) -> torch.Tensor: + output_chunks: list[torch.Tensor] = [] + for batch_idx, length in enumerate(seq_lengths.tolist()): + if length == 0: + continue + output_chunks.append(embeddings[batch_idx, :length]) + + if not output_chunks: + return embeddings.new_zeros((0, embeddings.size(-1))) + + return torch.cat(output_chunks, dim=0) def build_document_attention_mask( @@ -204,6 +204,24 @@ def build_document_attention_mask( return additive_mask.view(1, 1, total_tokens, total_tokens) +def ensure_document_attention_mask( + attention_mask: Optional[torch.Tensor], + cu_seqlens: Optional[torch.Tensor], + total_tokens: int, + dtype: torch.dtype, + device: torch.device, +) -> Optional[torch.Tensor]: + if attention_mask is not None or cu_seqlens is None: + return attention_mask + + return build_document_attention_mask( + cu_seqlens=cu_seqlens, + total_tokens=total_tokens, + dtype=dtype, + device=device, + ) + + class IsaacVisionAttention(nn.Module): """Custom attention that supports variable-length sequences with flash attention.""" @@ -220,10 +238,9 @@ class IsaacVisionAttention(nn.Module): def __init__(self, vision_config): super().__init__() - self.vision_config = vision_config - self.config = vision_config - self.embed_dim = vision_config.hidden_size - self.num_heads = vision_config.num_attention_heads + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( @@ -231,16 +248,30 @@ def __init__(self, vision_config): f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 - self.dropout = vision_config.attention_dropout + self.dropout = config.attention_dropout self.is_causal = False - self.num_key_value_groups = 1 self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.vision_config = vision_config + self._variable_length_metadata = None + + def forward(self, hidden_states, attention_mask=None, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + cu_seqlens = kwargs.pop("cu_seqlens", None) + max_seqlen = kwargs.pop("max_seqlen", None) + kwargs.pop("output_attentions", None) + if kwargs: + unexpected = ", ".join(sorted(kwargs)) + raise TypeError(f"Unexpected kwargs for IsaacVisionAttention.forward: {unexpected}") + cached_cu, cached_max = self._consume_variable_length_metadata() + if cu_seqlens is None: + cu_seqlens = cached_cu + if max_seqlen is None: + max_seqlen = cached_max - def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None): # Expect packed sequences with batch_size == 1 batch_size, L, _ = hidden_states.shape if batch_size != 1: @@ -258,11 +289,12 @@ def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None): attn_impl = getattr(self.vision_config, "_attn_implementation", "flash_attention_3") - attn_mask = build_document_attention_mask( - cu_seqlens=cu_seqlens, - total_tokens=L, - dtype=q.dtype, - device=q.device, + attn_mask = ensure_document_attention_mask( + attention_mask, + cu_seqlens, + L, + q.dtype, + q.device, ) resolved_key = self.ATTENTION_KEY_MAP.get(attn_impl) @@ -301,6 +333,17 @@ def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None): y = self.out_proj(y_lhd.reshape(L, self.embed_dim)) return y.unsqueeze(0), None # (1, L, E) + def _variable_length_context(self, *, cu_seqlens=None, max_seqlen=None): + """Store packed-sequence metadata for the next forward call.""" + self._variable_length_metadata = (cu_seqlens, max_seqlen) + + def _consume_variable_length_metadata(self): + if self._variable_length_metadata is None: + return None, None + cu_seqlens, max_seqlen = self._variable_length_metadata + self._variable_length_metadata = None + return cu_seqlens, max_seqlen + class IsaacVisionEncoderLayer(HFSiglip2EncoderLayer): """Isaac vision encoder layer with variable-length attention.""" @@ -312,62 +355,78 @@ def __init__(self, vision_config: IsaacVisionConfig): def forward( self, hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor = None, - max_seqlen: int = None, - ) -> torch.Tensor: - residual = hidden_states - - hidden_states = self.layer_norm1(hidden_states) + attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + output_attentions: bool = False, + ): + if cu_seqlens is not None or max_seqlen is not None: + self.self_attn._variable_length_context( + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) - hidden_states, _ = self.self_attn( - hidden_states=hidden_states, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, + attention_mask = ensure_document_attention_mask( + attention_mask, + cu_seqlens, + hidden_states.size(1), + hidden_states.dtype, + hidden_states.device, ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states + return super().forward( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) -class IsaacVisionEncoder(nn.Module): +class IsaacVisionEncoder(HFSiglip2Encoder): """Encoder using Isaac encoder layers with variable-length attention support.""" def __init__(self, config: IsaacVisionConfig): - super().__init__() - self.config = config + super().__init__(config) self.layers = nn.ModuleList([IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + def __variable_length_context(self, cu_seqlens, max_seqlen) -> None: + if cu_seqlens is None and max_seqlen is None: + return + + for layer in self.layers: + if isinstance(layer, IsaacVisionEncoderLayer): + layer.self_attn._variable_length_context( + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + @can_return_tuple def forward( self, inputs_embeds, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - output_hidden_states: bool = False, + attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ): - all_hidden_states = () if output_hidden_states else None - - hidden_states = inputs_embeds - - for encoder_layer in self.layers: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + self.__variable_length_context(cu_seqlens, max_seqlen) - hidden_states = encoder_layer( - hidden_states, - cu_seqlens, - max_seqlen, - ) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + attention_mask = ensure_document_attention_mask( + attention_mask, + cu_seqlens, + inputs_embeds.size(1), + inputs_embeds.dtype, + inputs_embeds.device, + ) - return hidden_states, all_hidden_states, None + return super().forward( + inputs_embeds, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) def create_pixel_shuffle_index_map( @@ -522,11 +581,13 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): max_seqlen = int(seq_sizes.max().item()) if seq_sizes.numel() > 0 else 0 # Pass through encoder with variable-length attention parameters - hidden_states, _, _ = self.encoder( + encoder_outputs = self.encoder( inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + return_dict=True, ) + hidden_states = encoder_outputs.last_hidden_state # Apply final layer normalization hidden_states = self.post_layernorm(hidden_states) @@ -544,101 +605,25 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): return hidden_states -class RopeScaling(TypedDict, total=False): - rope_type: str - factor: float - mrope_section: list[int] - mrope_interleaved: bool - low_freq_factor: float - high_freq_factor: float - original_max_position_embeddings: int - - -def precompute_cos_sin_3d( - position_ids: torch.Tensor, # shape (3, B, T) - inv_freq: torch.Tensor, # shape (dim//2,) - mrope_half_section: list[int], # sum to dim//2 -) -> tuple[torch.Tensor, torch.Tensor]: - r"""Generate 3D rotary embeddings for multi-axis positions. - - Args: - position_ids (`torch.Tensor`): - Tensor of shape `(3, batch_size, seq_len)` containing positional indices for the x/y/t axes. - inv_freq (`torch.Tensor`): - Precomputed inverse frequency vector used to derive rotary phases. - mrope_half_section (`list[int]`): - Sizes the axis-specific frequency blocks. - - Returns: - `tuple[torch.Tensor, torch.Tensor]`: Cosine and sine tensors, each of shape `(batch_size, seq_len, dim)`, ready - to be passed into rotary attention layers. - """ - B = position_ids.shape[1] - T = position_ids.shape[2] - dim_half = inv_freq.shape[0] - device = position_ids.device - - # Initialize with full dimension (not half) to match LLaMA - cos_3d = torch.zeros((B, T, dim_half * 2), dtype=torch.float32, device=device) - sin_3d = torch.zeros((B, T, dim_half * 2), dtype=torch.float32, device=device) - - offset = 0 - for d in range(3): - block_size = mrope_half_section[d] - freq_slice = inv_freq[offset : offset + block_size] # shape => (block_size,) - # shape => (B, T, block_size) - phase = position_ids[d].unsqueeze(-1).float() * freq_slice - - cos_part = phase.cos() - sin_part = phase.sin() - - # Duplicate values for both halves of the dimension - cos_3d[:, :, offset : offset + block_size] = cos_part - cos_3d[:, :, dim_half + offset : dim_half + offset + block_size] = cos_part - sin_3d[:, :, offset : offset + block_size] = sin_part - sin_3d[:, :, dim_half + offset : dim_half + offset + block_size] = sin_part - - offset += block_size - - return cos_3d, sin_3d - - class IsaacRotaryEmbedding(nn.Module): EXTRA_ROPE_KEYS = {"mrope_section", "mrope_interleaved"} def __init__(self, config: IsaacConfig, device=None): super().__init__() - self.config = config rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) - if rope_type not in ROPE_INIT_FUNCTIONS: - raise ValueError(f"Unsupported rope_type '{rope_type}' for IsaacRotaryEmbedding") - - self.rope_type = rope_type - rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] sanitized_scaling = {k: v for k, v in rope_scaling.items() if k not in self.EXTRA_ROPE_KEYS} - if sanitized_scaling != rope_scaling: - config_for_rope = copy.copy(rope_source_cfg) - config_for_rope.rope_scaling = sanitized_scaling - else: - config_for_rope = rope_source_cfg + config_for_rope = copy.copy(rope_source_cfg) + config_for_rope.rope_scaling = sanitized_scaling if sanitized_scaling else None init_device = device if device is not None and getattr(device, "type", None) != "meta" else None - inv_freq, attention_scaling = rope_init_fn(config_for_rope, device=init_device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.attention_scaling = self._normalize_scale(attention_scaling) + self._qwen_rotary = Qwen2_5_VLRotaryEmbedding(config_for_rope, device=init_device) - rotary_half_dim = self.inv_freq.shape[0] + rotary_half_dim = self._qwen_rotary.inv_freq.shape[0] self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim) - - @staticmethod - def _normalize_scale(scale: torch.Tensor | float) -> torch.Tensor | float: - if isinstance(scale, torch.Tensor): - return scale.detach().clone() - return float(scale) + self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size @staticmethod def _resolve_mrope_section(section: list[int] | None, rotary_half_dim: int) -> list[int]: @@ -657,26 +642,54 @@ def _resolve_mrope_section(section: list[int] | None, rotary_half_dim: int) -> l ) return section - def forward(self, position_ids: torch.Tensor, modality_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def _combine_axes(self, tensor: torch.Tensor) -> torch.Tensor: + split_sections = tuple(self.mrope_section * 2) + chunks = tensor.split(split_sections, dim=-1) + return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) + + @property + def inv_freq(self) -> torch.Tensor: + return self._qwen_rotary.inv_freq + + def forward( + self, + position_ids: torch.Tensor, + modality_tensor: torch.Tensor, + hidden_states: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if position_ids.ndim != 3 or position_ids.size(-1) != 3: + raise ValueError("`position_ids` must have shape (batch, seq_len, 3) for MRoPE") + if modality_tensor.shape != position_ids.shape[:2]: + raise ValueError("`modality_tensor` must align with the first two dims of `position_ids`") + + if hidden_states is None: + batch, seq_len, _ = position_ids.shape + hidden_states = torch.zeros( + batch, + seq_len, + self.hidden_size, + dtype=torch.float32, + device=position_ids.device, + ) + with torch.no_grad(): - position_ids = position_ids.clone() + pos = position_ids.clone() not_spatial = modality_tensor != VisionType.image.value if not_spatial.any(): - data_1d = position_ids[not_spatial][..., 0].unsqueeze(-1) - position_ids[not_spatial] = data_1d.expand(-1, position_ids.shape[-1]) + data_1d = pos[not_spatial][..., 0].unsqueeze(-1) + pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) + + pos_axes = pos.permute(2, 0, 1).contiguous() + + cos_axes, sin_axes = self._qwen_rotary(hidden_states, pos_axes) + + cos_axes = cos_axes.to(hidden_states.dtype) + sin_axes = sin_axes.to(hidden_states.dtype) - position_ids = position_ids.permute(2, 0, 1) - cos, sin = precompute_cos_sin_3d(position_ids, self.inv_freq, self.mrope_section) - scale = self.attention_scaling - if isinstance(scale, torch.Tensor): - scale = scale.to(device=cos.device, dtype=cos.dtype) - elif scale != 1.0: - scale = cos.new_tensor(scale) - if isinstance(scale, torch.Tensor) or scale != 1.0: - cos = cos * scale - sin = sin * scale + cos_combined = self._combine_axes(cos_axes) + sin_combined = self._combine_axes(sin_axes) - return cos, sin + return cos_combined, sin_combined @use_kernel_forward_from_hub("RMSNorm") @@ -750,21 +763,16 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -def repeat_kv(hidden_states: torch.Tensor, num_key_value_groups: int) -> torch.Tensor: - """Repeat key/value heads for grouped-query attention.""" - - if num_key_value_groups == 1: +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: return hidden_states - - batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, - num_key_value_heads, - num_key_value_groups, - seq_len, - head_dim, - ) - return hidden_states.reshape(batch, num_key_value_heads * num_key_value_groups, seq_len, head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( @@ -1119,7 +1127,11 @@ def forward( position_ids = compute_position_ids_input_ids(input_ids) # Compute MRoPE position embeddings if we have custom rotary_emb - cos, sin = self.rotary_emb(position_ids, modality_tensor) + cos, sin = self.rotary_emb( + position_ids, + modality_tensor, + hidden_states=inputs_embeds, + ) cos = cos.to(inputs_embeds.dtype) sin = sin.to(inputs_embeds.dtype) diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 894832a36a9d..6fbe37a94b12 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -4,7 +4,9 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# Perceptron, Inc. Non-Production License +# Copyright (c) 2024 Perceptron, Inc. All rights reserved. +# Perceptron, Inc. Non-Production License (2024-01-01) + ### 1. Scope and acceptance @@ -86,17 +88,16 @@ import math import re +from typing import Optional import PIL.Image import torch -from genesis.public.tensorstream.tensor_stream import Event, Stream, TensorStream, TextType, VisionType, create_stream -from genesis.public.tensorstream.tensor_stream_utils import slice as ts_slice -from genesis.public.tensorstream.tensor_stream_utils import tensor_stream_token_view -from ...image_processing_utils_fast import BatchFeature, DefaultFastImageProcessorKwargs +from ...feature_extraction_utils import BatchFeature +from ...image_processing_utils_fast import DefaultFastImageProcessorKwargs +from ...models.auto.tokenization_auto import AutoTokenizer from ...processing_utils import ProcessorMixin -from ...utils import TensorType -from ..auto import AutoTokenizer +from ...tokenization_utils import TensorType from .configuration_isaac import IsaacConfig @@ -157,12 +158,12 @@ def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> class IsaacProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] - image_processor_class = "IsaacImageProcessorFast" + image_processor_class = ("IsaacImageProcessorFast",) tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") def __init__( self, - image_processor: "IsaacImageProcessorFast | None" = None, + image_processor: Optional["IsaacImageProcessorFast"] = None, tokenizer: Qwen2Tokenizer | None = None, *, vision_token: str = "", From 5656b83394acd20d43244e79c9be533e43f49c54 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 10:02:50 +0400 Subject: [PATCH 12/77] style: remove redundant registration --- .../models/isaac/modular_isaac.py | 25 ++++++------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 5c485598e9c4..635d791ecd21 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -2244,23 +2244,6 @@ def prepare_inputs_for_generation( def can_generate(self) -> bool: return True - -AutoImageProcessor.register( - IsaacConfig, - fast_image_processor_class=IsaacImageProcessorFast, - exist_ok=True, -) - - -__all__ = [ - "IsaacConfig", - "IsaacModel", - "IsaacForConditionalGeneration", - "IsaacImageProcessorFast", - "IsaacProcessor", -] - - def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> torch.Tensor: """Compute residuals for P-frames to stay in sync with the training pipeline.""" if not any(is_p_frame): @@ -2272,3 +2255,11 @@ def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> p_indices = frame_indices[torch.tensor(is_p_frame, device=frames.device)] frames[p_indices] = frames[p_indices] - frames[last_i_indices[p_indices]] return frames + +__all__ = [ + "IsaacConfig", + "IsaacModel", + "IsaacForConditionalGeneration", + "IsaacImageProcessorFast", + "IsaacProcessor", +] From 963f8c1b4a697a7c12d1ffd6f139f01bc166a508 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 10:04:27 +0400 Subject: [PATCH 13/77] style: organize auto file entries --- src/transformers/models/auto/configuration_auto.py | 4 ++-- src/transformers/models/auto/modeling_auto.py | 4 ++-- src/transformers/models/auto/processing_auto.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index d0a4d937f6f0..d07b001f43d8 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -208,6 +208,7 @@ ("instructblipvideo", "InstructBlipVideoConfig"), ("internvl", "InternVLConfig"), ("internvl_vision", "InternVLVisionConfig"), + ("isaac", "IsaacConfig"), ("jamba", "JambaConfig"), ("janus", "JanusConfig"), ("jetmoe", "JetMoeConfig"), @@ -300,7 +301,6 @@ ("perceiver", "PerceiverConfig"), ("perception_encoder", "TimmWrapperConfig"), ("perception_lm", "PerceptionLMConfig"), - ("isaac", "IsaacConfig"), ("persimmon", "PersimmonConfig"), ("phi", "PhiConfig"), ("phi3", "Phi3Config"), @@ -642,6 +642,7 @@ ("instructblipvideo", "InstructBlipVideo"), ("internvl", "InternVL"), ("internvl_vision", "InternVLVision"), + ("isaac", "Isaac"), ("jamba", "Jamba"), ("janus", "Janus"), ("jetmoe", "JetMoe"), @@ -745,7 +746,6 @@ ("perceiver", "Perceiver"), ("perception_encoder", "PerceptionEncoder"), ("perception_lm", "PerceptionLM"), - ("isaac", "Isaac"), ("persimmon", "Persimmon"), ("phi", "Phi"), ("phi3", "Phi3"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 13e6ae32ab8f..c6548897657d 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -208,6 +208,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("instructblipvideo", "InstructBlipVideoModel"), ("internvl", "InternVLModel"), ("internvl_vision", "InternVLVisionModel"), + ("isaac", "IsaacModel"), ("jamba", "JambaModel"), ("janus", "JanusModel"), ("jetmoe", "JetMoeModel"), @@ -299,7 +300,6 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("perceiver", "PerceiverModel"), ("perception_encoder", "PerceptionEncoder"), ("perception_lm", "PerceptionLMModel"), - ("isaac", "IsaacModel"), ("persimmon", "PersimmonModel"), ("phi", "PhiModel"), ("phi3", "Phi3Model"), @@ -1022,6 +1022,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("idefics3", "Idefics3ForConditionalGeneration"), ("instructblip", "InstructBlipForConditionalGeneration"), ("internvl", "InternVLForConditionalGeneration"), + ("isaac", "IsaacForConditionalGeneration"), ("janus", "JanusForConditionalGeneration"), ("kosmos-2", "Kosmos2ForConditionalGeneration"), ("kosmos-2.5", "Kosmos2_5ForConditionalGeneration"), @@ -1035,7 +1036,6 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("ovis2", "Ovis2ForConditionalGeneration"), ("paligemma", "PaliGemmaForConditionalGeneration"), ("perception_lm", "PerceptionLMForConditionalGeneration"), - ("isaac", "IsaacForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"), ("pixtral", "LlavaForConditionalGeneration"), ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 5ee7a913114c..aaa0b02f2375 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -87,6 +87,7 @@ ("instructblip", "InstructBlipProcessor"), ("instructblipvideo", "InstructBlipVideoProcessor"), ("internvl", "InternVLProcessor"), + ("isaac", "IsaacProcessor"), ("janus", "JanusProcessor"), ("kosmos-2", "Kosmos2Processor"), ("kosmos-2.5", "Kosmos2_5Processor"), @@ -112,7 +113,6 @@ ("owlvit", "OwlViTProcessor"), ("paligemma", "PaliGemmaProcessor"), ("perception_lm", "PerceptionLMProcessor"), - ("isaac", "IsaacProcessor"), ("phi4_multimodal", "Phi4MultimodalProcessor"), ("pix2struct", "Pix2StructProcessor"), ("pixtral", "PixtralProcessor"), From 74f9f3be42be43679b176fa07629407418d83ded Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 10:20:35 +0400 Subject: [PATCH 14/77] style: lints --- .../models/isaac/configuration_isaac.py | 1 - .../isaac/image_processing_isaac_fast.py | 4 +- .../models/isaac/modeling_isaac.py | 6 + .../models/isaac/modular_isaac.py | 230 +++++++++--------- .../models/isaac/processing_isaac.py | 3 + 5 files changed, 132 insertions(+), 112 deletions(-) diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index abd5df690fdf..a076ed7cc1d6 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -90,7 +90,6 @@ import copy from typing import Any -# Build the list of all image processors from ...configuration_utils import PretrainedConfig, layer_type_validation from ...modeling_rope_utils import rope_config_validation from ...models.qwen3.configuration_qwen3 import Qwen3Config diff --git a/src/transformers/models/isaac/image_processing_isaac_fast.py b/src/transformers/models/isaac/image_processing_isaac_fast.py index d7380f23b1fd..ce599e6e0508 100644 --- a/src/transformers/models/isaac/image_processing_isaac_fast.py +++ b/src/transformers/models/isaac/image_processing_isaac_fast.py @@ -292,7 +292,7 @@ def _validate_preprocess_kwargs(self, **kwargs): def resize( self, - image: "torch.Tensor", + image: torch.Tensor, size: SizeDict, interpolation: Optional[Any] = None, antialias: bool = True, @@ -327,7 +327,7 @@ def resize( def _preprocess( self, - images: list["torch.Tensor"], + images: list[torch.Tensor], do_resize: bool, size: Optional[SizeDict], interpolation: Optional[Any], diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index ba69ef2a8e48..a0aaa38c9f18 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -93,6 +93,12 @@ import torch import torch.nn as nn +from genesis.public.tensorstream.tensor_stream import TensorStream, TextType, VisionType, group_streams +from genesis.public.tensorstream.tensor_stream_utils import ( + compute_mrope_pos_tensor, + modality_mask, + reconstruct_tensor_stream_from_compact_dict, +) from ...activations import ACT2FN from ...cache_utils import Cache, SlidingWindowCache, StaticCache diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 635d791ecd21..4755ee68beb9 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -5,18 +5,18 @@ ### 1. Scope and acceptance # **1.1. Scope of the Agreement.** -# This Agreement applies to any use, modification, or Distribution of any Perceptron Model by You, regardless of the source You obtained a copy of such Perceptron Model. +# This Agreement applies to any use, modification, or Distribution of any Perceptron Model by You, regardless of the source You obtained a copy of such Perceptron Model. # -# **1.2. Acceptance.** By accessing, using, modifying, Distributing a Perceptron Model, or by creating, using or distributing a Derivative of the Perceptron Model, You agree to be bound by this Agreement. +# **1.2. Acceptance.** By accessing, using, modifying, Distributing a Perceptron Model, or by creating, using or distributing a Derivative of the Perceptron Model, You agree to be bound by this Agreement. # # **1.3. Acceptance on behalf of a third-party.** If You accept this Agreement on behalf of Your employer or another person or entity, You warrant and represent that You have the authority to act and accept this Agreement on their behalf. In such a case, the word โ€œYouโ€ in this Agreement will refer to Your employer or such other person or entity. # -# ## 2. License -# **2.1. Grant of rights.** Subject to Section 3 below, Perceptron, Inc. hereby grants You a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable, limited license to use, copy, modify, and Distribute under the conditions provided in Section 2.2 below, the Perceptron Model and any Derivatives made by or for Perceptron, Inc. and to create Derivatives of the Perceptron Model. +# ## 2. License +# **2.1. Grant of rights.** Subject to Section 3 below, Perceptron, Inc. hereby grants You a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable, limited license to use, copy, modify, and Distribute under the conditions provided in Section 2.2 below, the Perceptron Model and any Derivatives made by or for Perceptron, Inc. and to create Derivatives of the Perceptron Model. # -# **2.2. Distribution of Perceptron Model and Derivatives made by or for Perceptron, Inc..** Subject to Section 3 below, You may Distribute copies of the Perceptron Model and/or Derivatives made by or for Perceptron, Inc., under the following conditions: -# - You must make available a copy of this Agreement to third-party recipients of the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. you Distribute, it being specified that any rights to use the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. shall be directly granted by Perceptron, Inc. to said third-party recipients pursuant to the Perceptron, Inc. Non-Production License agreement executed between these parties; -# - You must retain in all copies of the Perceptron Models the following attribution notice within a โ€œNoticeโ€ text file distributed as part of such copies: โ€œLicensed by Perceptron, Inc. under the Perceptron, Inc. Non-Production Licenseโ€. +# **2.2. Distribution of Perceptron Model and Derivatives made by or for Perceptron, Inc..** Subject to Section 3 below, You may Distribute copies of the Perceptron Model and/or Derivatives made by or for Perceptron, Inc., under the following conditions: +# - You must make available a copy of this Agreement to third-party recipients of the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. you Distribute, it being specified that any rights to use the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. shall be directly granted by Perceptron, Inc. to said third-party recipients pursuant to the Perceptron, Inc. Non-Production License agreement executed between these parties; +# - You must retain in all copies of the Perceptron Models the following attribution notice within a โ€œNoticeโ€ text file distributed as part of such copies: โ€œLicensed by Perceptron, Inc. under the Perceptron, Inc. Non-Production Licenseโ€. # # **2.3. Distribution of Derivatives made by or for You.** Subject to Section 3 below, You may Distribute any Derivatives made by or for You under additional or different terms and conditions, provided that: # - In any event, the use and modification of Perceptron Model and/or Derivatives made by or for Perceptron, Inc. shall remain governed by the terms and conditions of this Agreement; @@ -24,12 +24,12 @@ # - Any terms and conditions You impose on any third-party recipients relating to Derivatives made by or for You shall neither limit such third-party recipientsโ€™ use of the Perceptron Model or any Derivatives made by or for Perceptron, Inc. in accordance with the Perceptron, Inc. Non-Production License nor conflict with any of its terms and conditions. # # ## 3. Limitations -# **3.1. Misrepresentation.** You must not misrepresent or imply, through any means, that the Derivatives made by or for You and/or any modified version of the Perceptron Model You Distribute under your name and responsibility is an official product of Perceptron, Inc. or has been endorsed, approved or validated by Perceptron, Inc., unless You are authorized by Us to do so in writing. +# **3.1. Misrepresentation.** You must not misrepresent or imply, through any means, that the Derivatives made by or for You and/or any modified version of the Perceptron Model You Distribute under your name and responsibility is an official product of Perceptron, Inc. or has been endorsed, approved or validated by Perceptron, Inc., unless You are authorized by Us to do so in writing. # # **3.2. Usage Limitation** -# - You shall only use the Perceptron Models and Derivatives (whether or not created by Perceptron, Inc.) for testing, research, Personal, or evaluation purposes in Non-Production Environments; +# - You shall only use the Perceptron Models and Derivatives (whether or not created by Perceptron, Inc.) for testing, research, Personal, or evaluation purposes in Non-Production Environments; # - Subject to the foregoing, You shall not supply the Perceptron Models or Derivatives in the course of a commercial activity, whether in return for payment or free of charge, in any medium or form, including but not limited to through a hosted or managed service (e.g. SaaS, cloud instances, etc.), or behind a software layer. -# +# # **3.3. Usage not permitted under this Agreement.** If You want to use a Perceptron Model or a Derivative for any purpose that is not expressly authorized under this Agreement, You must request a license from Perceptron, Inc., which Perceptron, Inc. may grant to You in Perceptron, Inc.โ€™s sole discretion. Please contact Perceptron, Inc. at the following e-mail address if You want to discuss such a license: sales@perceptron.inc # # ## 4. Intellectual Property @@ -37,10 +37,10 @@ # # **4.2. Outputs.** We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs You generate and their subsequent uses in accordance with this Agreement. # -# **4.3. Derivatives.** By entering into this Agreement, You accept that any Derivatives that You may create or that may be created for You shall be subject to the restrictions set out in Section 3 of this Agreement. +# **4.3. Derivatives.** By entering into this Agreement, You accept that any Derivatives that You may create or that may be created for You shall be subject to the restrictions set out in Section 3 of this Agreement. # # # 5. Liability -# **5.1. Limitation of liability.** In no event, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall Perceptron, Inc. be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Perceptron Models and Derivatives (including but not limited to damages for loss of data, loss of goodwill, loss of expected profit or savings, work stoppage, computer failure or malfunction, or any damage caused by malware or security breaches), even if Perceptron, Inc. has been advised of the possibility of such damages. +# **5.1. Limitation of liability.** In no event, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall Perceptron, Inc. be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Perceptron Models and Derivatives (including but not limited to damages for loss of data, loss of goodwill, loss of expected profit or savings, work stoppage, computer failure or malfunction, or any damage caused by malware or security breaches), even if Perceptron, Inc. has been advised of the possibility of such damages. # # **5.2. Indemnification.** You agree to indemnify and hold harmless Perceptron, Inc. from and against any claims, damages, or losses arising out of or related to Your use or Distribution of the Perceptron Models and Derivatives. # @@ -48,9 +48,9 @@ # **6.1. Disclaimer.** Unless required by applicable law or agreed to in writing, Perceptron, Inc. provides the Perceptron Models and Derivatives on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. Perceptron, Inc. does not represent nor warrant that the Perceptron Models and Derivatives will be error-free, meet Your or any third partyโ€™s requirements, be secure or will allow You or any third party to achieve any kind of result or generate any kind of content. You are solely responsible for determining the appropriateness of using or Distributing the Perceptron Models and Derivatives and assume any risks associated with Your exercise of rights under this Agreement. # # # 7. Termination -# **7.1. Term.** This Agreement is effective as of the date of your acceptance of this Agreement or access to the concerned Perceptron Models or Derivatives and will continue until terminated in accordance with the following terms. +# **7.1. Term.** This Agreement is effective as of the date of your acceptance of this Agreement or access to the concerned Perceptron Models or Derivatives and will continue until terminated in accordance with the following terms. # -# **7.2. Termination.** Perceptron, Inc. may terminate this Agreement at any time if You are in breach of this Agreement. Upon termination of this Agreement, You must cease to use all Perceptron Models and Derivatives and shall permanently delete any copy thereof. Sections 5, 6, 7 and 8 shall survive the termination of this Agreement. +# **7.2. Termination.** Perceptron, Inc. may terminate this Agreement at any time if You are in breach of this Agreement. Upon termination of this Agreement, You must cease to use all Perceptron Models and Derivatives and shall permanently delete any copy thereof. Sections 5, 6, 7 and 8 shall survive the termination of this Agreement. # # **7.3. Litigation.** If You initiate any legal action or proceedings against Us or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the Model or a Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement will immediately terminate as of the date such legal action or claim is filed or initiated. # @@ -61,22 +61,22 @@ # # **8.3. Severability.** If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. # -# # 9. Definitions -# **โ€œAgreementโ€**: means this Perceptron, Inc. Non-Production License agreement governing the access, use, and Distribution of the Perceptron Models and Derivatives. +# # 9. Definitions +# **โ€œAgreementโ€**: means this Perceptron, Inc. Non-Production License agreement governing the access, use, and Distribution of the Perceptron Models and Derivatives. # -# **โ€œDerivativeโ€**: means any (i) modified version of the Perceptron Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the Perceptron Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered as Derivatives under this Agreement. +# **โ€œDerivativeโ€**: means any (i) modified version of the Perceptron Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the Perceptron Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered as Derivatives under this Agreement. # -# **โ€œDistributionโ€**, **โ€œDistributingโ€**, **โ€œDistributeโ€** or **โ€œDistributedโ€**: means providing or making available, by any means, a copy of the Perceptron Models and/or the Derivatives as the case may be, subject to Section 3 of this Agreement. +# **โ€œDistributionโ€**, **โ€œDistributingโ€**, **โ€œDistributeโ€** or **โ€œDistributedโ€**: means providing or making available, by any means, a copy of the Perceptron Models and/or the Derivatives as the case may be, subject to Section 3 of this Agreement. # # **โ€œPerceptron, Inc.โ€**, **โ€œWeโ€** or **โ€œUsโ€**: means Perceptron, Inc., a Delaware corporation with its principal place of business at 10900 NE 8th St Suite 613, Bellevue, WA 98004. # -# **โ€œPerceptron Modelโ€**: means the foundational large language model(s), and its elements which include algorithms, software, instructed checkpoints, parameters, source code (inference code, evaluation code and, if applicable, fine-tuning code) and any other elements associated thereto made available by Perceptron, Inc. under this Agreement, including, if any, the technical documentation, manuals and instructions for the use and operation thereof. +# **โ€œPerceptron Modelโ€**: means the foundational large language model(s), and its elements which include algorithms, software, instructed checkpoints, parameters, source code (inference code, evaluation code and, if applicable, fine-tuning code) and any other elements associated thereto made available by Perceptron, Inc. under this Agreement, including, if any, the technical documentation, manuals and instructions for the use and operation thereof. # -# **โ€œNon-Production Environmentโ€**: means any setting, use case, or application of the Perceptron Models or Derivatives that expressly excludes live, real-world conditions, commercial operations, revenue-generating activities, or direct interactions with or impacts on end users (such as, for instance, Your employees or customers). Non-Production Environment may include, but is not limited to, any setting, use case, or application for research, development, testing, quality assurance, training, internal evaluation (other than any internal usage by employees in the context of the companyโ€™s business activities), and demonstration purposes. +# **โ€œNon-Production Environmentโ€**: means any setting, use case, or application of the Perceptron Models or Derivatives that expressly excludes live, real-world conditions, commercial operations, revenue-generating activities, or direct interactions with or impacts on end users (such as, for instance, Your employees or customers). Non-Production Environment may include, but is not limited to, any setting, use case, or application for research, development, testing, quality assurance, training, internal evaluation (other than any internal usage by employees in the context of the companyโ€™s business activities), and demonstration purposes. # -# **โ€œOutputsโ€**: means any content generated by the operation of the Perceptron Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a Perceptron Models, such as any fine-tuned versions of the Perceptron Models, the weights, or parameters. +# **โ€œOutputsโ€**: means any content generated by the operation of the Perceptron Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a Perceptron Models, such as any fine-tuned versions of the Perceptron Models, the weights, or parameters. # -# **โ€œPersonalโ€**: means any use of a Perceptron Model or a Derivative that is (i) solely for personal, non-profit and non-commercial purposes and (ii) not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities. For illustration purposes, Personal use of a Model or a Derivative does not include any usage by individuals employed in companies in the context of their daily tasks, any activity that is intended to generate revenue, or that is performed on behalf of a commercial entity. +# **โ€œPersonalโ€**: means any use of a Perceptron Model or a Derivative that is (i) solely for personal, non-profit and non-commercial purposes and (ii) not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities. For illustration purposes, Personal use of a Model or a Derivative does not include any usage by individuals employed in companies in the context of their daily tasks, any activity that is intended to generate revenue, or that is performed on behalf of a commercial entity. # # **โ€œYouโ€**: means the individual or entity entering into this Agreement with Perceptron, Inc.. @@ -86,15 +86,17 @@ import math import re from collections import defaultdict -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union import PIL.Image import torch import torch.nn as nn import torch.nn.functional as F + from ...cache_utils import Cache, SlidingWindowCache, StaticCache -from ...generation.utils import GenerationMixin from ...feature_extraction_utils import BatchFeature +from ...generation.utils import GenerationMixin from ...image_processing_utils_fast import ( BaseImageProcessorFast, DefaultFastImageProcessorKwargs, @@ -102,11 +104,6 @@ group_images_by_shape, reorder_images, ) -from ...models.auto.image_processing_auto import AutoImageProcessor -from ...models.auto.modeling_auto import AutoModel -from ...models.auto.tokenization_auto import AutoTokenizer -from ...models.qwen3.configuration_qwen3 import Qwen3Config -from ...models.qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3PreTrainedModel from ...image_utils import ( ChannelDimension, PILImageResampling, @@ -114,49 +111,54 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...models.auto.modeling_auto import AutoModel +from ...models.auto.tokenization_auto import AutoTokenizer +from ...models.qwen3.configuration_qwen3 import Qwen3Config +from ...models.qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3PreTrainedModel +from ...processing_utils import ProcessorMixin, Unpack +from ...tokenization_utils import TensorType +from ...utils import auto_docstring + +# Vision preprocessing constants +from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN +from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD +from ...utils.generic import can_return_tuple +from ...utils.import_utils import is_torchdynamo_compiling from ..qwen2.tokenization_qwen2 import Qwen2Tokenizer from ..qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding from ..siglip2.configuration_siglip2 import Siglip2VisionConfig from ..siglip2.modeling_siglip2 import ( Siglip2Attention, +) +from ..siglip2.modeling_siglip2 import ( Siglip2Encoder as HFSiglip2Encoder, +) +from ..siglip2.modeling_siglip2 import ( Siglip2EncoderLayer as HFSiglip2EncoderLayer, +) +from ..siglip2.modeling_siglip2 import ( Siglip2VisionEmbeddings as HFSiglip2VisionEmbeddings, ) -from ...processing_utils import ProcessorMixin, Unpack -from ...tokenization_utils import TensorType -from ...utils import auto_docstring -from ...utils.generic import can_return_tuple -# Vision preprocessing constants -from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN -from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD -from ...utils.import_utils import is_torchdynamo_compiling -try: - from genesis.public.tensorstream.tensor_stream import ( - Event, - Stream, - TensorStream, - TextType, - VisionType, - create_stream, - group_streams, - ) - from genesis.public.tensorstream.tensor_stream_utils import ( - compute_mrope_pos_tensor, - modality_mask, - reconstruct_tensor_stream_from_compact_dict, - tensor_stream_token_view, - ) - from genesis.public.tensorstream.tensor_stream_utils import ( - slice as ts_slice, - ) -except ModuleNotFoundError as exc: # pragma: no cover - import guard - raise ModuleNotFoundError( - "genesis.public.tensorstream is required for the Isaac HuggingFace integration. " - "Ensure the TensorStream package is installed and on PYTHONPATH." - ) from exc +from genesis.public.tensorstream.tensor_stream import ( + Event, + Stream, + TensorStream, + TextType, + VisionType, + create_stream, + group_streams, +) +from genesis.public.tensorstream.tensor_stream_utils import ( + compute_mrope_pos_tensor, + modality_mask, + reconstruct_tensor_stream_from_compact_dict, + tensor_stream_token_view, +) +from genesis.public.tensorstream.tensor_stream_utils import ( + slice as ts_slice, +) _ORIGINAL_ATTENTION_FUNCTIONS: dict[str, Callable[..., tuple[torch.Tensor, Optional[torch.Tensor]]]] = {} @@ -249,7 +251,6 @@ def __init__( raise ValueError("`pixel_shuffle_scale` must be >= 1") self.pixel_shuffle_scale = pixel_shuffle_scale - def _validate_preprocess_kwargs(self, **kwargs): # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for. kwargs.pop("do_resize", None) @@ -261,7 +262,7 @@ def _validate_preprocess_kwargs(self, **kwargs): def resize( self, - image: "torch.Tensor", + image: torch.Tensor, size: SizeDict, interpolation: Optional[Any] = None, antialias: bool = True, @@ -296,7 +297,7 @@ def resize( def _preprocess( self, - images: list["torch.Tensor"], + images: list[torch.Tensor], do_resize: bool, size: Optional[SizeDict], interpolation: Optional[Any], @@ -323,7 +324,6 @@ def _preprocess( if do_pad: raise ValueError("`do_pad` is not supported by IsaacImageProcessorFast.") - grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) processed_patches_grouped: dict[tuple[int, ...], torch.Tensor] = {} token_grids_grouped: dict[tuple[int, ...], torch.Tensor] = {} @@ -341,9 +341,7 @@ def _preprocess( channels = 3 if original_height * original_width > self.MAX_PIXELS: - raise ValueError( - f"Image (w={original_width}, h={original_height}) > MAX=`{self.MAX_PIXELS}`" - ) + raise ValueError(f"Image (w={original_width}, h={original_height}) > MAX=`{self.MAX_PIXELS}`") target_height, target_width = get_image_size_for_max_num_patches( original_height, @@ -363,9 +361,7 @@ def _preprocess( ) else: if ((original_height % patch_size) != 0) or ((original_width % patch_size) != 0): - raise ValueError( - "Image dimensions must be divisible by patch_size when resize is disabled." - ) + raise ValueError("Image dimensions must be divisible by patch_size when resize is disabled.") image_batch = stacked_images target_height, target_width = original_height, original_width @@ -385,17 +381,25 @@ def _preprocess( patches = patchify_vision(nhwc_images, patch_size=patch_size) _, height_tokens, width_tokens, _ = patches.shape - token_grid = torch.tensor( - [height_tokens, width_tokens], - dtype=torch.long, - device=patches.device, - ).unsqueeze(0).repeat(batch_size, 1) + token_grid = ( + torch.tensor( + [height_tokens, width_tokens], + dtype=torch.long, + device=patches.device, + ) + .unsqueeze(0) + .repeat(batch_size, 1) + ) - real_dim = torch.tensor( - [1, height_tokens, width_tokens], - dtype=torch.long, - device=patches.device, - ).unsqueeze(0).repeat(batch_size, 1) + real_dim = ( + torch.tensor( + [1, height_tokens, width_tokens], + dtype=torch.long, + device=patches.device, + ) + .unsqueeze(0) + .repeat(batch_size, 1) + ) if pixel_shuffle_scale > 1: if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): @@ -408,11 +412,15 @@ def _preprocess( virtual_height = height_tokens virtual_width = width_tokens - virtual_dim = torch.tensor( - [1, virtual_height, virtual_width], - dtype=torch.long, - device=patches.device, - ).unsqueeze(0).repeat(batch_size, 1) + virtual_dim = ( + torch.tensor( + [1, virtual_height, virtual_width], + dtype=torch.long, + device=patches.device, + ) + .unsqueeze(0) + .repeat(batch_size, 1) + ) processed_patches_grouped[shape] = patches token_grids_grouped[shape] = token_grid @@ -440,8 +448,6 @@ def _preprocess( ) - - def _max_from_cu(cu: torch.Tensor | None, fallback: int) -> int: """Helper to compute max sequence length from cumulative sequence lengths.""" if cu is None or len(cu) < 2: @@ -474,8 +480,6 @@ def build_document_attention_mask( return additive_mask.view(1, 1, total_tokens, total_tokens) - - def ensure_document_attention_mask( attention_mask: Optional[torch.Tensor], cu_seqlens: Optional[torch.Tensor], @@ -675,12 +679,12 @@ def _consume_variable_length_metadata(self): return cu_seqlens, max_seqlen def forward(self, hidden_states, attention_mask=None, **kwargs): - cu_seqlens = kwargs.pop('cu_seqlens', None) - max_seqlen = kwargs.pop('max_seqlen', None) - kwargs.pop('output_attentions', None) + cu_seqlens = kwargs.pop("cu_seqlens", None) + max_seqlen = kwargs.pop("max_seqlen", None) + kwargs.pop("output_attentions", None) if kwargs: - unexpected = ', '.join(sorted(kwargs)) - raise TypeError(f'Unexpected kwargs for IsaacVisionAttention.forward: {unexpected}') + unexpected = ", ".join(sorted(kwargs)) + raise TypeError(f"Unexpected kwargs for IsaacVisionAttention.forward: {unexpected}") cached_cu, cached_max = self._consume_variable_length_metadata() if cu_seqlens is None: cu_seqlens = cached_cu @@ -861,7 +865,9 @@ def _isaac_flash_attention_forward( ) if query.dim() != 4 or query.size(0) != 1: - raise ValueError("IsaacVisionAttention expects packed sequences with batch size 1 when using packed attention.") + raise ValueError( + "IsaacVisionAttention expects packed sequences with batch size 1 when using packed attention." + ) _, num_heads, seq_len, head_dim = query.shape q_lhd = query.transpose(1, 2).reshape(seq_len, num_heads, head_dim) @@ -930,7 +936,9 @@ def _isaac_sdpa_forward( ) if query.dim() != 4 or query.size(0) != 1: - raise ValueError("IsaacVisionAttention expects packed sequences with batch size 1 when using packed attention.") + raise ValueError( + "IsaacVisionAttention expects packed sequences with batch size 1 when using packed attention." + ) _, num_heads, seq_len, head_dim = query.shape q_lhd = query.transpose(1, 2).reshape(seq_len, num_heads, head_dim) @@ -992,7 +1000,9 @@ def _isaac_eager_forward( ) if query.dim() != 4 or query.size(0) != 1: - raise ValueError("IsaacVisionAttention expects packed sequences with batch size 1 when using packed attention.") + raise ValueError( + "IsaacVisionAttention expects packed sequences with batch size 1 when using packed attention." + ) _, num_heads, seq_len, head_dim = query.shape q_lhd = query.transpose(1, 2).reshape(seq_len, num_heads, head_dim) @@ -1057,9 +1067,7 @@ def create_pixel_shuffle_index_map( # Safety: all spatial dims must be divisible by the scale factor # Cannot run under torch compile fullgraph mode hence if not is_torchdynamo_compiling(): - if not ( - (token_grids[:, 0] % scale_factor == 0).all() and (token_grids[:, 1] % scale_factor == 0).all() - ): + if not ((token_grids[:, 0] % scale_factor == 0).all() and (token_grids[:, 1] % scale_factor == 0).all()): raise AssertionError( "Every (H,W) in `token_grids` must be divisible by " f"scale_factor={scale_factor}, got {token_grids.tolist()}" @@ -1312,7 +1320,9 @@ def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: raise ValueError(f"Dimensions of images {image.shape} are not divisible by patch_size={patch_size}.") patches = image.reshape(num_images, height // patch_size, patch_size, width // patch_size, patch_size, channels) patches = patches.permute(0, 1, 3, 2, 4, 5) - patches = patches.reshape(num_images, height // patch_size, width // patch_size, channels * patch_size * patch_size) + patches = patches.reshape( + num_images, height // patch_size, width // patch_size, channels * patch_size * patch_size + ) return patches @@ -1327,7 +1337,7 @@ def __init__( self, vision_config: IsaacVisionConfig | None = None, text_config: Qwen3Config | dict | None = None, - vision_rescale_factor: float = 1/255, + vision_rescale_factor: float = 1 / 255, max_sequence_length: int = 16384, vision_token: str = "", **kwargs, @@ -1387,7 +1397,6 @@ def rope_scaling(self, value): @property def vision_attn_implementation(self) -> str | None: - value = getattr(self.vision_config, "_attn_implementation", None) if value is None: value = getattr(self.vision_config, "attn_implementation", None) @@ -1457,7 +1466,7 @@ class IsaacProcessor(ProcessorMixin): def __init__( self, - image_processor: IsaacImageProcessorFast | None = None, + image_processor: Optional["IsaacImageProcessorFast"] = None, tokenizer: Qwen2Tokenizer | None = None, *, vision_token: str = "", @@ -1476,9 +1485,7 @@ def __init__( vision_token = config.vision_token rescale_factor = config.vision_rescale_factor - resolved_rescale_factor = ( - float(rescale_factor) if rescale_factor is not None else float(1/255) - ) + resolved_rescale_factor = float(rescale_factor) if rescale_factor is not None else float(1 / 255) if config is not None: config.vision_rescale_factor = resolved_rescale_factor @@ -1721,6 +1728,7 @@ def forward( return cos_combined, sin_combined + class IsaacModel(Qwen3PreTrainedModel): supports_gradient_checkpointing = True @@ -2050,7 +2058,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( causal_mask = attention_mask else: min_dtype = torch.finfo(dtype).min - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also @@ -2244,6 +2254,7 @@ def prepare_inputs_for_generation( def can_generate(self) -> bool: return True + def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> torch.Tensor: """Compute residuals for P-frames to stay in sync with the training pipeline.""" if not any(is_p_frame): @@ -2256,6 +2267,7 @@ def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> frames[p_indices] = frames[p_indices] - frames[last_i_indices[p_indices]] return frames + __all__ = [ "IsaacConfig", "IsaacModel", diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 6fbe37a94b12..ed76210fe99f 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -92,6 +92,9 @@ import PIL.Image import torch +from genesis.public.tensorstream.tensor_stream import Event, Stream, TensorStream, TextType, VisionType, create_stream +from genesis.public.tensorstream.tensor_stream_utils import slice as ts_slice +from genesis.public.tensorstream.tensor_stream_utils import tensor_stream_token_view from ...feature_extraction_utils import BatchFeature from ...image_processing_utils_fast import DefaultFastImageProcessorKwargs From f56f064f6e4ef6bfe57fbfd202f757a391a24ffc Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 10:28:19 +0400 Subject: [PATCH 15/77] fix: processor typing --- .../models/isaac/modular_isaac.py | 40 +++++++++---------- .../models/isaac/processing_isaac.py | 3 +- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 4755ee68beb9..64bf9b2c5390 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -93,6 +93,24 @@ import torch import torch.nn as nn import torch.nn.functional as F +from genesis.public.tensorstream.tensor_stream import ( + Event, + Stream, + TensorStream, + TextType, + VisionType, + create_stream, + group_streams, +) +from genesis.public.tensorstream.tensor_stream_utils import ( + compute_mrope_pos_tensor, + modality_mask, + reconstruct_tensor_stream_from_compact_dict, + tensor_stream_token_view, +) +from genesis.public.tensorstream.tensor_stream_utils import ( + slice as ts_slice, +) from ...cache_utils import Cache, SlidingWindowCache, StaticCache from ...feature_extraction_utils import BatchFeature @@ -141,26 +159,6 @@ ) -from genesis.public.tensorstream.tensor_stream import ( - Event, - Stream, - TensorStream, - TextType, - VisionType, - create_stream, - group_streams, -) -from genesis.public.tensorstream.tensor_stream_utils import ( - compute_mrope_pos_tensor, - modality_mask, - reconstruct_tensor_stream_from_compact_dict, - tensor_stream_token_view, -) -from genesis.public.tensorstream.tensor_stream_utils import ( - slice as ts_slice, -) - - _ORIGINAL_ATTENTION_FUNCTIONS: dict[str, Callable[..., tuple[torch.Tensor, Optional[torch.Tensor]]]] = {} for _attn_name in ("flash_attention_2", "sdpa", "eager"): if _attn_name in ALL_ATTENTION_FUNCTIONS: @@ -1466,7 +1464,7 @@ class IsaacProcessor(ProcessorMixin): def __init__( self, - image_processor: Optional["IsaacImageProcessorFast"] = None, + image_processor: "IsaacImageProcessorFast | None" = None, tokenizer: Qwen2Tokenizer | None = None, *, vision_token: str = "", diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index ed76210fe99f..47b0b5a766af 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -88,7 +88,6 @@ import math import re -from typing import Optional import PIL.Image import torch @@ -166,7 +165,7 @@ class IsaacProcessor(ProcessorMixin): def __init__( self, - image_processor: Optional["IsaacImageProcessorFast"] = None, + image_processor: "IsaacImageProcessorFast | None" = None, tokenizer: Qwen2Tokenizer | None = None, *, vision_token: str = "", From 4c5c19d3abd829fca9d63763161ec673f17b7349 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 10:37:05 +0400 Subject: [PATCH 16/77] fix: allow image processor typing --- src/transformers/models/isaac/modular_isaac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 64bf9b2c5390..e4571e77ee89 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -1464,7 +1464,7 @@ class IsaacProcessor(ProcessorMixin): def __init__( self, - image_processor: "IsaacImageProcessorFast | None" = None, + image_processor: "IsaacImageProcessorFast | None" = None, # noqa tokenizer: Qwen2Tokenizer | None = None, *, vision_token: str = "", From 8a95e64e34305d6789cf8695e64ae7f7499da0b2 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 10:43:25 +0400 Subject: [PATCH 17/77] style: | for unions --- .../models/isaac/modular_isaac.py | 86 +++++++++---------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index e4571e77ee89..b9c518f3c923 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -87,7 +87,7 @@ import re from collections import defaultdict from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Callable import PIL.Image import torch @@ -159,7 +159,7 @@ ) -_ORIGINAL_ATTENTION_FUNCTIONS: dict[str, Callable[..., tuple[torch.Tensor, Optional[torch.Tensor]]]] = {} +_ORIGINAL_ATTENTION_FUNCTIONS: dict[str, Callable[..., tuple[torch.Tensor, torch.Tensor | None]]] = {} for _attn_name in ("flash_attention_2", "sdpa", "eager"): if _attn_name in ALL_ATTENTION_FUNCTIONS: _ORIGINAL_ATTENTION_FUNCTIONS[_attn_name] = ALL_ATTENTION_FUNCTIONS[_attn_name] @@ -185,9 +185,9 @@ def __init__( self, pixel_shuffle_scale_factor: int = 1, num_patches: int = 256, - **kwargs, + **super_kwargs, ): - super().__init__(**kwargs) + super().__init__(**super_kwargs) # Add our custom fields self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor @@ -240,9 +240,9 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): def __init__( self, - **kwargs: Unpack[IsaacImageProcessorKwargs], + **super_kwargs: Unpack[IsaacImageProcessorKwargs], ) -> None: - super().__init__(**kwargs) + super().__init__(**super_kwargs) pixel_shuffle_scale = 1 if self.pixel_shuffle_scale is None else int(self.pixel_shuffle_scale) if pixel_shuffle_scale < 1: @@ -262,7 +262,7 @@ def resize( self, image: torch.Tensor, size: SizeDict, - interpolation: Optional[Any] = None, + interpolation: Any | None = None, antialias: bool = True, **kwargs, ) -> torch.Tensor: @@ -297,19 +297,19 @@ def _preprocess( self, images: list[torch.Tensor], do_resize: bool, - size: Optional[SizeDict], - interpolation: Optional[Any], + size: SizeDict | None, + interpolation: Any | None, do_center_crop: bool, - crop_size: Optional[SizeDict], - do_rescale: Optional[bool], - rescale_factor: Optional[float], - do_normalize: Optional[bool], - image_mean: Optional[Union[float, Sequence[float]]], - image_std: Optional[Union[float, Sequence[float]]], - disable_grouping: Optional[bool] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - do_pad: Optional[bool] = None, - pad_size: Optional[SizeDict] = None, + crop_size: SizeDict | None, + do_rescale: bool | None, + rescale_factor: float | None, + do_normalize: bool | None, + image_mean: float | Sequence[float] | None, + image_std: float | Sequence[float] | None, + disable_grouping: bool | None = None, + return_tensors: str | TensorType | None = None, + do_pad: bool | None = None, + pad_size: SizeDict | None = None, *, patch_size: int | None = None, max_num_patches: int | None = None, @@ -479,12 +479,12 @@ def build_document_attention_mask( def ensure_document_attention_mask( - attention_mask: Optional[torch.Tensor], - cu_seqlens: Optional[torch.Tensor], + attention_mask: torch.Tensor | None, + cu_seqlens: torch.Tensor | None, total_tokens: int, dtype: torch.dtype, device: torch.device, -) -> Optional[torch.Tensor]: +) -> torch.Tensor | None: if attention_mask is not None or cu_seqlens is None: return attention_mask @@ -761,9 +761,9 @@ def __init__(self, vision_config: IsaacVisionConfig): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, + attention_mask: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, output_attentions: bool = False, ): if cu_seqlens is not None or max_seqlen is not None: @@ -809,12 +809,12 @@ def __variable_length_context(self, cu_seqlens, max_seqlen) -> None: def forward( self, inputs_embeds, - attention_mask: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ): self.__variable_length_context(cu_seqlens, max_seqlen) @@ -840,12 +840,12 @@ def _isaac_flash_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor], + attention_mask: torch.Tensor | None, dropout: float = 0.0, - scaling: Optional[float] = None, + scaling: float | None = None, is_causal: bool = False, **kwargs, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor | None]: base_fn = _ORIGINAL_ATTENTION_FUNCTIONS.get("flash_attention_2") if not isinstance(module, IsaacVisionAttention) or base_fn is None: if base_fn is None: @@ -911,12 +911,12 @@ def _isaac_sdpa_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor], + attention_mask: torch.Tensor | None, dropout: float = 0.0, - scaling: Optional[float] = None, + scaling: float | None = None, is_causal: bool = False, **kwargs, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor | None]: base_fn = _ORIGINAL_ATTENTION_FUNCTIONS.get("sdpa") if not isinstance(module, IsaacVisionAttention) or base_fn is None: if base_fn is None: @@ -975,12 +975,12 @@ def _isaac_eager_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor], + attention_mask: torch.Tensor | None, dropout: float = 0.0, - scaling: Optional[float] = None, + scaling: float | None = None, is_causal: bool = False, **kwargs, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor | None]: base_fn = _ORIGINAL_ATTENTION_FUNCTIONS.get("eager") if not isinstance(module, IsaacVisionAttention) or base_fn is None: if base_fn is None: @@ -1338,10 +1338,10 @@ def __init__( vision_rescale_factor: float = 1 / 255, max_sequence_length: int = 16384, vision_token: str = "", - **kwargs, + **super_kwargs, ): self._rope_scaling: dict[str, Any] | None = None - resolved_text_config = kwargs.pop("text_config", text_config) + resolved_text_config = super_kwargs.pop("text_config", text_config) if isinstance(resolved_text_config, Qwen3Config): text_config_kwargs = copy.deepcopy(resolved_text_config.to_dict()) elif isinstance(resolved_text_config, dict): @@ -1351,7 +1351,7 @@ def __init__( else: raise TypeError("`text_config` must be a mapping or `Qwen3Config` instance when provided.") - text_config_kwargs.update(kwargs) + text_config_kwargs.update(super_kwargs) super().__init__(**text_config_kwargs) self.text_config = Qwen3Config(**text_config_kwargs) From 25523ba8823bffce0df13c345cc8b49988404538 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 10:47:50 +0400 Subject: [PATCH 18/77] fix: don't alias siglip --- src/transformers/models/isaac/modular_isaac.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index b9c518f3c923..317706bb74f4 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -147,15 +147,9 @@ from ..siglip2.configuration_siglip2 import Siglip2VisionConfig from ..siglip2.modeling_siglip2 import ( Siglip2Attention, -) -from ..siglip2.modeling_siglip2 import ( - Siglip2Encoder as HFSiglip2Encoder, -) -from ..siglip2.modeling_siglip2 import ( - Siglip2EncoderLayer as HFSiglip2EncoderLayer, -) -from ..siglip2.modeling_siglip2 import ( - Siglip2VisionEmbeddings as HFSiglip2VisionEmbeddings, + Siglip2Encoder, + Siglip2EncoderLayer, + Siglip2VisionEmbeddings, ) @@ -581,7 +575,7 @@ def sdpa_document_mask_forward( return Y.squeeze(0).permute(1, 0, 2) # Back to (L, H, D) -class IsaacVisionEmbeddings(HFSiglip2VisionEmbeddings): +class IsaacVisionEmbeddings(Siglip2VisionEmbeddings): """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences.""" def __init__(self, config: IsaacVisionConfig): @@ -751,7 +745,7 @@ def forward(self, hidden_states, attention_mask=None, **kwargs): return y.unsqueeze(0), None # (1, L, E) -class IsaacVisionEncoderLayer(HFSiglip2EncoderLayer): +class IsaacVisionEncoderLayer(Siglip2EncoderLayer): """Isaac vision encoder layer with variable-length attention.""" def __init__(self, vision_config: IsaacVisionConfig): @@ -787,7 +781,7 @@ def forward( ) -class IsaacVisionEncoder(HFSiglip2Encoder): +class IsaacVisionEncoder(Siglip2Encoder): """Encoder using Isaac encoder layers with variable-length attention support.""" def __init__(self, config: IsaacVisionConfig): From d1dc712ce4dfa16d6d9533137a99cd2680ab244b Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 11:03:37 +0400 Subject: [PATCH 19/77] fix: rename vision config to config to be consistent with base class --- src/transformers/models/isaac/modular_isaac.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 317706bb74f4..875c7eab1ee1 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -654,9 +654,9 @@ class IsaacVisionAttention(Siglip2Attention): "isaac_eager": "isaac_eager", } - def __init__(self, vision_config): - super().__init__(vision_config) - self.vision_config = vision_config + def __init__(self, config): + super().__init__(config) + self.config = config self._variable_length_metadata = None def _variable_length_context(self, *, cu_seqlens=None, max_seqlen=None): @@ -698,7 +698,7 @@ def forward(self, hidden_states, attention_mask=None, **kwargs): k = self.k_proj(x).view(L, H, D) v = self.v_proj(x).view(L, H, D) - attn_impl = getattr(self.vision_config, "_attn_implementation", "flash_attention_3") + attn_impl = getattr(self.config, "_attn_implementation", "flash_attention_3") attn_mask = ensure_document_attention_mask( attention_mask, From d80c9f6b2f3fab7b720284976aa755ab25cd728d Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 11:08:43 +0400 Subject: [PATCH 20/77] fix: additional remakes --- .../models/isaac/modular_isaac.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 875c7eab1ee1..a82ae31ac057 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -179,9 +179,9 @@ def __init__( self, pixel_shuffle_scale_factor: int = 1, num_patches: int = 256, - **super_kwargs, + **kwargs, ): - super().__init__(**super_kwargs) + super().__init__(**kwargs) # Add our custom fields self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor @@ -234,9 +234,9 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): def __init__( self, - **super_kwargs: Unpack[IsaacImageProcessorKwargs], + **kwargs: Unpack[IsaacImageProcessorKwargs], ) -> None: - super().__init__(**super_kwargs) + super().__init__(**kwargs) pixel_shuffle_scale = 1 if self.pixel_shuffle_scale is None else int(self.pixel_shuffle_scale) if pixel_shuffle_scale < 1: @@ -748,9 +748,9 @@ def forward(self, hidden_states, attention_mask=None, **kwargs): class IsaacVisionEncoderLayer(Siglip2EncoderLayer): """Isaac vision encoder layer with variable-length attention.""" - def __init__(self, vision_config: IsaacVisionConfig): - super().__init__(vision_config) - self.self_attn = IsaacVisionAttention(vision_config) + def __init__(self, config: IsaacVisionConfig): + super().__init__(config) + self.self_attn = IsaacVisionAttention(config) def forward( self, @@ -1332,10 +1332,10 @@ def __init__( vision_rescale_factor: float = 1 / 255, max_sequence_length: int = 16384, vision_token: str = "", - **super_kwargs, + **kwargs, ): self._rope_scaling: dict[str, Any] | None = None - resolved_text_config = super_kwargs.pop("text_config", text_config) + resolved_text_config = kwargs.pop("text_config", text_config) if isinstance(resolved_text_config, Qwen3Config): text_config_kwargs = copy.deepcopy(resolved_text_config.to_dict()) elif isinstance(resolved_text_config, dict): @@ -1345,7 +1345,7 @@ def __init__( else: raise TypeError("`text_config` must be a mapping or `Qwen3Config` instance when provided.") - text_config_kwargs.update(super_kwargs) + text_config_kwargs.update(kwargs) super().__init__(**text_config_kwargs) self.text_config = Qwen3Config(**text_config_kwargs) From 58d73116f60569fdcef538b3e231a5a1a4beeebb Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 11:10:49 +0400 Subject: [PATCH 21/77] chore: convert artifacts --- .../isaac/image_processing_isaac_fast.py | 28 +- .../models/isaac/modeling_isaac.py | 277 ++++++++++++++---- .../models/isaac/processing_isaac.py | 2 +- 3 files changed, 231 insertions(+), 76 deletions(-) diff --git a/src/transformers/models/isaac/image_processing_isaac_fast.py b/src/transformers/models/isaac/image_processing_isaac_fast.py index ce599e6e0508..57c8eb3062e8 100644 --- a/src/transformers/models/isaac/image_processing_isaac_fast.py +++ b/src/transformers/models/isaac/image_processing_isaac_fast.py @@ -88,7 +88,7 @@ import math from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any import torch import torch.nn.functional as F @@ -294,7 +294,7 @@ def resize( self, image: torch.Tensor, size: SizeDict, - interpolation: Optional[Any] = None, + interpolation: Any | None = None, antialias: bool = True, **kwargs, ) -> torch.Tensor: @@ -329,19 +329,19 @@ def _preprocess( self, images: list[torch.Tensor], do_resize: bool, - size: Optional[SizeDict], - interpolation: Optional[Any], + size: SizeDict | None, + interpolation: Any | None, do_center_crop: bool, - crop_size: Optional[SizeDict], - do_rescale: Optional[bool], - rescale_factor: Optional[float], - do_normalize: Optional[bool], - image_mean: Optional[Union[float, Sequence[float]]], - image_std: Optional[Union[float, Sequence[float]]], - disable_grouping: Optional[bool] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - do_pad: Optional[bool] = None, - pad_size: Optional[SizeDict] = None, + crop_size: SizeDict | None, + do_rescale: bool | None, + rescale_factor: float | None, + do_normalize: bool | None, + image_mean: float | Sequence[float] | None, + image_std: float | Sequence[float] | None, + disable_grouping: bool | None = None, + return_tensors: str | TensorType | None = None, + do_pad: bool | None = None, + pad_size: SizeDict | None = None, *, patch_size: int | None = None, max_num_patches: int | None = None, diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index a0aaa38c9f18..ae231cdbe53a 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -93,6 +93,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from genesis.public.tensorstream.tensor_stream import TensorStream, TextType, VisionType, group_streams from genesis.public.tensorstream.tensor_stream_utils import ( compute_mrope_pos_tensor, @@ -107,7 +108,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...models.auto.modeling_auto import AutoModel from ...models.qwen3.configuration_qwen3 import Qwen3Config @@ -120,13 +121,91 @@ from .configuration_isaac import IsaacConfig, IsaacVisionConfig -class IsaacVisionEmbeddings(HFSiglip2VisionEmbeddings): +class IsaacVisionEmbeddings(nn.Module): """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences.""" def __init__(self, config: IsaacVisionConfig): - super().__init__(config) + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Linear( + in_features=config.num_channels * self.patch_size * self.patch_size, + out_features=self.embed_dim, + ) + + self.num_patches = config.num_patches + self.position_embedding_size = int(self.num_patches**0.5) + self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + + @staticmethod + def resize_positional_embeddings( + positional_embeddings: torch.Tensor, + spatial_shapes: torch.LongTensor, + max_length: int, + ) -> torch.Tensor: + """ + Resize positional embeddings to image-specific size and pad to a fixed size. + + Args: + positional_embeddings (`torch.Tensor`): + Position embeddings of shape (height, width, embed_dim) + spatial_shapes (`torch.LongTensor`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + max_length (`int`): + Maximum length of the positional embeddings to pad resized positional embeddings to + + Returns: + `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) + """ + batch_size = spatial_shapes.shape[0] + embed_dim = positional_embeddings.shape[-1] + source_dtype = positional_embeddings.dtype + + resulted_positional_embeddings = torch.empty( + (batch_size, max_length, embed_dim), + device=positional_embeddings.device, + dtype=source_dtype, + ) + + # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation + positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0) + + # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU + if positional_embeddings.device.type == "cpu": + positional_embeddings = positional_embeddings.to(torch.float32) + + for i in range(batch_size): + # (1, dim, height, width) -> (1, dim, target_height, target_width) + height, width = spatial_shapes[i] + resized_embeddings = F.interpolate( + positional_embeddings, + size=(height, width), + mode="bilinear", + align_corners=False, + antialias=True, + ) + + # (1, dim, target_height, target_width) -> (target_height * target_width, dim) + resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1) + + # Cast to original dtype + resized_embeddings = resized_embeddings.to(source_dtype) + + resulted_positional_embeddings[i, : height * width] = resized_embeddings + resulted_positional_embeddings[i, height * width :] = resized_embeddings[0] + + return resulted_positional_embeddings def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor`): + Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size) + spatial_shapes (`list[tuple[int, int]]`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + """ packed_pixel_values, seq_lengths = self._pack_to_batch(seq_patches, spatial_shapes) if packed_pixel_values is None: return seq_patches.new_zeros((0, self.embed_dim)) @@ -211,12 +290,12 @@ def build_document_attention_mask( def ensure_document_attention_mask( - attention_mask: Optional[torch.Tensor], - cu_seqlens: Optional[torch.Tensor], + attention_mask: torch.Tensor | None, + cu_seqlens: torch.Tensor | None, total_tokens: int, dtype: torch.dtype, device: torch.device, -) -> Optional[torch.Tensor]: +) -> torch.Tensor | None: if attention_mask is not None or cu_seqlens is None: return attention_mask @@ -242,7 +321,7 @@ class IsaacVisionAttention(nn.Module): "isaac_eager": "isaac_eager", } - def __init__(self, vision_config): + def __init__(self, config): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -261,7 +340,6 @@ def __init__(self, vision_config): self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.vision_config = vision_config self._variable_length_metadata = None def forward(self, hidden_states, attention_mask=None, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -293,7 +371,7 @@ def forward(self, hidden_states, attention_mask=None, **kwargs) -> tuple[torch.T k = self.k_proj(x).view(L, H, D) v = self.v_proj(x).view(L, H, D) - attn_impl = getattr(self.vision_config, "_attn_implementation", "flash_attention_3") + attn_impl = getattr(self.config, "_attn_implementation", "flash_attention_3") attn_mask = ensure_document_attention_mask( attention_mask, @@ -351,21 +429,50 @@ def _consume_variable_length_metadata(self): return cu_seqlens, max_seqlen -class IsaacVisionEncoderLayer(HFSiglip2EncoderLayer): +class IsaacMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class IsaacVisionEncoderLayer(GradientCheckpointingLayer): """Isaac vision encoder layer with variable-length attention.""" - def __init__(self, vision_config: IsaacVisionConfig): - super().__init__(vision_config) - self.self_attn = IsaacVisionAttention(vision_config) + def __init__(self, config: IsaacVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = IsaacVisionAttention(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = IsaacMLP(config) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, + attention_mask: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, output_attentions: bool = False, - ): + ) -> tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ if cu_seqlens is not None or max_seqlen is not None: self.self_attn._variable_length_context( cu_seqlens=cu_seqlens, @@ -379,43 +486,72 @@ def forward( hidden_states.dtype, hidden_states.device, ) + residual = hidden_states - return super().forward( - hidden_states, + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs -class IsaacVisionEncoder(HFSiglip2Encoder): +class IsaacVisionEncoder(nn.Module): """Encoder using Isaac encoder layers with variable-length attention support.""" def __init__(self, config: IsaacVisionConfig): - super().__init__(config) + super().__init__() + self.config = config self.layers = nn.ModuleList([IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False - def __variable_length_context(self, cu_seqlens, max_seqlen) -> None: - if cu_seqlens is None and max_seqlen is None: - return - - for layer in self.layers: - if isinstance(layer, IsaacVisionEncoderLayer): - layer.self_attn._variable_length_context( - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - + # Ignore copy @can_return_tuple def forward( self, inputs_embeds, - attention_mask: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): + attention_mask: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> BaseModelOutput: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ self.__variable_length_context(cu_seqlens, max_seqlen) attention_mask = ensure_document_attention_mask( @@ -425,15 +561,50 @@ def forward( inputs_embeds.dtype, inputs_embeds.device, ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) - return super().forward( - inputs_embeds, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, ) + def __variable_length_context(self, cu_seqlens, max_seqlen) -> None: + if cu_seqlens is None and max_seqlen is None: + return + + for layer in self.layers: + if isinstance(layer, IsaacVisionEncoderLayer): + layer.self_attn._variable_length_context( + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + def create_pixel_shuffle_index_map( seq_sizes: torch.Tensor, @@ -719,22 +890,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class IsaacMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 47b0b5a766af..003d36988e9a 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -165,7 +165,7 @@ class IsaacProcessor(ProcessorMixin): def __init__( self, - image_processor: "IsaacImageProcessorFast | None" = None, + image_processor: "IsaacImageProcessorFast | None" = None, # noqa tokenizer: Qwen2Tokenizer | None = None, *, vision_token: str = "", From ffb3b9f2cae54c00624b03b75efb17b899473392 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 11:35:43 +0400 Subject: [PATCH 22/77] style: make style changes --- src/transformers/models/__init__.py | 2 +- src/transformers/models/isaac/__init__.py | 2 +- tests/models/isaac/__init__.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 4377af520590..4cf99b94ceff 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -170,6 +170,7 @@ from .instructblip import * from .instructblipvideo import * from .internvl import * + from .isaac import * from .jamba import * from .janus import * from .jetmoe import * @@ -256,7 +257,6 @@ from .pegasus_x import * from .perceiver import * from .perception_lm import * - from .isaac import * from .persimmon import * from .phi import * from .phi3 import * diff --git a/src/transformers/models/isaac/__init__.py b/src/transformers/models/isaac/__init__.py index fbc25598385d..8ff2b88ec9af 100644 --- a/src/transformers/models/isaac/__init__.py +++ b/src/transformers/models/isaac/__init__.py @@ -25,4 +25,4 @@ import sys _file = globals()["__file__"] - sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) \ No newline at end of file + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/tests/models/isaac/__init__.py b/tests/models/isaac/__init__.py index 199f5353a864..2f76d5676d10 100644 --- a/tests/models/isaac/__init__.py +++ b/tests/models/isaac/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. From 92c36d4973fd83f1362633a3c2bb5d88949a4764 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 11:39:04 +0400 Subject: [PATCH 23/77] refactor: bespoke isaac config --- .../models/isaac/modular_isaac.py | 72 +++++++++++++++++-- 1 file changed, 66 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index a82ae31ac057..4b4c3a12ab62 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -136,6 +136,7 @@ from ...processing_utils import ProcessorMixin, Unpack from ...tokenization_utils import TensorType from ...utils import auto_docstring +from ...configuration_utils import PretrainedConfig # Vision preprocessing constants from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN @@ -177,15 +178,35 @@ class IsaacVisionConfig(Siglip2VisionConfig): def __init__( self, - pixel_shuffle_scale_factor: int = 1, - num_patches: int = 256, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + num_patches=256, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + pixel_shuffle_scale_factor=1, **kwargs, ): + # Copied from transformers.models.siglip2.configuration_siglip2.Siglip2VisionConfig.__init__ super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.num_patches = num_patches + # Add our custom fields self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor - self.num_patches = num_patches if self._attn_implementation is None: self._attn_implementation = "flash_attention_2" @@ -1318,7 +1339,7 @@ def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: return patches -class IsaacConfig(Qwen3Config): +class IsaacConfig(PretrainedConfig): """Configuration class for Isaac multimodal model.""" model_type = "isaac" @@ -1347,13 +1368,52 @@ def __init__( text_config_kwargs.update(kwargs) - super().__init__(**text_config_kwargs) - self.text_config = Qwen3Config(**text_config_kwargs) + self.text_config = self.sub_configs["text_config"](**text_config_kwargs) + + super().__init__(**kwargs) + if self._rope_scaling is None: self._rope_scaling = getattr(self.text_config, "rope_scaling", None) else: self.text_config.rope_scaling = self._rope_scaling + # Mirror frequently accessed Qwen3 attributes at the composite config level for BC. + self.tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False) + self.vocab_size = self.text_config.vocab_size + self.max_position_embeddings = self.text_config.max_position_embeddings + self.hidden_size = self.text_config.hidden_size + self.intermediate_size = self.text_config.intermediate_size + self.num_hidden_layers = self.text_config.num_hidden_layers + self.num_attention_heads = self.text_config.num_attention_heads + self.use_sliding_window = getattr(self.text_config, "use_sliding_window", False) + sliding_window = getattr(self.text_config, "sliding_window", None) + self.sliding_window = sliding_window if self.use_sliding_window else None + self.max_window_layers = getattr(self.text_config, "max_window_layers", None) + self.num_key_value_heads = getattr(self.text_config, "num_key_value_heads", None) + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + self.head_dim = self.text_config.head_dim + self.hidden_act = self.text_config.hidden_act + self.initializer_range = self.text_config.initializer_range + self.rms_norm_eps = self.text_config.rms_norm_eps + self.use_cache = self.text_config.use_cache + self.rope_theta = self.text_config.rope_theta + self.attention_bias = getattr(self.text_config, "attention_bias", False) + self.attention_dropout = getattr(self.text_config, "attention_dropout", 0.0) + + # Validate rotary parameters now that they have been mirrored locally. + rope_config_validation(self) + + self.layer_types = getattr(self.text_config, "layer_types", None) + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + # Handle vision config - either dict or IsaacVisionConfig instance if isinstance(vision_config, dict): self.vision_config = self.sub_configs["vision_config"](**vision_config) From 79eb96bfa4c69c43fdad1ae4134254c937ed14a0 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 11:39:50 +0400 Subject: [PATCH 24/77] style: ruff organize imports --- src/transformers/models/isaac/modular_isaac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 4b4c3a12ab62..41e281702892 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -113,6 +113,7 @@ ) from ...cache_utils import Cache, SlidingWindowCache, StaticCache +from ...configuration_utils import PretrainedConfig from ...feature_extraction_utils import BatchFeature from ...generation.utils import GenerationMixin from ...image_processing_utils_fast import ( @@ -136,7 +137,6 @@ from ...processing_utils import ProcessorMixin, Unpack from ...tokenization_utils import TensorType from ...utils import auto_docstring -from ...configuration_utils import PretrainedConfig # Vision preprocessing constants from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN From 1c6479aae1cf9765bf654dce1e3b74383980f203 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 11:41:32 +0400 Subject: [PATCH 25/77] chore: convert configuration artifact --- .../models/isaac/configuration_isaac.py | 102 ++++++++---------- 1 file changed, 46 insertions(+), 56 deletions(-) diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index a076ed7cc1d6..29f351142f24 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -113,8 +113,17 @@ class IsaacVisionConfig(PretrainedConfig): def __init__( self, - pixel_shuffle_scale_factor: int = 1, - num_patches: int = 256, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + num_patches=256, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + pixel_shuffle_scale_factor=1, **kwargs, ): super().__init__(**kwargs) @@ -141,23 +150,6 @@ class IsaacConfig(PretrainedConfig): """Configuration class for Isaac multimodal model.""" model_type = "isaac" - keys_to_ignore_at_inference = ["past_key_values"] - - # Default tensor parallel plan for base model `Isaac` - base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate_proj": "colwise", - "layers.*.mlp.up_proj": "colwise", - "layers.*.mlp.down_proj": "rowwise", - } - base_model_pp_plan = { - "embed_tokens": (["input_ids"], ["inputs_embeds"]), - "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), - "norm": (["hidden_states"], ["hidden_states"]), - } sub_configs = {"vision_config": IsaacVisionConfig, "text_config": Qwen3Config} image_processor_type = "IsaacImageProcessor" @@ -170,10 +162,6 @@ def __init__( vision_token: str = "", **kwargs, ): - super().__init__( - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) self._rope_scaling: dict[str, Any] | None = None resolved_text_config = kwargs.pop("text_config", text_config) if isinstance(resolved_text_config, Qwen3Config): @@ -186,37 +174,44 @@ def __init__( raise TypeError("`text_config` must be a mapping or `Qwen3Config` instance when provided.") text_config_kwargs.update(kwargs) - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window if self.use_sliding_window else None - self.max_window_layers = max_window_layers - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads + self.text_config = self.sub_configs["text_config"](**text_config_kwargs) - self.num_key_value_heads = num_key_value_heads - self.head_dim = head_dim - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.rope_scaling = rope_scaling - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - # Validate the correctness of rotary position embeddings parameters - # BC: if there is a 'type' field, move it to 'rope_type'. - if self.rope_scaling is not None and "type" in self.rope_scaling: - self.rope_scaling["rope_type"] = self.rope_scaling["type"] + super().__init__(**kwargs) + + if self._rope_scaling is None: + self._rope_scaling = getattr(self.text_config, "rope_scaling", None) + else: + self.text_config.rope_scaling = self._rope_scaling + + # Mirror frequently accessed Qwen3 attributes at the composite config level for BC. + self.tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False) + self.vocab_size = self.text_config.vocab_size + self.max_position_embeddings = self.text_config.max_position_embeddings + self.hidden_size = self.text_config.hidden_size + self.intermediate_size = self.text_config.intermediate_size + self.num_hidden_layers = self.text_config.num_hidden_layers + self.num_attention_heads = self.text_config.num_attention_heads + self.use_sliding_window = getattr(self.text_config, "use_sliding_window", False) + sliding_window = getattr(self.text_config, "sliding_window", None) + self.sliding_window = sliding_window if self.use_sliding_window else None + self.max_window_layers = getattr(self.text_config, "max_window_layers", None) + self.num_key_value_heads = getattr(self.text_config, "num_key_value_heads", None) + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + self.head_dim = self.text_config.head_dim + self.hidden_act = self.text_config.hidden_act + self.initializer_range = self.text_config.initializer_range + self.rms_norm_eps = self.text_config.rms_norm_eps + self.use_cache = self.text_config.use_cache + self.rope_theta = self.text_config.rope_theta + self.attention_bias = getattr(self.text_config, "attention_bias", False) + self.attention_dropout = getattr(self.text_config, "attention_dropout", 0.0) + + # Validate rotary parameters now that they have been mirrored locally. rope_config_validation(self) - self.layer_types = layer_types + self.layer_types = getattr(self.text_config, "layer_types", None) if self.layer_types is None: self.layer_types = [ "sliding_attention" @@ -224,12 +219,7 @@ def __init__( else "full_attention" for i in range(self.num_hidden_layers) ] - layer_type_validation(self.layer_types) - self.text_config = Qwen3Config(**text_config_kwargs) - if self._rope_scaling is None: - self._rope_scaling = getattr(self.text_config, "rope_scaling", None) - else: - self.text_config.rope_scaling = self._rope_scaling + layer_type_validation(self.layer_types, self.num_hidden_layers) # Handle vision config - either dict or IsaacVisionConfig instance if isinstance(vision_config, dict): From 289921655854e523a05d80caa5edc39dc8ba4fe9 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 11:45:21 +0400 Subject: [PATCH 26/77] fix: get imports in --- src/transformers/models/isaac/modular_isaac.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 41e281702892..53b83fa2350c 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -113,7 +113,7 @@ ) from ...cache_utils import Cache, SlidingWindowCache, StaticCache -from ...configuration_utils import PretrainedConfig +from ...configuration_utils import PretrainedConfig, layer_type_validation from ...feature_extraction_utils import BatchFeature from ...generation.utils import GenerationMixin from ...image_processing_utils_fast import ( @@ -129,6 +129,7 @@ ) from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...models.auto.modeling_auto import AutoModel from ...models.auto.tokenization_auto import AutoTokenizer From aec77214c055fb9ce75816eff1a3e2f204f7855e Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 11:56:55 +0400 Subject: [PATCH 27/77] style: string typing of qwen2 --- src/transformers/models/isaac/modular_isaac.py | 4 ++-- src/transformers/models/isaac/processing_isaac.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 53b83fa2350c..3527900200a6 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -1519,8 +1519,8 @@ class IsaacProcessor(ProcessorMixin): def __init__( self, - image_processor: "IsaacImageProcessorFast | None" = None, # noqa - tokenizer: Qwen2Tokenizer | None = None, + image_processor: "IsaacImageProcessorFast | None" = None, # noqa: UP037 + tokenizer: "Qwen2Tokenizer | None" = None, # noqa: UP037 *, vision_token: str = "", max_sequence_length: int = 16384, diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 003d36988e9a..0bef663e9e3b 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -165,8 +165,8 @@ class IsaacProcessor(ProcessorMixin): def __init__( self, - image_processor: "IsaacImageProcessorFast | None" = None, # noqa - tokenizer: Qwen2Tokenizer | None = None, + image_processor: "IsaacImageProcessorFast | None" = None, # noqa: UP037 + tokenizer: "Qwen2Tokenizer | None" = None, # noqa: UP037 *, vision_token: str = "", max_sequence_length: int = 16384, From c0b10b648f419ac28162555c2d8daca44aa5df97 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 12:14:43 +0400 Subject: [PATCH 28/77] fix: remove image processor and tokenizer typing --- src/transformers/models/isaac/modular_isaac.py | 5 ++--- src/transformers/models/isaac/processing_isaac.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 3527900200a6..c6b60d5d1ae3 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -144,7 +144,6 @@ from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD from ...utils.generic import can_return_tuple from ...utils.import_utils import is_torchdynamo_compiling -from ..qwen2.tokenization_qwen2 import Qwen2Tokenizer from ..qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding from ..siglip2.configuration_siglip2 import Siglip2VisionConfig from ..siglip2.modeling_siglip2 import ( @@ -1519,8 +1518,8 @@ class IsaacProcessor(ProcessorMixin): def __init__( self, - image_processor: "IsaacImageProcessorFast | None" = None, # noqa: UP037 - tokenizer: "Qwen2Tokenizer | None" = None, # noqa: UP037 + image_processor, + tokenizer, *, vision_token: str = "", max_sequence_length: int = 16384, diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 0bef663e9e3b..87d37f773668 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -165,8 +165,8 @@ class IsaacProcessor(ProcessorMixin): def __init__( self, - image_processor: "IsaacImageProcessorFast | None" = None, # noqa: UP037 - tokenizer: "Qwen2Tokenizer | None" = None, # noqa: UP037 + image_processor, + tokenizer, *, vision_token: str = "", max_sequence_length: int = 16384, From 302374d33061a75f8504241810ef49ff79f00575 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 12:16:59 +0400 Subject: [PATCH 29/77] fix: enable qwen_2_5_vl import --- src/transformers/models/isaac/modeling_isaac.py | 3 ++- src/transformers/models/isaac/modular_isaac.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index ae231cdbe53a..93c1f3884d72 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -118,6 +118,7 @@ from ...utils.deprecation import deprecate_kwarg from ...utils.generic import can_return_tuple from ...utils.import_utils import is_torchdynamo_compiling +from ..qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling from .configuration_isaac import IsaacConfig, IsaacVisionConfig @@ -796,7 +797,7 @@ def __init__(self, config: IsaacConfig, device=None): config_for_rope.rope_scaling = sanitized_scaling if sanitized_scaling else None init_device = device if device is not None and getattr(device, "type", None) != "meta" else None - self._qwen_rotary = Qwen2_5_VLRotaryEmbedding(config_for_rope, device=init_device) + self._qwen_rotary = qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding(config_for_rope, device=init_device) rotary_half_dim = self._qwen_rotary.inv_freq.shape[0] self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index c6b60d5d1ae3..f2ed7dba569d 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -144,7 +144,7 @@ from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD from ...utils.generic import can_return_tuple from ...utils.import_utils import is_torchdynamo_compiling -from ..qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding +from ..qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling from ..siglip2.configuration_siglip2 import Siglip2VisionConfig from ..siglip2.modeling_siglip2 import ( Siglip2Attention, @@ -1708,7 +1708,7 @@ def __init__(self, config: IsaacConfig, device=None): config_for_rope.rope_scaling = sanitized_scaling if sanitized_scaling else None init_device = device if device is not None and getattr(device, "type", None) != "meta" else None - self._qwen_rotary = Qwen2_5_VLRotaryEmbedding(config_for_rope, device=init_device) + self._qwen_rotary = qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding(config_for_rope, device=init_device) rotary_half_dim = self._qwen_rotary.inv_freq.shape[0] self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim) From de9dc808d831d16f0fbc822b02de128cc66be30b Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 12:20:52 +0400 Subject: [PATCH 30/77] style: remove unnecessary copy text --- src/transformers/models/isaac/modular_isaac.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index f2ed7dba569d..d9c5f45df340 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -191,7 +191,6 @@ def __init__( pixel_shuffle_scale_factor=1, **kwargs, ): - # Copied from transformers.models.siglip2.configuration_siglip2.Siglip2VisionConfig.__init__ super().__init__(**kwargs) self.hidden_size = hidden_size From 107ecded5bfb9314bec253e6c253831946dea834 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 12:34:38 +0400 Subject: [PATCH 31/77] fix: fix copies fix: proper imports --- .../models/isaac/modeling_isaac.py | 79 +++---------------- .../models/isaac/modular_isaac.py | 20 ++--- .../models/isaac/processing_isaac.py | 6 +- 3 files changed, 23 insertions(+), 82 deletions(-) diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 93c1f3884d72..0ea71ca1a2cb 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -94,12 +94,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from genesis.public.tensorstream.tensor_stream import TensorStream, TextType, VisionType, group_streams -from genesis.public.tensorstream.tensor_stream_utils import ( +from perceptron.tensorstream.ops import ( compute_mrope_pos_tensor, modality_mask, reconstruct_tensor_stream_from_compact_dict, ) +from perceptron.tensorstream.tensorstream import TensorStream, TextType, VisionType, group_streams from ...activations import ACT2FN from ...cache_utils import Cache, SlidingWindowCache, StaticCache @@ -456,6 +456,7 @@ def __init__(self, config: IsaacVisionConfig): self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = IsaacMLP(config) + @auto_docstring def forward( self, hidden_states: torch.Tensor, @@ -463,17 +464,7 @@ def forward( cu_seqlens: torch.Tensor | None = None, max_seqlen: int | None = None, output_attentions: bool = False, - ) -> tuple[torch.FloatTensor]: - """ - Args: - hidden_states (`torch.FloatTensor`): - Input to the layer of shape `(batch, seq_len, embed_dim)`. - attention_mask (`torch.FloatTensor`): - Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ + ) -> torch.FloatTensor: if cu_seqlens is not None or max_seqlen is not None: self.self_attn._variable_length_context( cu_seqlens=cu_seqlens, @@ -490,10 +481,10 @@ def forward( residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states, attn_weights = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, - output_attentions=output_attentions, + **kwargs, ) hidden_states = residual + hidden_states @@ -502,12 +493,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states class IsaacVisionEncoder(nn.Module): @@ -531,28 +517,6 @@ def forward( output_hidden_states: bool | None = None, return_dict: bool | None = None, ) -> BaseModelOutput: - r""" - Args: - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ self.__variable_length_context(cu_seqlens, max_seqlen) attention_mask = ensure_document_attention_mask( @@ -562,38 +526,15 @@ def forward( inputs_embeds.dtype, inputs_embeds.device, ) - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - hidden_states = inputs_embeds for encoder_layer in self.layers: - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - layer_outputs = encoder_layer( + hidden_states = encoder_layer( hidden_states, attention_mask, - output_attentions=output_attentions, + **kwargs, ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=encoder_states, - attentions=all_attentions, - ) + return BaseModelOutput(last_hidden_state=hidden_states) def __variable_length_context(self, cu_seqlens, max_seqlen) -> None: if cu_seqlens is None and max_seqlen is None: diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index d9c5f45df340..e60291185df9 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -93,7 +93,16 @@ import torch import torch.nn as nn import torch.nn.functional as F -from genesis.public.tensorstream.tensor_stream import ( +from perceptron.tensorstream.ops import ( + compute_mrope_pos_tensor, + modality_mask, + reconstruct_tensor_stream_from_compact_dict, + tensor_stream_token_view, +) +from perceptron.tensorstream.ops import ( + slice as ts_slice, +) +from perceptron.tensorstream.tensorstream import ( Event, Stream, TensorStream, @@ -102,15 +111,6 @@ create_stream, group_streams, ) -from genesis.public.tensorstream.tensor_stream_utils import ( - compute_mrope_pos_tensor, - modality_mask, - reconstruct_tensor_stream_from_compact_dict, - tensor_stream_token_view, -) -from genesis.public.tensorstream.tensor_stream_utils import ( - slice as ts_slice, -) from ...cache_utils import Cache, SlidingWindowCache, StaticCache from ...configuration_utils import PretrainedConfig, layer_type_validation diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 87d37f773668..debcc6ec612b 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -91,9 +91,9 @@ import PIL.Image import torch -from genesis.public.tensorstream.tensor_stream import Event, Stream, TensorStream, TextType, VisionType, create_stream -from genesis.public.tensorstream.tensor_stream_utils import slice as ts_slice -from genesis.public.tensorstream.tensor_stream_utils import tensor_stream_token_view +from perceptron.tensorstream.ops import slice as ts_slice +from perceptron.tensorstream.ops import tensor_stream_token_view +from perceptron.tensorstream.tensorstream import Event, Stream, TensorStream, TextType, VisionType, create_stream from ...feature_extraction_utils import BatchFeature from ...image_processing_utils_fast import DefaultFastImageProcessorKwargs From fd5e399e15fbfb2ce85373f9e8ebe5cc12a798d1 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 13:06:10 +0400 Subject: [PATCH 32/77] style: pass kwargs and docstrings --- .../models/isaac/modular_isaac.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index e60291185df9..4a26b3c390a8 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -142,7 +142,7 @@ # Vision preprocessing constants from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD -from ...utils.generic import can_return_tuple +from ...utils.generic import TransformersKwargs, can_return_tuple from ...utils.import_utils import is_torchdynamo_compiling from ..qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling from ..siglip2.configuration_siglip2 import Siglip2VisionConfig @@ -779,7 +779,16 @@ def forward( cu_seqlens: torch.Tensor | None = None, max_seqlen: int | None = None, output_attentions: bool = False, + **kwargs: Unpack[TransformersKwargs], ): + r""" + cu_seqlens (`torch.Tensor`, *optional*): + Prefix-sum tensor whose length equals the number of documents + 1. The difference between successive + entries gives each document's token count and enables block-diagonal attention masking for packed batches. + max_seqlen (`int`, *optional*): + Maximum document length referenced by `cu_seqlens`. Passed to FlashAttention so it can size temporary + buffers for packed variable-length attention. + """ if cu_seqlens is not None or max_seqlen is not None: self.self_attn._variable_length_context( cu_seqlens=cu_seqlens, @@ -798,6 +807,7 @@ def forward( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, + **kwargs, ) @@ -829,6 +839,7 @@ def forward( output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, + **kwargs: Unpack[TransformersKwargs], ): self.__variable_length_context(cu_seqlens, max_seqlen) @@ -846,6 +857,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + **kwargs, ) @@ -2199,9 +2211,13 @@ def forward( cache_position: torch.LongTensor | None = None, **kwargs, ) -> tuple | CausalLMOutputWithPast: - """ + r""" Forward pass for conditional generation supporting both standard inputs and TensorStream. - Uses our embed_stream approach for multimodal inputs. + + tensor_stream (`TensorStream`, *optional*): + Packed multimodal stream (text, vision, audio tokens) that already encodes spatial metadata. When provided, + the model derives embeddings, modality masks, and 3D rotary coordinates directly from the stream instead of + `input_ids`. """ # Don't compute embeddings here - let the model handle it From 206b82a1100a64efe3ac00fc8a49f4ab7e097bed Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 13:06:17 +0400 Subject: [PATCH 33/77] chore: artifact --- .../models/isaac/modeling_isaac.py | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 0ea71ca1a2cb..b8044d2f05c1 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -114,9 +114,9 @@ from ...models.qwen3.configuration_qwen3 import Qwen3Config from ...models.qwen3.modeling_qwen3 import Qwen3PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring +from ...utils import auto_docstring from ...utils.deprecation import deprecate_kwarg -from ...utils.generic import can_return_tuple +from ...utils.generic import TransformersKwargs, can_return_tuple from ...utils.import_utils import is_torchdynamo_compiling from ..qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling from .configuration_isaac import IsaacConfig, IsaacVisionConfig @@ -464,7 +464,16 @@ def forward( cu_seqlens: torch.Tensor | None = None, max_seqlen: int | None = None, output_attentions: bool = False, + **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: + r""" + cu_seqlens (`torch.Tensor`, *optional*): + Prefix-sum tensor whose length equals the number of documents + 1. The difference between successive + entries gives each document's token count and enables block-diagonal attention masking for packed batches. + max_seqlen (`int`, *optional*): + Maximum document length referenced by `cu_seqlens`. Passed to FlashAttention so it can size temporary + buffers for packed variable-length attention. + """ if cu_seqlens is not None or max_seqlen is not None: self.self_attn._variable_length_context( cu_seqlens=cu_seqlens, @@ -516,6 +525,7 @@ def forward( output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, + **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: self.__variable_length_context(cu_seqlens, max_seqlen) @@ -1479,9 +1489,13 @@ def forward( cache_position: torch.LongTensor | None = None, **kwargs, ) -> tuple | CausalLMOutputWithPast: - """ + r""" Forward pass for conditional generation supporting both standard inputs and TensorStream. - Uses our embed_stream approach for multimodal inputs. + + tensor_stream (`TensorStream`, *optional*): + Packed multimodal stream (text, vision, audio tokens) that already encodes spatial metadata. When provided, + the model derives embeddings, modality masks, and 3D rotary coordinates directly from the stream instead of + `input_ids`. """ # Don't compute embeddings here - let the model handle it From 510eb05f80dc26e67943a22ed8e45eb6aa2040d8 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 13:33:09 +0400 Subject: [PATCH 34/77] style: revert UP045 typing for autodocstring to work --- .../models/isaac/configuration_isaac.py | 14 +- .../isaac/image_processing_isaac_fast.py | 56 ++--- .../models/isaac/modeling_isaac.py | 100 ++++---- .../models/isaac/modular_isaac.py | 224 +++++++++--------- .../models/isaac/processing_isaac.py | 21 +- 5 files changed, 207 insertions(+), 208 deletions(-) diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index 29f351142f24..ec7f72c569f7 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -88,7 +88,7 @@ import copy -from typing import Any +from typing import Any, Optional, Union from ...configuration_utils import PretrainedConfig, layer_type_validation from ...modeling_rope_utils import rope_config_validation @@ -109,7 +109,7 @@ class IsaacVisionConfig(PretrainedConfig): model_type = "isaac_vision" base_config_key = "vision_config" - _attn_implementation: str | None = None + _attn_implementation: Optional[str] = None def __init__( self, @@ -155,14 +155,14 @@ class IsaacConfig(PretrainedConfig): def __init__( self, - vision_config: IsaacVisionConfig | None = None, - text_config: Qwen3Config | dict | None = None, + vision_config: Optional[IsaacVisionConfig] = None, + text_config: Optional[Union[Qwen3Config, dict]] = None, vision_rescale_factor: float = 1 / 255, max_sequence_length: int = 16384, vision_token: str = "", **kwargs, ): - self._rope_scaling: dict[str, Any] | None = None + self._rope_scaling: Optional[dict[str, Any]] = None resolved_text_config = kwargs.pop("text_config", text_config) if isinstance(resolved_text_config, Qwen3Config): text_config_kwargs = copy.deepcopy(resolved_text_config.to_dict()) @@ -255,14 +255,14 @@ def rope_scaling(self, value): self.text_config.rope_scaling = value @property - def vision_attn_implementation(self) -> str | None: + def vision_attn_implementation(self) -> Optional[str]: value = getattr(self.vision_config, "_attn_implementation", None) if value is None: value = getattr(self.vision_config, "attn_implementation", None) return value @vision_attn_implementation.setter - def vision_attn_implementation(self, value: str | None) -> None: + def vision_attn_implementation(self, value: Optional[str]) -> None: self.vision_config._attn_implementation = value if value is not None: self.vision_config.attn_implementation = value diff --git a/src/transformers/models/isaac/image_processing_isaac_fast.py b/src/transformers/models/isaac/image_processing_isaac_fast.py index 57c8eb3062e8..0c257e4d3db2 100644 --- a/src/transformers/models/isaac/image_processing_isaac_fast.py +++ b/src/transformers/models/isaac/image_processing_isaac_fast.py @@ -88,7 +88,7 @@ import math from collections.abc import Sequence -from typing import Any +from typing import Any, Optional, Union import torch import torch.nn.functional as F @@ -124,7 +124,7 @@ def get_image_size_for_max_num_patches( image_width: int, patch_size: int, max_num_patches: int, - min_num_patches: int | None = None, + min_num_patches: Optional[int] = None, eps: float = 1e-5, pixel_shuffle_scale: int = 1, ) -> tuple[int, int]: @@ -247,16 +247,16 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): unused_kwargs = ["size", "do_center_crop", "crop_size"] do_resize = True - size: SizeDict | None = None - default_to_square: bool | None = None + size: Optional[SizeDict] = None + default_to_square: Optional[bool] = None do_center_crop = False - crop_size: SizeDict | None = None - patch_size: int | None = 16 - max_num_patches: int | None = 256 - min_num_patches: int | None = None - pixel_shuffle_scale: int | None = 1 + crop_size: Optional[SizeDict] = None + patch_size: Optional[int] = 16 + max_num_patches: Optional[int] = 256 + min_num_patches: Optional[int] = None + pixel_shuffle_scale: Optional[int] = 1 do_pad = False - pad_size: SizeDict | None = None + pad_size: Optional[SizeDict] = None do_rescale = True rescale_factor = 1 / 255 do_normalize = True @@ -268,7 +268,7 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): input_data_format = None device = None disable_grouping = False - size_divisor: int | None = None + size_divisor: Optional[int] = None def __init__( self, @@ -294,7 +294,7 @@ def resize( self, image: torch.Tensor, size: SizeDict, - interpolation: Any | None = None, + interpolation: Optional[Any] = None, antialias: bool = True, **kwargs, ) -> torch.Tensor: @@ -329,24 +329,24 @@ def _preprocess( self, images: list[torch.Tensor], do_resize: bool, - size: SizeDict | None, - interpolation: Any | None, + size: Optional[SizeDict], + interpolation: Optional[Any], do_center_crop: bool, - crop_size: SizeDict | None, - do_rescale: bool | None, - rescale_factor: float | None, - do_normalize: bool | None, - image_mean: float | Sequence[float] | None, - image_std: float | Sequence[float] | None, - disable_grouping: bool | None = None, - return_tensors: str | TensorType | None = None, - do_pad: bool | None = None, - pad_size: SizeDict | None = None, + crop_size: Optional[SizeDict], + do_rescale: Optional[bool], + rescale_factor: Optional[float], + do_normalize: Optional[bool], + image_mean: Optional[Union[float, Sequence[float]]], + image_std: Optional[Union[float, Sequence[float]]], + disable_grouping: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + do_pad: Optional[bool] = None, + pad_size: Optional[SizeDict] = None, *, - patch_size: int | None = None, - max_num_patches: int | None = None, - min_num_patches: int | None = None, - pixel_shuffle_scale: int | None = None, + patch_size: Optional[int] = None, + max_num_patches: Optional[int] = None, + min_num_patches: Optional[int] = None, + pixel_shuffle_scale: Optional[int] = None, **kwargs, ) -> BatchFeature: if do_center_crop: diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index b8044d2f05c1..f0819c867e68 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -218,7 +218,7 @@ def _pack_to_batch( self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor, - ) -> tuple[torch.Tensor | None, torch.Tensor]: + ) -> tuple[Optional[torch.Tensor], torch.Tensor]: if seq_patches.ndim != 2: raise ValueError("`seq_patches` is expected to be 2D (total_patches, patch_dim).") if spatial_shapes.ndim != 2 or spatial_shapes.size(-1) != 2: @@ -266,11 +266,11 @@ def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor def build_document_attention_mask( - cu_seqlens: torch.Tensor | None, + cu_seqlens: Optional[torch.Tensor], total_tokens: int, dtype: torch.dtype, device: torch.device, -) -> torch.Tensor | None: +) -> Optional[torch.Tensor]: """Creates an additive attention mask that blocks cross-document attention.""" if cu_seqlens is None: @@ -291,12 +291,12 @@ def build_document_attention_mask( def ensure_document_attention_mask( - attention_mask: torch.Tensor | None, - cu_seqlens: torch.Tensor | None, + attention_mask: Optional[torch.Tensor], + cu_seqlens: Optional[torch.Tensor], total_tokens: int, dtype: torch.dtype, device: torch.device, -) -> torch.Tensor | None: +) -> Optional[torch.Tensor]: if attention_mask is not None or cu_seqlens is None: return attention_mask @@ -460,9 +460,9 @@ def __init__(self, config: IsaacVisionConfig): def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, + attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, output_attentions: bool = False, **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: @@ -519,12 +519,12 @@ def __init__(self, config: IsaacVisionConfig): def forward( self, inputs_embeds, - attention_mask: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, + attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: self.__variable_length_context(cu_seqlens, max_seqlen) @@ -562,7 +562,7 @@ def create_pixel_shuffle_index_map( seq_sizes: torch.Tensor, token_grids: torch.Tensor, scale_factor: int = 1, - device: torch.device | None = None, + device: Optional[torch.device] = None, ) -> torch.Tensor: """ Build a gather-index map that tells us, for every *output* token after @@ -755,7 +755,7 @@ def __init__(self, config: IsaacConfig, device=None): self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size @staticmethod - def _resolve_mrope_section(section: list[int] | None, rotary_half_dim: int) -> list[int]: + def _resolve_mrope_section(section: Optional[list[int]], rotary_half_dim: int) -> list[int]: if section is None: weights = (2, 1, 1) base = [rotary_half_dim * w // sum(weights) for w in weights] @@ -784,7 +784,7 @@ def forward( self, position_ids: torch.Tensor, modality_tensor: torch.Tensor, - hidden_states: torch.Tensor | None = None, + hidden_states: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: if position_ids.ndim != 3 or position_ids.size(-1) != 3: raise ValueError("`position_ids` must have shape (batch, seq_len, 3) for MRoPE") @@ -1186,17 +1186,17 @@ def embed_stream(self, tensor_stream: TensorStream) -> torch.Tensor: def forward( self, - input_ids: torch.LongTensor | None = None, - tensor_stream: TensorStream | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - modality_tensor: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | None = None, - inputs_embeds: torch.FloatTensor | None = None, - use_cache: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - cache_position: torch.LongTensor | None = None, + input_ids: Optional[torch.LongTensor] = None, + tensor_stream: Optional[TensorStream] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + modality_tensor: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple | BaseModelOutputWithPast: """ @@ -1476,22 +1476,20 @@ def __init__(self, config: IsaacConfig): @auto_docstring def forward( self, - input_ids: torch.LongTensor | None = None, - tensor_stream: TensorStream | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - cache_position: torch.LongTensor | None = None, + input_ids: Optional[torch.LongTensor] = None, + tensor_stream: Optional[TensorStream] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple | CausalLMOutputWithPast: r""" - Forward pass for conditional generation supporting both standard inputs and TensorStream. - tensor_stream (`TensorStream`, *optional*): Packed multimodal stream (text, vision, audio tokens) that already encodes spatial metadata. When provided, the model derives embeddings, modality masks, and 3D rotary coordinates directly from the stream instead of @@ -1561,9 +1559,9 @@ def forward( def get_rope_index( self, - input_ids: torch.Tensor | None, - tensor_stream: TensorStream | None, - attention_mask: torch.Tensor | None, + input_ids: Optional[torch.Tensor], + tensor_stream: Optional[TensorStream], + attention_mask: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: """Compute MRoPE position ids from a TensorStream (or 1D fallback). @@ -1595,12 +1593,12 @@ def get_rope_index( def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, - past_key_values: list[torch.FloatTensor] | None = None, - attention_mask: torch.Tensor | None = None, - inputs_embeds: torch.FloatTensor | None = None, - tensor_stream: TensorStream | None = None, - cache_position: torch.LongTensor | None = None, - position_ids: torch.LongTensor | None = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + tensor_stream: Optional[TensorStream] = None, + cache_position: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, use_cache: bool = True, **kwargs, ) -> dict[str, Any]: diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 4a26b3c390a8..a8c22174ed35 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -87,7 +87,7 @@ import re from collections import defaultdict from collections.abc import Sequence -from typing import Any, Callable +from typing import Any, Callable, Optional, Union import PIL.Image import torch @@ -154,7 +154,7 @@ ) -_ORIGINAL_ATTENTION_FUNCTIONS: dict[str, Callable[..., tuple[torch.Tensor, torch.Tensor | None]]] = {} +_ORIGINAL_ATTENTION_FUNCTIONS: dict[str, Callable[..., tuple[torch.Tensor, Optional[torch.Tensor]]]] = {} for _attn_name in ("flash_attention_2", "sdpa", "eager"): if _attn_name in ALL_ATTENTION_FUNCTIONS: _ORIGINAL_ATTENTION_FUNCTIONS[_attn_name] = ALL_ATTENTION_FUNCTIONS[_attn_name] @@ -174,7 +174,7 @@ class IsaacVisionConfig(Siglip2VisionConfig): model_type = "isaac_vision" base_config_key = "vision_config" - _attn_implementation: str | None = None + _attn_implementation: Optional[str] = None def __init__( self, @@ -212,10 +212,10 @@ def __init__( class IsaacImageProcessorKwargs(DefaultFastImageProcessorKwargs, total=False): - patch_size: int | None - max_num_patches: int | None - min_num_patches: int | None - pixel_shuffle_scale: int | None + patch_size: Optional[int] + max_num_patches: Optional[int] + min_num_patches: Optional[int] + pixel_shuffle_scale: Optional[int] @auto_docstring @@ -229,16 +229,16 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): unused_kwargs = ["size", "do_center_crop", "crop_size"] do_resize = True - size: SizeDict | None = None - default_to_square: bool | None = None + size: Optional[SizeDict] = None + default_to_square: Optional[bool] = None do_center_crop = False - crop_size: SizeDict | None = None - patch_size: int | None = 16 - max_num_patches: int | None = 256 - min_num_patches: int | None = None - pixel_shuffle_scale: int | None = 1 + crop_size: Optional[SizeDict] = None + patch_size: Optional[int] = 16 + max_num_patches: Optional[int] = 256 + min_num_patches: Optional[int] = None + pixel_shuffle_scale: Optional[int] = 1 do_pad = False - pad_size: SizeDict | None = None + pad_size: Optional[SizeDict] = None do_rescale = True rescale_factor = 1 / 255 do_normalize = True @@ -250,7 +250,7 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): input_data_format = None device = None disable_grouping = False - size_divisor: int | None = None + size_divisor: Optional[int] = None def __init__( self, @@ -276,7 +276,7 @@ def resize( self, image: torch.Tensor, size: SizeDict, - interpolation: Any | None = None, + interpolation: Optional[Any] = None, antialias: bool = True, **kwargs, ) -> torch.Tensor: @@ -311,24 +311,24 @@ def _preprocess( self, images: list[torch.Tensor], do_resize: bool, - size: SizeDict | None, - interpolation: Any | None, + size: Optional[SizeDict], + interpolation: Optional[Any], do_center_crop: bool, - crop_size: SizeDict | None, - do_rescale: bool | None, - rescale_factor: float | None, - do_normalize: bool | None, - image_mean: float | Sequence[float] | None, - image_std: float | Sequence[float] | None, - disable_grouping: bool | None = None, - return_tensors: str | TensorType | None = None, - do_pad: bool | None = None, - pad_size: SizeDict | None = None, + crop_size: Optional[SizeDict], + do_rescale: Optional[bool], + rescale_factor: Optional[float], + do_normalize: Optional[bool], + image_mean: Optional[Union[float, Sequence[float]]], + image_std: Optional[Union[float, Sequence[float]]], + disable_grouping: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + do_pad: Optional[bool] = None, + pad_size: Optional[SizeDict] = None, *, - patch_size: int | None = None, - max_num_patches: int | None = None, - min_num_patches: int | None = None, - pixel_shuffle_scale: int | None = None, + patch_size: Optional[int] = None, + max_num_patches: Optional[int] = None, + min_num_patches: Optional[int] = None, + pixel_shuffle_scale: Optional[int] = None, **kwargs, ) -> BatchFeature: if do_center_crop: @@ -460,7 +460,7 @@ def _preprocess( ) -def _max_from_cu(cu: torch.Tensor | None, fallback: int) -> int: +def _max_from_cu(cu: Optional[torch.Tensor], fallback: int) -> int: """Helper to compute max sequence length from cumulative sequence lengths.""" if cu is None or len(cu) < 2: return fallback @@ -468,11 +468,11 @@ def _max_from_cu(cu: torch.Tensor | None, fallback: int) -> int: def build_document_attention_mask( - cu_seqlens: torch.Tensor | None, + cu_seqlens: Optional[torch.Tensor], total_tokens: int, dtype: torch.dtype, device: torch.device, -) -> torch.Tensor | None: +) -> Optional[torch.Tensor]: """Creates an additive attention mask that blocks cross-document attention.""" if cu_seqlens is None: @@ -493,12 +493,12 @@ def build_document_attention_mask( def ensure_document_attention_mask( - attention_mask: torch.Tensor | None, - cu_seqlens: torch.Tensor | None, + attention_mask: Optional[torch.Tensor], + cu_seqlens: Optional[torch.Tensor], total_tokens: int, dtype: torch.dtype, device: torch.device, -) -> torch.Tensor | None: +) -> Optional[torch.Tensor]: if attention_mask is not None or cu_seqlens is None: return attention_mask @@ -515,12 +515,12 @@ def flash_attention_document_mask_forward( q_lhd: torch.Tensor, # (L, H, D) k_lhd: torch.Tensor, # (L, H, D) v_lhd: torch.Tensor, # (L, H, D) - attention_mask: torch.Tensor | None = None, # unused for FA path + attention_mask: Optional[torch.Tensor] = None, # unused for FA path dropout: float = 0.0, - scaling: float | None = None, - cum_seq_q: torch.Tensor | None = None, - cum_seq_k: torch.Tensor | None = None, - max_seqlen: int | None = None, + scaling: Optional[float] = None, + cum_seq_q: Optional[torch.Tensor] = None, + cum_seq_k: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, is_causal: bool = False, **kwargs, ) -> tuple[torch.Tensor, None]: @@ -566,9 +566,9 @@ def sdpa_document_mask_forward( k_lhd: torch.Tensor, # (L, H, D) v_lhd: torch.Tensor, # (L, H, D) dropout: float, - scaling: float | None, - attention_mask: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, + scaling: Optional[float], + attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, ) -> torch.Tensor: """SDPA with block-diagonal masking for variable-length sequences.""" L, H, D = q_lhd.shape @@ -613,7 +613,7 @@ def _pack_to_batch( self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor, - ) -> tuple[torch.Tensor | None, torch.Tensor]: + ) -> tuple[Optional[torch.Tensor], torch.Tensor]: if seq_patches.ndim != 2: raise ValueError("`seq_patches` is expected to be 2D (total_patches, patch_dim).") if spatial_shapes.ndim != 2 or spatial_shapes.size(-1) != 2: @@ -775,9 +775,9 @@ def __init__(self, config: IsaacVisionConfig): def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, + attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, output_attentions: bool = False, **kwargs: Unpack[TransformersKwargs], ): @@ -833,12 +833,12 @@ def __variable_length_context(self, cu_seqlens, max_seqlen) -> None: def forward( self, inputs_embeds, - attention_mask: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, + attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): self.__variable_length_context(cu_seqlens, max_seqlen) @@ -866,12 +866,12 @@ def _isaac_flash_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: torch.Tensor | None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, - scaling: float | None = None, + scaling: Optional[float] = None, is_causal: bool = False, **kwargs, -) -> tuple[torch.Tensor, torch.Tensor | None]: +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: base_fn = _ORIGINAL_ATTENTION_FUNCTIONS.get("flash_attention_2") if not isinstance(module, IsaacVisionAttention) or base_fn is None: if base_fn is None: @@ -937,12 +937,12 @@ def _isaac_sdpa_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: torch.Tensor | None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, - scaling: float | None = None, + scaling: Optional[float] = None, is_causal: bool = False, **kwargs, -) -> tuple[torch.Tensor, torch.Tensor | None]: +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: base_fn = _ORIGINAL_ATTENTION_FUNCTIONS.get("sdpa") if not isinstance(module, IsaacVisionAttention) or base_fn is None: if base_fn is None: @@ -1001,12 +1001,12 @@ def _isaac_eager_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: torch.Tensor | None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, - scaling: float | None = None, + scaling: Optional[float] = None, is_causal: bool = False, **kwargs, -) -> tuple[torch.Tensor, torch.Tensor | None]: +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: base_fn = _ORIGINAL_ATTENTION_FUNCTIONS.get("eager") if not isinstance(module, IsaacVisionAttention) or base_fn is None: if base_fn is None: @@ -1061,7 +1061,7 @@ def create_pixel_shuffle_index_map( seq_sizes: torch.Tensor, token_grids: torch.Tensor, scale_factor: int = 1, - device: torch.device | None = None, + device: Optional[torch.device] = None, ) -> torch.Tensor: """ Build a gather-index map that tells us, for every *output* token after @@ -1251,7 +1251,7 @@ def get_image_size_for_max_num_patches( image_width: int, patch_size: int, max_num_patches: int, - min_num_patches: int | None = None, + min_num_patches: Optional[int] = None, eps: float = 1e-5, pixel_shuffle_scale: int = 1, ) -> tuple[int, int]: @@ -1359,14 +1359,14 @@ class IsaacConfig(PretrainedConfig): def __init__( self, - vision_config: IsaacVisionConfig | None = None, - text_config: Qwen3Config | dict | None = None, + vision_config: Optional[IsaacVisionConfig] = None, + text_config: Optional[Union[Qwen3Config, dict]] = None, vision_rescale_factor: float = 1 / 255, max_sequence_length: int = 16384, vision_token: str = "", **kwargs, ): - self._rope_scaling: dict[str, Any] | None = None + self._rope_scaling: Optional[dict[str, Any]] = None resolved_text_config = kwargs.pop("text_config", text_config) if isinstance(resolved_text_config, Qwen3Config): text_config_kwargs = copy.deepcopy(resolved_text_config.to_dict()) @@ -1459,14 +1459,14 @@ def rope_scaling(self, value): self.text_config.rope_scaling = value @property - def vision_attn_implementation(self) -> str | None: + def vision_attn_implementation(self) -> Optional[str]: value = getattr(self.vision_config, "_attn_implementation", None) if value is None: value = getattr(self.vision_config, "attn_implementation", None) return value @vision_attn_implementation.setter - def vision_attn_implementation(self, value: str | None) -> None: + def vision_attn_implementation(self, value: Optional[str]) -> None: self.vision_config._attn_implementation = value if value is not None: self.vision_config.attn_implementation = value @@ -1534,8 +1534,8 @@ def __init__( *, vision_token: str = "", max_sequence_length: int = 16384, - rescale_factor: float | None = None, - config: IsaacConfig | dict | None = None, + rescale_factor: Optional[float] = None, + config: Optional[Union[IsaacConfig, dict]] = None, ) -> None: if tokenizer is None: raise ValueError("`tokenizer` must be provided to initialize IsaacProcessor.") @@ -1568,7 +1568,7 @@ def __init__( def build_event_stream_simple( self, text: str, - images: list[PIL.Image.Image] | None = None, + images: Optional[list[PIL.Image.Image]] = None, ) -> Stream: events = [] # Process text and images @@ -1613,9 +1613,9 @@ def build_event_stream_simple( def __call__( self, - text: str | list[str], - images: PIL.Image.Image | list[PIL.Image.Image] | None = None, - return_tensors: str | TensorType | None = TensorType.PYTORCH, + text: Union[str, list[str]], + images: Optional[Union[PIL.Image.Image, list[PIL.Image.Image]]] = None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, **kwargs, ) -> BatchFeature: """ @@ -1726,7 +1726,7 @@ def __init__(self, config: IsaacConfig, device=None): self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size @staticmethod - def _resolve_mrope_section(section: list[int] | None, rotary_half_dim: int) -> list[int]: + def _resolve_mrope_section(section: Optional[list[int]], rotary_half_dim: int) -> list[int]: if section is None: weights = (2, 1, 1) base = [rotary_half_dim * w // sum(weights) for w in weights] @@ -1755,7 +1755,7 @@ def forward( self, position_ids: torch.Tensor, modality_tensor: torch.Tensor, - hidden_states: torch.Tensor | None = None, + hidden_states: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: if position_ids.ndim != 3 or position_ids.size(-1) != 3: raise ValueError("`position_ids` must have shape (batch, seq_len, 3) for MRoPE") @@ -1904,17 +1904,17 @@ def embed_stream(self, tensor_stream: TensorStream) -> torch.Tensor: def forward( self, - input_ids: torch.LongTensor | None = None, - tensor_stream: TensorStream | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - modality_tensor: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | None = None, - inputs_embeds: torch.FloatTensor | None = None, - use_cache: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - cache_position: torch.LongTensor | None = None, + input_ids: Optional[torch.LongTensor] = None, + tensor_stream: Optional[TensorStream] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + modality_tensor: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple | BaseModelOutputWithPast: """ @@ -2165,9 +2165,9 @@ def __init__(self, config: IsaacConfig): def get_rope_index( self, - input_ids: torch.Tensor | None, - tensor_stream: TensorStream | None, - attention_mask: torch.Tensor | None, + input_ids: Optional[torch.Tensor], + tensor_stream: Optional[TensorStream], + attention_mask: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: """Compute MRoPE position ids from a TensorStream (or 1D fallback). @@ -2198,17 +2198,17 @@ def get_rope_index( def forward( self, - input_ids: torch.LongTensor | None = None, - tensor_stream: TensorStream | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - cache_position: torch.LongTensor | None = None, + input_ids: Optional[torch.LongTensor] = None, + tensor_stream: Optional[TensorStream] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple | CausalLMOutputWithPast: r""" @@ -2284,12 +2284,12 @@ def forward( def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, - past_key_values: list[torch.FloatTensor] | None = None, - attention_mask: torch.Tensor | None = None, - inputs_embeds: torch.FloatTensor | None = None, - tensor_stream: TensorStream | None = None, - cache_position: torch.LongTensor | None = None, - position_ids: torch.LongTensor | None = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + tensor_stream: Optional[TensorStream] = None, + cache_position: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, use_cache: bool = True, **kwargs, ) -> dict[str, Any]: diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index debcc6ec612b..3de31b1a7436 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -88,6 +88,7 @@ import math import re +from typing import Optional, Union import PIL.Image import torch @@ -104,10 +105,10 @@ class IsaacImageProcessorKwargs(DefaultFastImageProcessorKwargs, total=False): - patch_size: int | None - max_num_patches: int | None - min_num_patches: int | None - pixel_shuffle_scale: int | None + patch_size: Optional[int] + max_num_patches: Optional[int] + min_num_patches: Optional[int] + pixel_shuffle_scale: Optional[int] # ============================================================================ @@ -170,8 +171,8 @@ def __init__( *, vision_token: str = "", max_sequence_length: int = 16384, - rescale_factor: float | None = None, - config: IsaacConfig | dict | None = None, + rescale_factor: Optional[float] = None, + config: Optional[Union[IsaacConfig, dict]] = None, ) -> None: if tokenizer is None: raise ValueError("`tokenizer` must be provided to initialize IsaacProcessor.") @@ -204,7 +205,7 @@ def __init__( def build_event_stream_simple( self, text: str, - images: list[PIL.Image.Image] | None = None, + images: Optional[list[PIL.Image.Image]] = None, ) -> Stream: events = [] # Process text and images @@ -249,9 +250,9 @@ def build_event_stream_simple( def __call__( self, - text: str | list[str], - images: PIL.Image.Image | list[PIL.Image.Image] | None = None, - return_tensors: str | TensorType | None = TensorType.PYTORCH, + text: Union[str, list[str]], + images: Optional[Union[PIL.Image.Image, list[PIL.Image.Image]]] = None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, **kwargs, ) -> BatchFeature: """ From 5da40566f488de1033927b1d1500221eb70a2547 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 13:54:12 +0400 Subject: [PATCH 35/77] fix: latest transformers changes --- src/transformers/models/isaac/configuration_isaac.py | 4 ++-- src/transformers/models/isaac/modeling_isaac.py | 3 ++- src/transformers/models/isaac/modular_isaac.py | 8 ++++---- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index ec7f72c569f7..d45775ba8db9 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -90,12 +90,12 @@ import copy from typing import Any, Optional, Union -from ...configuration_utils import PretrainedConfig, layer_type_validation +from ...configuration_utils import PreTrainedConfig, PretrainedConfig, layer_type_validation from ...modeling_rope_utils import rope_config_validation from ...models.qwen3.configuration_qwen3 import Qwen3Config -class IsaacVisionConfig(PretrainedConfig): +class IsaacVisionConfig(PreTrainedConfig): """Vision configuration for Isaac with Pixel Shuffle support. Extends Siglip2VisionConfig with additional fields for pixel shuffle. diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index f0819c867e68..b7b7b9de648b 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -89,7 +89,8 @@ import copy from collections import defaultdict -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional import torch import torch.nn as nn diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index a8c22174ed35..ce4bef525d8f 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -86,8 +86,8 @@ import math import re from collections import defaultdict -from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Sequence +from typing import Any, Optional, Union import PIL.Image import torch @@ -118,7 +118,7 @@ from ...generation.utils import GenerationMixin from ...image_processing_utils_fast import ( BaseImageProcessorFast, - DefaultFastImageProcessorKwargs, + ImagesKwargs, SizeDict, group_images_by_shape, reorder_images, @@ -211,7 +211,7 @@ def __init__( self._attn_implementation = "flash_attention_2" -class IsaacImageProcessorKwargs(DefaultFastImageProcessorKwargs, total=False): +class IsaacImageProcessorKwargs(ImagesKwargs, total=False): patch_size: Optional[int] max_num_patches: Optional[int] min_num_patches: Optional[int] From 4a97889621a2d08e0fe00d40d88a89ddab2c521f Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 14:26:54 +0400 Subject: [PATCH 36/77] chore: new transformers convert --- .../models/isaac/image_processing_isaac.py | 98 +++++++++++++++++++ .../isaac/image_processing_isaac_fast.py | 2 +- .../models/isaac/modeling_isaac.py | 8 +- .../models/isaac/processing_isaac.py | 8 -- 4 files changed, 102 insertions(+), 14 deletions(-) create mode 100644 src/transformers/models/isaac/image_processing_isaac.py diff --git a/src/transformers/models/isaac/image_processing_isaac.py b/src/transformers/models/isaac/image_processing_isaac.py new file mode 100644 index 000000000000..ecd28aaae954 --- /dev/null +++ b/src/transformers/models/isaac/image_processing_isaac.py @@ -0,0 +1,98 @@ +# ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ +# This file was automatically generated from src/transformers/models/isaac/modular_isaac.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_isaac.py file directly. One of our CI enforces this. +# ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ +# Copyright (c) 2024 Perceptron, Inc. All rights reserved. +# Perceptron, Inc. Non-Production License (2024-01-01) + + +### 1. Scope and acceptance + +# **1.1. Scope of the Agreement.** +# This Agreement applies to any use, modification, or Distribution of any Perceptron Model by You, regardless of the source You obtained a copy of such Perceptron Model. +# +# **1.2. Acceptance.** By accessing, using, modifying, Distributing a Perceptron Model, or by creating, using or distributing a Derivative of the Perceptron Model, You agree to be bound by this Agreement. +# +# **1.3. Acceptance on behalf of a third-party.** If You accept this Agreement on behalf of Your employer or another person or entity, You warrant and represent that You have the authority to act and accept this Agreement on their behalf. In such a case, the word โ€œYouโ€ in this Agreement will refer to Your employer or such other person or entity. +# +# ## 2. License +# **2.1. Grant of rights.** Subject to Section 3 below, Perceptron, Inc. hereby grants You a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable, limited license to use, copy, modify, and Distribute under the conditions provided in Section 2.2 below, the Perceptron Model and any Derivatives made by or for Perceptron, Inc. and to create Derivatives of the Perceptron Model. +# +# **2.2. Distribution of Perceptron Model and Derivatives made by or for Perceptron, Inc..** Subject to Section 3 below, You may Distribute copies of the Perceptron Model and/or Derivatives made by or for Perceptron, Inc., under the following conditions: +# - You must make available a copy of this Agreement to third-party recipients of the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. you Distribute, it being specified that any rights to use the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. shall be directly granted by Perceptron, Inc. to said third-party recipients pursuant to the Perceptron, Inc. Non-Production License agreement executed between these parties; +# - You must retain in all copies of the Perceptron Models the following attribution notice within a โ€œNoticeโ€ text file distributed as part of such copies: โ€œLicensed by Perceptron, Inc. under the Perceptron, Inc. Non-Production Licenseโ€. +# +# **2.3. Distribution of Derivatives made by or for You.** Subject to Section 3 below, You may Distribute any Derivatives made by or for You under additional or different terms and conditions, provided that: +# - In any event, the use and modification of Perceptron Model and/or Derivatives made by or for Perceptron, Inc. shall remain governed by the terms and conditions of this Agreement; +# - You include in any such Derivatives made by or for You prominent notices stating that You modified the concerned Perceptron Model; and +# - Any terms and conditions You impose on any third-party recipients relating to Derivatives made by or for You shall neither limit such third-party recipientsโ€™ use of the Perceptron Model or any Derivatives made by or for Perceptron, Inc. in accordance with the Perceptron, Inc. Non-Production License nor conflict with any of its terms and conditions. +# +# ## 3. Limitations +# **3.1. Misrepresentation.** You must not misrepresent or imply, through any means, that the Derivatives made by or for You and/or any modified version of the Perceptron Model You Distribute under your name and responsibility is an official product of Perceptron, Inc. or has been endorsed, approved or validated by Perceptron, Inc., unless You are authorized by Us to do so in writing. +# +# **3.2. Usage Limitation** +# - You shall only use the Perceptron Models and Derivatives (whether or not created by Perceptron, Inc.) for testing, research, Personal, or evaluation purposes in Non-Production Environments; +# - Subject to the foregoing, You shall not supply the Perceptron Models or Derivatives in the course of a commercial activity, whether in return for payment or free of charge, in any medium or form, including but not limited to through a hosted or managed service (e.g. SaaS, cloud instances, etc.), or behind a software layer. +# +# **3.3. Usage not permitted under this Agreement.** If You want to use a Perceptron Model or a Derivative for any purpose that is not expressly authorized under this Agreement, You must request a license from Perceptron, Inc., which Perceptron, Inc. may grant to You in Perceptron, Inc.โ€™s sole discretion. Please contact Perceptron, Inc. at the following e-mail address if You want to discuss such a license: sales@perceptron.inc +# +# ## 4. Intellectual Property +# **4.1. Trademarks.** No trademark licenses are granted under this Agreement, and in connection with the Perceptron Models, You may not use any name or mark owned by or associated with Perceptron, Inc. or any of its affiliates, except (i) as required for reasonable and customary use in describing and Distributing the Perceptron Models and Derivatives made by or for Perceptron, Inc. and (ii) for attribution purposes as required by this Agreement. +# +# **4.2. Outputs.** We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs You generate and their subsequent uses in accordance with this Agreement. +# +# **4.3. Derivatives.** By entering into this Agreement, You accept that any Derivatives that You may create or that may be created for You shall be subject to the restrictions set out in Section 3 of this Agreement. +# +# # 5. Liability +# **5.1. Limitation of liability.** In no event, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall Perceptron, Inc. be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Perceptron Models and Derivatives (including but not limited to damages for loss of data, loss of goodwill, loss of expected profit or savings, work stoppage, computer failure or malfunction, or any damage caused by malware or security breaches), even if Perceptron, Inc. has been advised of the possibility of such damages. +# +# **5.2. Indemnification.** You agree to indemnify and hold harmless Perceptron, Inc. from and against any claims, damages, or losses arising out of or related to Your use or Distribution of the Perceptron Models and Derivatives. +# +# ## 6. Warranty +# **6.1. Disclaimer.** Unless required by applicable law or agreed to in writing, Perceptron, Inc. provides the Perceptron Models and Derivatives on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. Perceptron, Inc. does not represent nor warrant that the Perceptron Models and Derivatives will be error-free, meet Your or any third partyโ€™s requirements, be secure or will allow You or any third party to achieve any kind of result or generate any kind of content. You are solely responsible for determining the appropriateness of using or Distributing the Perceptron Models and Derivatives and assume any risks associated with Your exercise of rights under this Agreement. +# +# # 7. Termination +# **7.1. Term.** This Agreement is effective as of the date of your acceptance of this Agreement or access to the concerned Perceptron Models or Derivatives and will continue until terminated in accordance with the following terms. +# +# **7.2. Termination.** Perceptron, Inc. may terminate this Agreement at any time if You are in breach of this Agreement. Upon termination of this Agreement, You must cease to use all Perceptron Models and Derivatives and shall permanently delete any copy thereof. Sections 5, 6, 7 and 8 shall survive the termination of this Agreement. +# +# **7.3. Litigation.** If You initiate any legal action or proceedings against Us or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the Model or a Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement will immediately terminate as of the date such legal action or claim is filed or initiated. +# +# # 8. General provisions +# 8.1. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the State of Washington, without regard to its conflict of law principles. +# +# 8.2. Jurisdiction. The state and federal courts located in King County, Washington shall have exclusive jurisdiction over any dispute arising out of or relating to this Agreement, and You and We consent to personal jurisdiction and venue in such courts. +# +# **8.3. Severability.** If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. +# +# # 9. Definitions +# **โ€œAgreementโ€**: means this Perceptron, Inc. Non-Production License agreement governing the access, use, and Distribution of the Perceptron Models and Derivatives. +# +# **โ€œDerivativeโ€**: means any (i) modified version of the Perceptron Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the Perceptron Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered as Derivatives under this Agreement. +# +# **โ€œDistributionโ€**, **โ€œDistributingโ€**, **โ€œDistributeโ€** or **โ€œDistributedโ€**: means providing or making available, by any means, a copy of the Perceptron Models and/or the Derivatives as the case may be, subject to Section 3 of this Agreement. +# +# **โ€œPerceptron, Inc.โ€**, **โ€œWeโ€** or **โ€œUsโ€**: means Perceptron, Inc., a Delaware corporation with its principal place of business at 10900 NE 8th St Suite 613, Bellevue, WA 98004. +# +# **โ€œPerceptron Modelโ€**: means the foundational large language model(s), and its elements which include algorithms, software, instructed checkpoints, parameters, source code (inference code, evaluation code and, if applicable, fine-tuning code) and any other elements associated thereto made available by Perceptron, Inc. under this Agreement, including, if any, the technical documentation, manuals and instructions for the use and operation thereof. +# +# **โ€œNon-Production Environmentโ€**: means any setting, use case, or application of the Perceptron Models or Derivatives that expressly excludes live, real-world conditions, commercial operations, revenue-generating activities, or direct interactions with or impacts on end users (such as, for instance, Your employees or customers). Non-Production Environment may include, but is not limited to, any setting, use case, or application for research, development, testing, quality assurance, training, internal evaluation (other than any internal usage by employees in the context of the companyโ€™s business activities), and demonstration purposes. +# +# **โ€œOutputsโ€**: means any content generated by the operation of the Perceptron Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a Perceptron Models, such as any fine-tuned versions of the Perceptron Models, the weights, or parameters. +# +# **โ€œPersonalโ€**: means any use of a Perceptron Model or a Derivative that is (i) solely for personal, non-profit and non-commercial purposes and (ii) not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities. For illustration purposes, Personal use of a Model or a Derivative does not include any usage by individuals employed in companies in the context of their daily tasks, any activity that is intended to generate revenue, or that is performed on behalf of a commercial entity. +# +# **โ€œYouโ€**: means the individual or entity entering into this Agreement with Perceptron, Inc.. + +from typing import Optional + +from ...image_processing_utils_fast import ImagesKwargs + + +class IsaacImageProcessorKwargs(ImagesKwargs, total=False): + patch_size: Optional[int] + max_num_patches: Optional[int] + min_num_patches: Optional[int] + pixel_shuffle_scale: Optional[int] diff --git a/src/transformers/models/isaac/image_processing_isaac_fast.py b/src/transformers/models/isaac/image_processing_isaac_fast.py index 0c257e4d3db2..789dae85b40a 100644 --- a/src/transformers/models/isaac/image_processing_isaac_fast.py +++ b/src/transformers/models/isaac/image_processing_isaac_fast.py @@ -103,7 +103,7 @@ # Vision preprocessing constants from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD -from .processing_isaac import IsaacImageProcessorKwargs +from .image_processing_isaac import IsaacImageProcessorKwargs def get_scaled_image_size( diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index b7b7b9de648b..1066b732a2d8 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -116,7 +116,6 @@ from ...models.qwen3.modeling_qwen3 import Qwen3PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring -from ...utils.deprecation import deprecate_kwarg from ...utils.generic import TransformersKwargs, can_return_tuple from ...utils.import_utils import is_torchdynamo_compiling from ..qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling @@ -920,6 +919,7 @@ class IsaacAttention(nn.Module): def __init__(self, config: IsaacConfig, layer_idx: int): super().__init__() + self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) @@ -942,9 +942,8 @@ def __init__(self, config: IsaacConfig, layer_idx: int): ) self.q_norm = IsaacRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = IsaacRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape - self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -1002,7 +1001,6 @@ def __init__(self, config: IsaacConfig, layer_idx: int): self.post_attention_layernorm = IsaacRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] - @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -1011,7 +1009,7 @@ def forward( past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 3de31b1a7436..1313eb3db119 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -97,20 +97,12 @@ from perceptron.tensorstream.tensorstream import Event, Stream, TensorStream, TextType, VisionType, create_stream from ...feature_extraction_utils import BatchFeature -from ...image_processing_utils_fast import DefaultFastImageProcessorKwargs from ...models.auto.tokenization_auto import AutoTokenizer from ...processing_utils import ProcessorMixin from ...tokenization_utils import TensorType from .configuration_isaac import IsaacConfig -class IsaacImageProcessorKwargs(DefaultFastImageProcessorKwargs, total=False): - patch_size: Optional[int] - max_num_patches: Optional[int] - min_num_patches: Optional[int] - pixel_shuffle_scale: Optional[int] - - # ============================================================================ # Processor Components # ============================================================================ From 0d553958163baa99140bd71d016da48de027f2b7 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 14:43:23 +0400 Subject: [PATCH 37/77] again --- src/transformers/models/isaac/modeling_isaac.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 1066b732a2d8..6eab664d2465 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -1489,6 +1489,8 @@ def forward( **kwargs, ) -> tuple | CausalLMOutputWithPast: r""" + Forward pass for conditional generation supporting both standard inputs and TensorStream. + tensor_stream (`TensorStream`, *optional*): Packed multimodal stream (text, vision, audio tokens) that already encodes spatial metadata. When provided, the model derives embeddings, modality masks, and 3D rotary coordinates directly from the stream instead of From 287a461def2fa3868d9189ad5b1bc4df4f8cce1d Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 15:07:57 +0400 Subject: [PATCH 38/77] fix: export pretrained model --- src/transformers/models/isaac/modeling_isaac.py | 2 +- src/transformers/models/isaac/modular_isaac.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 6eab664d2465..5e0b71c62a2b 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -1632,4 +1632,4 @@ def can_generate(self) -> bool: return True -__all__ = ["IsaacModel", "IsaacForConditionalGeneration"] +__all__ = ["IsaacModel", "IsaacPreTrainedModel", "IsaacForConditionalGeneration"] diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index ce4bef525d8f..8f0342f5e5bd 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -1351,7 +1351,11 @@ def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: class IsaacConfig(PretrainedConfig): - """Configuration class for Isaac multimodal model.""" + """Configuration class for Isaac multimodal model. + + This configuration corresponds to checkpoints such as + [Perceptron/isaac-base](https://huggingface.co/Perceptron/isaac-base). + """ model_type = "isaac" sub_configs = {"vision_config": IsaacVisionConfig, "text_config": Qwen3Config} @@ -2338,6 +2342,7 @@ def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> __all__ = [ "IsaacConfig", "IsaacModel", + "IsaacPreTrainedModel", # noqa: F822 "IsaacForConditionalGeneration", "IsaacImageProcessorFast", "IsaacProcessor", From c43cb5de1b72f5a49fa64a3f97243879cfa42078 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 15:48:56 +0400 Subject: [PATCH 39/77] test: add placeholder tests --- tests/models/isaac/test_modeling_isaac.py | 141 ++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 tests/models/isaac/test_modeling_isaac.py diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py new file mode 100644 index 000000000000..58f24c4e637a --- /dev/null +++ b/tests/models/isaac/test_modeling_isaac.py @@ -0,0 +1,141 @@ +import unittest + +from transformers import IsaacConfig, IsaacForConditionalGeneration, IsaacModel, is_torch_available +from transformers.testing_utils import require_torch, torch_device + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ids_tensor + + +if is_torch_available(): + import torch + + +class IsaacModelTester: + def __init__( + self, + parent, + batch_size=2, + seq_length=5, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + self.text_config = { + "bos_token_id": 0, + "eos_token_id": 1, + "pad_token_id": 2, + "hidden_act": "silu", + "head_dim": hidden_size // num_attention_heads, + "hidden_size": hidden_size, + "vocab_size": vocab_size, + "intermediate_size": hidden_size * 3, + "max_position_embeddings": 128, + "model_type": "qwen3", + "num_attention_heads": num_attention_heads, + "num_hidden_layers": num_hidden_layers, + "num_key_value_heads": num_attention_heads, + "rope_parameters": {"rope_type": "default", "mrope_section": [16, 8, 8], "mrope_interleaved": True}, + "rope_theta": 10000, + "tie_word_embeddings": True, + } + + self.vision_config = { + "hidden_size": hidden_size, + "intermediate_size": hidden_size * 2, + "num_hidden_layers": 1, + "num_attention_heads": num_attention_heads, + "num_channels": 3, + "num_patches": 64, + "patch_size": 4, + "pixel_shuffle_scale_factor": 1, + "attention_dropout": 0.0, + "layer_norm_eps": 1e-6, + } + + def get_config(self): + config = IsaacConfig( + text_config=self.text_config, + vision_config=self.vision_config, + ) + # Rely on vanilla SDPA so the tests do not need flash attention. + config._attn_implementation = "sdpa" + config.text_config._attn_implementation = "sdpa" + config.vision_attn_implementation = "sdpa" + return config + + def prepare_config_and_inputs(self): + config = self.get_config() + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = torch.ones( + (self.batch_size, self.seq_length), + dtype=torch.long, + device=torch_device, + ) + labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + return config, input_ids, attention_mask, labels + + +@require_torch +class IsaacModelTest(unittest.TestCase): + all_model_classes = (IsaacModel, IsaacForConditionalGeneration) if is_torch_available() else () + + def setUp(self): + self.model_tester = IsaacModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=IsaacConfig, + has_text_modality=True, + common_properties=["hidden_size"], + text_config=self.model_tester.text_config, + vision_config=self.model_tester.vision_config, + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model_forward(self): + config, input_ids, attention_mask, _ = self.model_tester.prepare_config_and_inputs() + model = IsaacModel(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_ids=input_ids, attention_mask=attention_mask) + + self.assertEqual( + result.last_hidden_state.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, config.hidden_size), + ) + + def test_for_conditional_generation(self): + config, input_ids, attention_mask, labels = self.model_tester.prepare_config_and_inputs() + model = IsaacForConditionalGeneration(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.vocab_size), + ) + self.assertIsNotNone(result.loss) + + def test_prepare_inputs_for_generation(self): + config, input_ids, attention_mask, _ = self.model_tester.prepare_config_and_inputs() + model = IsaacForConditionalGeneration(config) + model.to(torch_device) + + prepared_inputs = model.prepare_inputs_for_generation(input_ids=input_ids, attention_mask=attention_mask) + self.assertIn("input_ids", prepared_inputs) + self.assertIn("position_ids", prepared_inputs) + self.assertIsNone(prepared_inputs["position_ids"]) From c84df288c084805b4975f45294b618334da83a60 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 16:27:23 +0400 Subject: [PATCH 40/77] docs: add seed documentation --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/isaac.md | 130 ++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+) create mode 100644 docs/source/en/model_doc/isaac.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c5ce9fbdb9c4..2df16437cb19 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1088,6 +1088,8 @@ title: InstructBlipVideo - local: model_doc/internvl title: InternVL + - local: model_doc/isaac + title: Isaac - local: model_doc/janus title: Janus - local: model_doc/kosmos-2 diff --git a/docs/source/en/model_doc/isaac.md b/docs/source/en/model_doc/isaac.md new file mode 100644 index 000000000000..d5af32870bae --- /dev/null +++ b/docs/source/en/model_doc/isaac.md @@ -0,0 +1,130 @@ + +*This model was added to Hugging Face Transformers in 2025.* + +
+
+ PyTorch + FlashAttention + SDPA +
+
+ +# Isaac + +Isaac is Perceptron's vision-language model (VLM) that pairs a SigLIP2 vision encoder with a Qwen3 decoder-only stack. The +architecture is designed for efficient long-context multimodal interactions, and supports interleaving images with +text. The vision encoder has variable-resolution capability and with optional pixel shuffle to merge +neighboring patches before they reach the decoder, which keeps the KV-cache and compute requirements manageable on long +conversations. Text and vision tokens are unified via the [`TensorStream`](https://github.com/perceptron-ai-inc/perceptron/tree/main/src/perceptron/tensorstream) abstraction so +that modal boundaries, spatial coordinates, and rescaling parameters are preserved throughout the model stack. + +Key implementation notes: + +- **Packed vision attention** โ€“ `IsaacVisionEncoder` keeps track of per-image patch lengths and uses specialized attention + kernels with custom `AttentionMaskConverter` utilities so the decoder only applies attention to real patches while supporting + both FlashAttention and SDPA. +- **TensorStream-first pipeline** โ€“ `IsaacProcessor` converts chat templates into multimodal streams where every image gets a + dedicated event with spatial metadata. `IsaacModel` can embed that stream directly (using `embed_stream`) and automatically + derive multi-dimensional RoPE coordinates, so you only need to provide the `tensor_stream` during the first decoding step. +- **Fast image pre-processing** โ€“ `IsaacImageProcessorFast` solves for the closest resolution that fits within the requested context. + +Isaac checkpoints are distributed under Perceptron's Non-Production license; please review the license that ships with the +weights before using them in commercial settings. + +## Usage example + +`IsaacProcessor` expects that every `` token in the rendered prompt has a +matching image. The processor returns both standard tokenized inputs and a `TensorStream`. You should pass the stream to the +model (only the first generation step requires it) alongside the regular tensors. + +```py +import torch +from PIL import Image +from transformers import AutoProcessor, IsaacForConditionalGeneration + +model_id = "Perceptron/isaac-base" +processor = AutoProcessor.from_pretrained(model_id) +model = IsaacForConditionalGeneration.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + device_map="auto", + attn_implementation="flash_attention_2", +) + +images = [Image.open("chart.png"), Image.open("panel.jpg")] +messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": images[0]}, + {"type": "image", "image": images[1]}, + {"type": "text", "text": "Compare the two figures and explain what changed."}, + ], + } +] + +# Render the chat template to text so we can pass text+images together. +prompt = processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, +) + +# IsaacProcessor builds TensorStream events internally when both text and images are provided. +batch = processor(text=prompt, images=images, return_tensors="pt") + +tensor_stream = batch.pop("tensor_stream").to(model.device) +inputs = {name: tensor.to(model.device) for name, tensor in batch.items()} + +with torch.inference_mode(): + generated = model.generate( + **inputs, + tensor_stream=tensor_stream, + max_new_tokens=256, + temperature=0.2, + eos_token_id=processor.tokenizer.eos_token_id, + pad_token_id=processor.tokenizer.eos_token_id, + ) + +response = processor.post_process_image_text_to_text( + generated, + skip_special_tokens=True, +)[0] +print(response) +``` + +## IsaacConfig + +[[autodoc]] IsaacConfig + +## IsaacModel + +[[autodoc]] IsaacModel + - forward + +## IsaacForConditionalGeneration + +[[autodoc]] IsaacForConditionalGeneration + - forward + +## IsaacProcessor + +[[autodoc]] IsaacProcessor + +## IsaacImageProcessorFast + +[[autodoc]] IsaacImageProcessorFast From bf432bcc951dca5102d736181a2ef44725e56d54 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 16:27:45 +0400 Subject: [PATCH 41/77] docs: point to isaac model checkpoint --- src/transformers/models/isaac/configuration_isaac.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index d45775ba8db9..aac683cce15c 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -147,7 +147,11 @@ def __init__( class IsaacConfig(PretrainedConfig): - """Configuration class for Isaac multimodal model.""" + """Configuration class for Isaac multimodal model. + + This configuration corresponds to checkpoints such as + [Perceptron/isaac-base](https://huggingface.co/Perceptron/isaac-base). + """ model_type = "isaac" sub_configs = {"vision_config": IsaacVisionConfig, "text_config": Qwen3Config} From 080f22dc2fad3c1973ac1f31dd798169bba59214 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 16:48:16 +0400 Subject: [PATCH 42/77] fix: set config fields in model --- src/transformers/models/isaac/modeling_isaac.py | 5 +++++ src/transformers/models/isaac/modular_isaac.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 5e0b71c62a2b..436ed1396724 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -1109,6 +1109,11 @@ def __init__(self, config: IsaacConfig): VisionType: self.embed_vision, } + # Keep track of config attributes that downstream utilities may query directly on the model. + self.max_sequence_length = config.max_sequence_length + self.vision_rescale_factor = config.vision_rescale_factor + self.vision_token = config.vision_token + def get_input_embeddings(self) -> nn.Module: return self.text_model.get_input_embeddings() diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 8f0342f5e5bd..36dba0137845 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -1832,6 +1832,11 @@ def __init__(self, config: IsaacConfig): VisionType: self.embed_vision, } + # Keep track of config attributes that downstream utilities may query directly on the model. + self.max_sequence_length = config.max_sequence_length + self.vision_rescale_factor = config.vision_rescale_factor + self.vision_token = config.vision_token + def get_input_embeddings(self) -> nn.Module: return self.text_model.get_input_embeddings() From 0764c2cbd6458506982464dff2c86017a6ff19c3 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 10 Nov 2025 17:04:16 +0400 Subject: [PATCH 43/77] docs: add dates stamp --- docs/source/en/model_doc/isaac.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/model_doc/isaac.md b/docs/source/en/model_doc/isaac.md index d5af32870bae..763678931c0b 100644 --- a/docs/source/en/model_doc/isaac.md +++ b/docs/source/en/model_doc/isaac.md @@ -13,6 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> +*This model was released on {release_date} and added to Hugging Face Transformers on 2025-11-10.* *This model was added to Hugging Face Transformers in 2025.*
From 43f8b8172da8a0b5af18c8013a5546cc3cae49d7 Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Fri, 14 Nov 2025 17:28:15 +0400 Subject: [PATCH 44/77] Update isaac.md --- docs/source/en/model_doc/isaac.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/isaac.md b/docs/source/en/model_doc/isaac.md index 763678931c0b..3317b8cbd84f 100644 --- a/docs/source/en/model_doc/isaac.md +++ b/docs/source/en/model_doc/isaac.md @@ -31,7 +31,7 @@ architecture is designed for efficient long-context multimodal interactions, and text. The vision encoder has variable-resolution capability and with optional pixel shuffle to merge neighboring patches before they reach the decoder, which keeps the KV-cache and compute requirements manageable on long conversations. Text and vision tokens are unified via the [`TensorStream`](https://github.com/perceptron-ai-inc/perceptron/tree/main/src/perceptron/tensorstream) abstraction so -that modal boundaries, spatial coordinates, and rescaling parameters are preserved throughout the model stack. +that modal boundaries, spatial coordinates, and rescaling parameters are preserved throughout the model stack. For more information, refer to the [technical report](https://github.com/perceptron-ai-inc/perceptron/blob/main/papers/isaac_01.pdf). Key implementation notes: From 059002525d40231d00d90839abc7b467565acb56 Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Mon, 17 Nov 2025 18:51:59 +0400 Subject: [PATCH 45/77] Isaact e2e tests + passing make fixup (#4) * test: fix mrope expectation * fix: default cache position setting * test: add golden values * fix: add isaac image processor to auto * fix: handle new delegated cache functionality * fix: update deprecated argument name * fix: pop unused args to pass validation * fix: alias rope for compatibility with checkpoints * test: extensive e2e testing of isaac * feat: convert to modular artifacts * test: add vision token to pre-defined messages * fix: hand roll siglip embeddings so it's moved properly to modeling isaac * test: fix separate vision token to isolated message * fix: move implementation into class so implementations are copied to modeling file * test: save logit stats for more interpretable test deviations * WIP tests are passing! * test: fix off by one from not stripping whitespace * test: generation consistency tests passing * test: test-specific dtype and device * test: add vocab size property for testing utils compatibility * test: move tensorstream to correct device * test: linting * fix: unregister * chore: linting * fix: remporarily remove vision embeddings copy for fix copies * test: update tests to load model from revision with updated configs * chore: latest transformers modular convert artifact * chore: make fixup completion artifact --- docs/source/en/model_doc/isaac.md | 2 +- .../models/auto/image_processing_auto.py | 1 + .../models/isaac/configuration_isaac.py | 33 + .../models/isaac/modeling_isaac.py | 280 ++++++-- .../models/isaac/modular_isaac.py | 644 ++++++++--------- .../models/isaac/isaac_checkpoint_hashes.json | 5 + .../models/isaac/isaac_generation_golden.json | 450 ++++++++++++ tests/models/isaac/test_modeling_isaac.py | 674 +++++++++++++++++- 8 files changed, 1715 insertions(+), 374 deletions(-) create mode 100644 tests/models/isaac/isaac_checkpoint_hashes.json create mode 100644 tests/models/isaac/isaac_generation_golden.json diff --git a/docs/source/en/model_doc/isaac.md b/docs/source/en/model_doc/isaac.md index 3317b8cbd84f..5b991f166acf 100644 --- a/docs/source/en/model_doc/isaac.md +++ b/docs/source/en/model_doc/isaac.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on {release_date} and added to Hugging Face Transformers on 2025-11-10.* +*This model was released on {release_date} and added to Hugging Face Transformers on 2025-11-17.* *This model was added to Hugging Face Transformers in 2025.*
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 1de08b72934c..201b4a73861a 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -123,6 +123,7 @@ ("imagegpt", ("ImageGPTImageProcessor", "ImageGPTImageProcessorFast")), ("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")), ("internvl", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")), + ("isaac", (None, "IsaacImageProcessorFast")), ("janus", ("JanusImageProcessor", "JanusImageProcessorFast")), ("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("kosmos-2.5", ("Kosmos2_5ImageProcessor", "Kosmos2_5ImageProcessorFast")), diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index aac683cce15c..53963cf0e07b 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -167,6 +167,7 @@ def __init__( **kwargs, ): self._rope_scaling: Optional[dict[str, Any]] = None + self._rope_parameters: Optional[dict[str, Any]] = None resolved_text_config = kwargs.pop("text_config", text_config) if isinstance(resolved_text_config, Qwen3Config): text_config_kwargs = copy.deepcopy(resolved_text_config.to_dict()) @@ -180,6 +181,11 @@ def __init__( text_config_kwargs.update(kwargs) self.text_config = self.sub_configs["text_config"](**text_config_kwargs) + if not hasattr(self.text_config, "rope_theta"): + rope_theta_override = text_config_kwargs.get("rope_theta", kwargs.get("rope_theta")) + if rope_theta_override is None: + rope_theta_override = getattr(Qwen3Config(), "rope_theta", 10000.0) + self.text_config.rope_theta = rope_theta_override super().__init__(**kwargs) @@ -188,6 +194,9 @@ def __init__( else: self.text_config.rope_scaling = self._rope_scaling + # Keep rope parameters alias in sync with upstream expectations + self._rope_parameters = self._rope_scaling + # Mirror frequently accessed Qwen3 attributes at the composite config level for BC. self.tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False) self.vocab_size = self.text_config.vocab_size @@ -258,6 +267,21 @@ def rope_scaling(self, value): if hasattr(self, "text_config") and self.text_config is not None: self.text_config.rope_scaling = value + @property + def rope_parameters(self) -> dict[str, Any] | None: + """Alias introduced upstream for rope scaling dictionaries.""" + value = self._rope_parameters + if value is None: + value = self.rope_scaling + if value is None: + return {"rope_type": "default"} + return value + + @rope_parameters.setter + def rope_parameters(self, value: dict[str, Any] | None) -> None: + self._rope_parameters = value + self.rope_scaling = value + @property def vision_attn_implementation(self) -> Optional[str]: value = getattr(self.vision_config, "_attn_implementation", None) @@ -273,5 +297,14 @@ def vision_attn_implementation(self, value: Optional[str]) -> None: elif hasattr(self.vision_config, "attn_implementation"): delattr(self.vision_config, "attn_implementation") + def to_dict(self): + output = super().to_dict() + # Ensure nested configs round-trip through dict serialization + if hasattr(self, "text_config") and self.text_config is not None: + output["text_config"] = self.text_config.to_dict() + if hasattr(self, "vision_config") and self.vision_config is not None: + output["vision_config"] = self.vision_config.to_dict() + return output + __all__ = ["IsaacConfig"] diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 436ed1396724..3920e53f9d09 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -103,7 +103,7 @@ from perceptron.tensorstream.tensorstream import TensorStream, TextType, VisionType, group_streams from ...activations import ACT2FN -from ...cache_utils import Cache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation.utils import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter @@ -140,6 +140,26 @@ def __init__(self, config: IsaacVisionConfig): self.position_embedding_size = int(self.num_patches**0.5) self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor: + packed_pixel_values, seq_lengths = self._pack_to_batch(seq_patches, spatial_shapes) + if packed_pixel_values is None: + return seq_patches.new_zeros((0, self.embed_dim)) + + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(packed_pixel_values.to(dtype=target_dtype)) + + positional_embeddings = self.position_embedding.weight.reshape( + self.position_embedding_size, + self.position_embedding_size, + -1, + ) + resized_positional_embeddings = self.resize_positional_embeddings( + positional_embeddings, spatial_shapes, max_length=packed_pixel_values.shape[1] + ) + + embeddings = patch_embeds + resized_positional_embeddings + return self._unpack_from_batch(embeddings, seq_lengths) + @staticmethod def resize_positional_embeddings( positional_embeddings: torch.Tensor, @@ -199,21 +219,6 @@ def resize_positional_embeddings( return resulted_positional_embeddings - def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor: - """ - Args: - pixel_values (`torch.FloatTensor`): - Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size) - spatial_shapes (`list[tuple[int, int]]`): - Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to - """ - packed_pixel_values, seq_lengths = self._pack_to_batch(seq_patches, spatial_shapes) - if packed_pixel_values is None: - return seq_patches.new_zeros((0, self.embed_dim)) - - embeddings = super().forward(packed_pixel_values, spatial_shapes) - return self._unpack_from_batch(embeddings, seq_lengths) - def _pack_to_batch( self, seq_patches: torch.Tensor, @@ -265,6 +270,13 @@ def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor return torch.cat(output_chunks, dim=0) +def _max_from_cu(cu: Optional[torch.Tensor], fallback: int) -> int: + """Helper to compute max sequence length from cumulative sequence lengths.""" + if cu is None or len(cu) < 2: + return fallback + return int((cu[1:] - cu[:-1]).max().item()) + + def build_document_attention_mask( cu_seqlens: Optional[torch.Tensor], total_tokens: int, @@ -321,6 +333,7 @@ class IsaacVisionAttention(nn.Module): "eager": "isaac_eager", "isaac_eager": "isaac_eager", } + _FLASH_IMPLS = frozenset(("isaac_flash_attention_2", "isaac_flash_attention_3")) def __init__(self, config): super().__init__() @@ -348,6 +361,8 @@ def forward(self, hidden_states, attention_mask=None, **kwargs) -> tuple[torch.T cu_seqlens = kwargs.pop("cu_seqlens", None) max_seqlen = kwargs.pop("max_seqlen", None) kwargs.pop("output_attentions", None) + kwargs.pop("output_hidden_states", None) + kwargs.pop("return_dict", None) if kwargs: unexpected = ", ".join(sorted(kwargs)) raise TypeError(f"Unexpected kwargs for IsaacVisionAttention.forward: {unexpected}") @@ -382,41 +397,70 @@ def forward(self, hidden_states, attention_mask=None, **kwargs) -> tuple[torch.T q.device, ) - resolved_key = self.ATTENTION_KEY_MAP.get(attn_impl) - attention_fn = ALL_ATTENTION_FUNCTIONS.get(resolved_key) if resolved_key is not None else None - if attention_fn is None: - raise ValueError(f"Attention implementation {attn_impl} not found.") + resolved_key = self.ATTENTION_KEY_MAP.get(attn_impl, attn_impl) - query_states = q.transpose(0, 1).unsqueeze(0) - key_states = k.transpose(0, 1).unsqueeze(0) - value_states = v.transpose(0, 1).unsqueeze(0) - - attention_kwargs: dict[str, Any] = { - "dropout": p_drop, - "scaling": self.scale, - "is_causal": False, - } - if cu_seqlens is not None: - attention_kwargs["cu_seq_lens_q"] = cu_seqlens - attention_kwargs["cu_seq_lens_k"] = cu_seqlens - if max_seqlen is not None: - attention_kwargs["max_length_q"] = max_seqlen - attention_kwargs["max_length_k"] = max_seqlen - - attn_output, _ = attention_fn( - self, - query_states, - key_states, - value_states, - attn_mask, - **attention_kwargs, - ) + attn_weights = None + if resolved_key in self._FLASH_IMPLS: + y_lhd = self._flash_attention_forward( + q_lhd=q, + k_lhd=k, + v_lhd=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout=p_drop, + ) + elif resolved_key == "isaac_sdpa": + y_lhd = self._sdpa_attention_forward( + q_lhd=q, + k_lhd=k, + v_lhd=v, + attention_mask=attn_mask, + cu_seqlens=cu_seqlens, + dropout=p_drop, + ) + elif resolved_key == "isaac_eager": + y_lhd, attn_weights = self._eager_attention_forward( + q_lhd=q, + k_lhd=k, + v_lhd=v, + attention_mask=attn_mask, + dropout=p_drop, + ) + else: + attention_fn = ALL_ATTENTION_FUNCTIONS.get(resolved_key) + if attention_fn is None: + raise ValueError(f"Attention implementation {attn_impl} not found.") + + query_states = q.transpose(0, 1).unsqueeze(0) + key_states = k.transpose(0, 1).unsqueeze(0) + value_states = v.transpose(0, 1).unsqueeze(0) + + attention_kwargs: dict[str, Any] = { + "dropout": p_drop, + "scaling": self.scale, + "is_causal": False, + } + if cu_seqlens is not None: + attention_kwargs["cu_seq_lens_q"] = cu_seqlens + attention_kwargs["cu_seq_lens_k"] = cu_seqlens + if max_seqlen is not None: + attention_kwargs["max_length_q"] = max_seqlen + attention_kwargs["max_length_k"] = max_seqlen + + attn_output, attn_weights = attention_fn( + self, + query_states, + key_states, + value_states, + attn_mask, + **attention_kwargs, + ) - y_lhd = attn_output.squeeze(0).permute(1, 0, 2).contiguous() + y_lhd = attn_output.squeeze(0).permute(1, 0, 2).contiguous() # Merge heads and project y = self.out_proj(y_lhd.reshape(L, self.embed_dim)) - return y.unsqueeze(0), None # (1, L, E) + return y.unsqueeze(0), attn_weights # (1, L, E) def _variable_length_context(self, *, cu_seqlens=None, max_seqlen=None): """Store packed-sequence metadata for the next forward call.""" @@ -429,6 +473,114 @@ def _consume_variable_length_metadata(self): self._variable_length_metadata = None return cu_seqlens, max_seqlen + @staticmethod + def _max_from_cu(cu: Optional[torch.Tensor], fallback: int) -> int: + if cu is None or cu.numel() < 2: + return fallback + return int((cu[1:] - cu[:-1]).max().item()) + + def _flash_attention_forward( + self, + *, + q_lhd: torch.Tensor, + k_lhd: torch.Tensor, + v_lhd: torch.Tensor, + cu_seqlens: Optional[torch.Tensor], + max_seqlen: Optional[int], + dropout: float, + ) -> torch.Tensor: + L = q_lhd.size(0) + if max_seqlen is not None: + max_q = max_k = int(max_seqlen) + else: + max_q = max_k = self._max_from_cu(cu_seqlens, L) + + if not q_lhd.is_contiguous(): + q_lhd = q_lhd.contiguous() + if not k_lhd.is_contiguous(): + k_lhd = k_lhd.contiguous() + if not v_lhd.is_contiguous(): + v_lhd = v_lhd.contiguous() + + out_lhd, *_ = torch.ops.aten._flash_attention_forward( + query=q_lhd, + key=k_lhd, + value=v_lhd, + cum_seq_q=cu_seqlens, + cum_seq_k=cu_seqlens, + max_q=max_q, + max_k=max_k, + dropout_p=dropout, + is_causal=False, + return_debug_mask=False, + scale=self.scale, + window_size_left=-1, + window_size_right=-1, + alibi_slopes=None, + ) + return out_lhd + + def _sdpa_attention_forward( + self, + *, + q_lhd: torch.Tensor, + k_lhd: torch.Tensor, + v_lhd: torch.Tensor, + attention_mask: Optional[torch.Tensor], + cu_seqlens: Optional[torch.Tensor], + dropout: float, + ) -> torch.Tensor: + L = q_lhd.size(0) + attn_mask = attention_mask + if attn_mask is None: + attn_mask = build_document_attention_mask( + cu_seqlens=cu_seqlens, + total_tokens=L, + dtype=q_lhd.dtype, + device=q_lhd.device, + ) + + q = q_lhd.permute(1, 0, 2).unsqueeze(0) + k = k_lhd.permute(1, 0, 2).unsqueeze(0) + v = v_lhd.permute(1, 0, 2).unsqueeze(0) + + if attn_mask is not None and attn_mask.dtype != q.dtype: + attn_mask = attn_mask.to(q.dtype) + + output = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=dropout, + scale=self.scale, + is_causal=False, + ) + return output.squeeze(0).permute(1, 0, 2).contiguous() + + def _eager_attention_forward( + self, + *, + q_lhd: torch.Tensor, + k_lhd: torch.Tensor, + v_lhd: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + attn_weights = torch.matmul(q_lhd, k_lhd.transpose(1, 2)) * self.scale + if attention_mask is not None: + mask = attention_mask + if mask.dim() == 4: + mask = mask.squeeze(0).squeeze(0) + attn_weights = attn_weights + mask + + attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_lhd.dtype) + if dropout and self.training: + attn_weights = F.dropout(attn_weights, p=dropout, training=True) + + attn_output_lhd = torch.matmul(attn_weights, v_lhd) + return attn_output_lhd, attn_weights + class IsaacMLP(nn.Module): def __init__(self, config): @@ -1236,6 +1388,19 @@ def forward( elif inputs_embeds is None: raise ValueError("You have to specify either tensor_stream, input_ids or inputs_embeds") + # Ensure cache exists when requested + if use_cache and past_key_values is None: + cache_config = self.config.get_text_config() if hasattr(self.config, "get_text_config") else self.config + past_key_values = DynamicCache(config=cache_config) + + if cache_position is None and (past_key_values is not None or use_cache): + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + # Create default position_ids if not provided if position_ids is None: if tensor_stream is not None: @@ -1266,7 +1431,7 @@ def forward( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=(cos, sin), @@ -1459,7 +1624,7 @@ class IsaacPreTrainedModel(PreTrainedModel): class IsaacForConditionalGeneration(IsaacPreTrainedModel, GenerationMixin): """Isaac multimodal model for conditional generation.""" - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} @@ -1611,6 +1776,23 @@ def prepare_inputs_for_generation( """ Prepare inputs for generation, handling TensorStream inputs properly. """ + if cache_position is None: + seq_length = None + device = None + if input_ids is not None: + seq_length = input_ids.shape[1] + device = input_ids.device + elif inputs_embeds is not None: + seq_length = inputs_embeds.shape[1] + device = inputs_embeds.device + elif tensor_stream is not None: + _, seq_length = tensor_stream.shape + device = tensor_stream.device + if seq_length is not None: + # prepare_inputs_for_generation may be invoked outside `generate`, so synthesize the + # same cache positions that GenerationMixin would have created during prefill. + cache_position = torch.arange(seq_length, dtype=torch.long, device=device) + # Call parent preparation model_inputs = super().prepare_inputs_for_generation( input_ids, @@ -1623,6 +1805,8 @@ def prepare_inputs_for_generation( **kwargs, ) + cache_position = model_inputs.get("cache_position", cache_position) + # Handle TensorStream for first forward pass only if tensor_stream is not None and (cache_position is None or cache_position[0] == 0): model_inputs["tensor_stream"] = tensor_stream diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 36dba0137845..b5b6723b43e9 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -86,7 +86,7 @@ import math import re from collections import defaultdict -from collections.abc import Callable, Sequence +from collections.abc import Sequence from typing import Any, Optional, Union import PIL.Image @@ -112,7 +112,7 @@ group_streams, ) -from ...cache_utils import Cache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...configuration_utils import PretrainedConfig, layer_type_validation from ...feature_extraction_utils import BatchFeature from ...generation.utils import GenerationMixin @@ -150,16 +150,9 @@ Siglip2Attention, Siglip2Encoder, Siglip2EncoderLayer, - Siglip2VisionEmbeddings, ) -_ORIGINAL_ATTENTION_FUNCTIONS: dict[str, Callable[..., tuple[torch.Tensor, Optional[torch.Tensor]]]] = {} -for _attn_name in ("flash_attention_2", "sdpa", "eager"): - if _attn_name in ALL_ATTENTION_FUNCTIONS: - _ORIGINAL_ATTENTION_FUNCTIONS[_attn_name] = ALL_ATTENTION_FUNCTIONS[_attn_name] - - class IsaacVisionConfig(Siglip2VisionConfig): """Vision configuration for Isaac with Pixel Shuffle support. @@ -510,105 +503,103 @@ def ensure_document_attention_mask( ) -def flash_attention_document_mask_forward( - module: torch.nn.Module, - q_lhd: torch.Tensor, # (L, H, D) - k_lhd: torch.Tensor, # (L, H, D) - v_lhd: torch.Tensor, # (L, H, D) - attention_mask: Optional[torch.Tensor] = None, # unused for FA path - dropout: float = 0.0, - scaling: Optional[float] = None, - cum_seq_q: Optional[torch.Tensor] = None, - cum_seq_k: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, - is_causal: bool = False, - **kwargs, -) -> tuple[torch.Tensor, None]: - """FlashAttention that consumes (L, H, D) directly to avoid layout churn.""" - L, H, D = q_lhd.shape - - # Compute max block length once (honor caller when provided) - if max_seqlen is not None: - max_q = max_k = int(max_seqlen) - else: - max_q = _max_from_cu(cum_seq_q, L) - max_k = _max_from_cu(cum_seq_k, L) - - # Ensure contiguity only if needed - if not q_lhd.is_contiguous(): - q_lhd = q_lhd.contiguous() - if not k_lhd.is_contiguous(): - k_lhd = k_lhd.contiguous() - if not v_lhd.is_contiguous(): - v_lhd = v_lhd.contiguous() - - out_lhd, *_ = torch.ops.aten._flash_attention_forward( - query=q_lhd, # (L, H, D) - key=k_lhd, # (L, H, D) - value=v_lhd, # (L, H, D) - cum_seq_q=cum_seq_q, - cum_seq_k=cum_seq_k, - max_q=max_q, - max_k=max_k, - dropout_p=dropout, - is_causal=is_causal, - return_debug_mask=False, - scale=scaling, - window_size_left=-1, - window_size_right=-1, - alibi_slopes=None, - ) - return out_lhd, None # (L, H, D) +class IsaacVisionEmbeddings(nn.Module): + """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences.""" + def __init__(self, config: IsaacVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size -def sdpa_document_mask_forward( - q_lhd: torch.Tensor, # (L, H, D) - k_lhd: torch.Tensor, # (L, H, D) - v_lhd: torch.Tensor, # (L, H, D) - dropout: float, - scaling: Optional[float], - attention_mask: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """SDPA with block-diagonal masking for variable-length sequences.""" - L, H, D = q_lhd.shape - - # Transpose to (1, H, L, D) format for SDPA - Q = q_lhd.permute(1, 0, 2).unsqueeze(0) - K = k_lhd.permute(1, 0, 2).unsqueeze(0) - V = v_lhd.permute(1, 0, 2).unsqueeze(0) - - # Build block-diagonal mask for variable-length sequences - attn_mask = attention_mask - if attn_mask is None: - attn_mask = build_document_attention_mask( - cu_seqlens=cu_seqlens, - total_tokens=L, - dtype=q_lhd.dtype, - device=q_lhd.device, + self.patch_embedding = nn.Linear( + in_features=config.num_channels * self.patch_size * self.patch_size, + out_features=self.embed_dim, ) - if attn_mask is not None and attn_mask.dtype != Q.dtype: - attn_mask = attn_mask.to(Q.dtype) - - Y = F.scaled_dot_product_attention(Q, K, V, attn_mask=attn_mask, dropout_p=dropout, scale=scaling) - return Y.squeeze(0).permute(1, 0, 2) # Back to (L, H, D) - - -class IsaacVisionEmbeddings(Siglip2VisionEmbeddings): - """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences.""" - - def __init__(self, config: IsaacVisionConfig): - super().__init__(config) + self.num_patches = config.num_patches + self.position_embedding_size = int(self.num_patches**0.5) + self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor: packed_pixel_values, seq_lengths = self._pack_to_batch(seq_patches, spatial_shapes) if packed_pixel_values is None: return seq_patches.new_zeros((0, self.embed_dim)) - embeddings = super().forward(packed_pixel_values, spatial_shapes) + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(packed_pixel_values.to(dtype=target_dtype)) + + positional_embeddings = self.position_embedding.weight.reshape( + self.position_embedding_size, + self.position_embedding_size, + -1, + ) + resized_positional_embeddings = self.resize_positional_embeddings( + positional_embeddings, spatial_shapes, max_length=packed_pixel_values.shape[1] + ) + + embeddings = patch_embeds + resized_positional_embeddings return self._unpack_from_batch(embeddings, seq_lengths) + @staticmethod + def resize_positional_embeddings( + positional_embeddings: torch.Tensor, + spatial_shapes: torch.LongTensor, + max_length: int, + ) -> torch.Tensor: + """ + Resize positional embeddings to image-specific size and pad to a fixed size. + + Args: + positional_embeddings (`torch.Tensor`): + Position embeddings of shape (height, width, embed_dim) + spatial_shapes (`torch.LongTensor`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + max_length (`int`): + Maximum length of the positional embeddings to pad resized positional embeddings to + + Returns: + `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) + """ + batch_size = spatial_shapes.shape[0] + embed_dim = positional_embeddings.shape[-1] + source_dtype = positional_embeddings.dtype + + resulted_positional_embeddings = torch.empty( + (batch_size, max_length, embed_dim), + device=positional_embeddings.device, + dtype=source_dtype, + ) + + # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation + positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0) + + # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU + if positional_embeddings.device.type == "cpu": + positional_embeddings = positional_embeddings.to(torch.float32) + + for i in range(batch_size): + # (1, dim, height, width) -> (1, dim, target_height, target_width) + height, width = spatial_shapes[i] + resized_embeddings = F.interpolate( + positional_embeddings, + size=(height, width), + mode="bilinear", + align_corners=False, + antialias=True, + ) + + # (1, dim, target_height, target_width) -> (target_height * target_width, dim) + resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1) + + # Cast to original dtype + resized_embeddings = resized_embeddings.to(source_dtype) + + resulted_positional_embeddings[i, : height * width] = resized_embeddings + resulted_positional_embeddings[i, height * width :] = resized_embeddings[0] + + return resulted_positional_embeddings + def _pack_to_batch( self, seq_patches: torch.Tensor, @@ -673,6 +664,7 @@ class IsaacVisionAttention(Siglip2Attention): "eager": "isaac_eager", "isaac_eager": "isaac_eager", } + _FLASH_IMPLS = frozenset(("isaac_flash_attention_2", "isaac_flash_attention_3")) def __init__(self, config): super().__init__(config) @@ -694,6 +686,8 @@ def forward(self, hidden_states, attention_mask=None, **kwargs): cu_seqlens = kwargs.pop("cu_seqlens", None) max_seqlen = kwargs.pop("max_seqlen", None) kwargs.pop("output_attentions", None) + kwargs.pop("output_hidden_states", None) + kwargs.pop("return_dict", None) if kwargs: unexpected = ", ".join(sorted(kwargs)) raise TypeError(f"Unexpected kwargs for IsaacVisionAttention.forward: {unexpected}") @@ -728,41 +722,178 @@ def forward(self, hidden_states, attention_mask=None, **kwargs): q.device, ) - resolved_key = self.ATTENTION_KEY_MAP.get(attn_impl) - attention_fn = ALL_ATTENTION_FUNCTIONS.get(resolved_key) if resolved_key is not None else None - if attention_fn is None: - raise ValueError(f"Attention implementation {attn_impl} not found.") + resolved_key = self.ATTENTION_KEY_MAP.get(attn_impl, attn_impl) + + attn_weights = None + if resolved_key in self._FLASH_IMPLS: + y_lhd = self._flash_attention_forward( + q_lhd=q, + k_lhd=k, + v_lhd=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout=p_drop, + ) + elif resolved_key == "isaac_sdpa": + y_lhd = self._sdpa_attention_forward( + q_lhd=q, + k_lhd=k, + v_lhd=v, + attention_mask=attn_mask, + cu_seqlens=cu_seqlens, + dropout=p_drop, + ) + elif resolved_key == "isaac_eager": + y_lhd, attn_weights = self._eager_attention_forward( + q_lhd=q, + k_lhd=k, + v_lhd=v, + attention_mask=attn_mask, + dropout=p_drop, + ) + else: + attention_fn = ALL_ATTENTION_FUNCTIONS.get(resolved_key) + if attention_fn is None: + raise ValueError(f"Attention implementation {attn_impl} not found.") + + query_states = q.transpose(0, 1).unsqueeze(0) + key_states = k.transpose(0, 1).unsqueeze(0) + value_states = v.transpose(0, 1).unsqueeze(0) + + attention_kwargs: dict[str, Any] = { + "dropout": p_drop, + "scaling": self.scale, + "is_causal": False, + } + if cu_seqlens is not None: + attention_kwargs["cu_seq_lens_q"] = cu_seqlens + attention_kwargs["cu_seq_lens_k"] = cu_seqlens + if max_seqlen is not None: + attention_kwargs["max_length_q"] = max_seqlen + attention_kwargs["max_length_k"] = max_seqlen + + attn_output, attn_weights = attention_fn( + self, + query_states, + key_states, + value_states, + attn_mask, + **attention_kwargs, + ) - query_states = q.transpose(0, 1).unsqueeze(0) - key_states = k.transpose(0, 1).unsqueeze(0) - value_states = v.transpose(0, 1).unsqueeze(0) + y_lhd = attn_output.squeeze(0).permute(1, 0, 2).contiguous() - attention_kwargs: dict[str, Any] = { - "dropout": p_drop, - "scaling": self.scale, - "is_causal": False, - } - if cu_seqlens is not None: - attention_kwargs["cu_seq_lens_q"] = cu_seqlens - attention_kwargs["cu_seq_lens_k"] = cu_seqlens + # Merge heads and project + y = self.out_proj(y_lhd.reshape(L, self.embed_dim)) + return y.unsqueeze(0), attn_weights # (1, L, E) + + @staticmethod + def _max_from_cu(cu: Optional[torch.Tensor], fallback: int) -> int: + if cu is None or cu.numel() < 2: + return fallback + return int((cu[1:] - cu[:-1]).max().item()) + + def _flash_attention_forward( + self, + *, + q_lhd: torch.Tensor, + k_lhd: torch.Tensor, + v_lhd: torch.Tensor, + cu_seqlens: Optional[torch.Tensor], + max_seqlen: Optional[int], + dropout: float, + ) -> torch.Tensor: + L = q_lhd.size(0) if max_seqlen is not None: - attention_kwargs["max_length_q"] = max_seqlen - attention_kwargs["max_length_k"] = max_seqlen - - attn_output, _ = attention_fn( - self, - query_states, - key_states, - value_states, - attn_mask, - **attention_kwargs, + max_q = max_k = int(max_seqlen) + else: + max_q = max_k = self._max_from_cu(cu_seqlens, L) + + if not q_lhd.is_contiguous(): + q_lhd = q_lhd.contiguous() + if not k_lhd.is_contiguous(): + k_lhd = k_lhd.contiguous() + if not v_lhd.is_contiguous(): + v_lhd = v_lhd.contiguous() + + out_lhd, *_ = torch.ops.aten._flash_attention_forward( + query=q_lhd, + key=k_lhd, + value=v_lhd, + cum_seq_q=cu_seqlens, + cum_seq_k=cu_seqlens, + max_q=max_q, + max_k=max_k, + dropout_p=dropout, + is_causal=False, + return_debug_mask=False, + scale=self.scale, + window_size_left=-1, + window_size_right=-1, + alibi_slopes=None, ) + return out_lhd - y_lhd = attn_output.squeeze(0).permute(1, 0, 2).contiguous() + def _sdpa_attention_forward( + self, + *, + q_lhd: torch.Tensor, + k_lhd: torch.Tensor, + v_lhd: torch.Tensor, + attention_mask: Optional[torch.Tensor], + cu_seqlens: Optional[torch.Tensor], + dropout: float, + ) -> torch.Tensor: + L = q_lhd.size(0) + attn_mask = attention_mask + if attn_mask is None: + attn_mask = build_document_attention_mask( + cu_seqlens=cu_seqlens, + total_tokens=L, + dtype=q_lhd.dtype, + device=q_lhd.device, + ) - # Merge heads and project - y = self.out_proj(y_lhd.reshape(L, self.embed_dim)) - return y.unsqueeze(0), None # (1, L, E) + q = q_lhd.permute(1, 0, 2).unsqueeze(0) + k = k_lhd.permute(1, 0, 2).unsqueeze(0) + v = v_lhd.permute(1, 0, 2).unsqueeze(0) + + if attn_mask is not None and attn_mask.dtype != q.dtype: + attn_mask = attn_mask.to(q.dtype) + + output = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=dropout, + scale=self.scale, + is_causal=False, + ) + return output.squeeze(0).permute(1, 0, 2).contiguous() + + def _eager_attention_forward( + self, + *, + q_lhd: torch.Tensor, + k_lhd: torch.Tensor, + v_lhd: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + attn_weights = torch.matmul(q_lhd, k_lhd.transpose(1, 2)) * self.scale + if attention_mask is not None: + mask = attention_mask + if mask.dim() == 4: + mask = mask.squeeze(0).squeeze(0) + attn_weights = attn_weights + mask + + attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_lhd.dtype) + if dropout and self.training: + attn_weights = F.dropout(attn_weights, p=dropout, training=True) + + attn_output_lhd = torch.matmul(attn_weights, v_lhd) + return attn_output_lhd, attn_weights class IsaacVisionEncoderLayer(Siglip2EncoderLayer): @@ -861,202 +992,6 @@ def forward( ) -def _isaac_flash_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - dropout: float = 0.0, - scaling: Optional[float] = None, - is_causal: bool = False, - **kwargs, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - base_fn = _ORIGINAL_ATTENTION_FUNCTIONS.get("flash_attention_2") - if not isinstance(module, IsaacVisionAttention) or base_fn is None: - if base_fn is None: - raise ValueError("Base flash attention function unavailable for fallback.") - return base_fn( - module, - query, - key, - value, - attention_mask, - dropout=dropout, - scaling=scaling, - is_causal=is_causal, - **kwargs, - ) - - if query.dim() != 4 or query.size(0) != 1: - raise ValueError( - "IsaacVisionAttention expects packed sequences with batch size 1 when using packed attention." - ) - - _, num_heads, seq_len, head_dim = query.shape - q_lhd = query.transpose(1, 2).reshape(seq_len, num_heads, head_dim) - k_lhd = key.transpose(1, 2).reshape(seq_len, num_heads, head_dim) - v_lhd = value.transpose(1, 2).reshape(seq_len, num_heads, head_dim) - - cum_seq_q = kwargs.get("cu_seq_lens_q") - cum_seq_k = kwargs.get("cu_seq_lens_k", cum_seq_q) - max_seqlen = kwargs.get("max_length_q") - - effective_dropout = dropout if dropout is not None else (module.dropout if module.training else 0.0) - effective_scaling = module.scale if scaling is None else scaling - - attn_mask = attention_mask - if attn_mask is None: - attn_mask = build_document_attention_mask( - cu_seqlens=cum_seq_q, - total_tokens=seq_len, - dtype=q_lhd.dtype, - device=q_lhd.device, - ) - - attn_output_lhd, attn_weights = flash_attention_document_mask_forward( - module, - q_lhd, - k_lhd, - v_lhd, - attention_mask=attn_mask, - dropout=effective_dropout, - scaling=effective_scaling, - cum_seq_q=cum_seq_q, - cum_seq_k=cum_seq_k, - max_seqlen=max_seqlen, - is_causal=is_causal, - ) - - attn_output = attn_output_lhd.permute(1, 0, 2).unsqueeze(0) - return attn_output, attn_weights - - -def _isaac_sdpa_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - dropout: float = 0.0, - scaling: Optional[float] = None, - is_causal: bool = False, - **kwargs, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - base_fn = _ORIGINAL_ATTENTION_FUNCTIONS.get("sdpa") - if not isinstance(module, IsaacVisionAttention) or base_fn is None: - if base_fn is None: - raise ValueError("Base SDPA function unavailable for fallback.") - return base_fn( - module, - query, - key, - value, - attention_mask, - dropout=dropout, - scaling=scaling, - is_causal=is_causal, - **kwargs, - ) - - if query.dim() != 4 or query.size(0) != 1: - raise ValueError( - "IsaacVisionAttention expects packed sequences with batch size 1 when using packed attention." - ) - - _, num_heads, seq_len, head_dim = query.shape - q_lhd = query.transpose(1, 2).reshape(seq_len, num_heads, head_dim) - k_lhd = key.transpose(1, 2).reshape(seq_len, num_heads, head_dim) - v_lhd = value.transpose(1, 2).reshape(seq_len, num_heads, head_dim) - - cum_seq = kwargs.get("cu_seq_lens_q") - effective_dropout = dropout if dropout is not None else (module.dropout if module.training else 0.0) - effective_scaling = module.scale if scaling is None else scaling - - attn_mask = attention_mask - if attn_mask is None: - attn_mask = build_document_attention_mask( - cu_seqlens=cum_seq, - total_tokens=seq_len, - dtype=q_lhd.dtype, - device=q_lhd.device, - ) - - attn_output_lhd = sdpa_document_mask_forward( - q_lhd, - k_lhd, - v_lhd, - dropout=effective_dropout, - scaling=effective_scaling, - attention_mask=attn_mask, - cu_seqlens=cum_seq, - ) - - attn_output = attn_output_lhd.permute(1, 0, 2).unsqueeze(0) - return attn_output, None - - -def _isaac_eager_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - dropout: float = 0.0, - scaling: Optional[float] = None, - is_causal: bool = False, - **kwargs, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - base_fn = _ORIGINAL_ATTENTION_FUNCTIONS.get("eager") - if not isinstance(module, IsaacVisionAttention) or base_fn is None: - if base_fn is None: - raise ValueError("Base eager attention function unavailable for fallback.") - return base_fn( - module, - query, - key, - value, - attention_mask, - dropout=dropout, - scaling=scaling, - is_causal=is_causal, - **kwargs, - ) - - if query.dim() != 4 or query.size(0) != 1: - raise ValueError( - "IsaacVisionAttention expects packed sequences with batch size 1 when using packed attention." - ) - - _, num_heads, seq_len, head_dim = query.shape - q_lhd = query.transpose(1, 2).reshape(seq_len, num_heads, head_dim) - k_lhd = key.transpose(1, 2).reshape(seq_len, num_heads, head_dim) - v_lhd = value.transpose(1, 2).reshape(seq_len, num_heads, head_dim) - - effective_scaling = module.scale if scaling is None else scaling - attn_weights = torch.matmul(q_lhd, k_lhd.transpose(1, 2)) * effective_scaling - - if attention_mask is not None: - mask = attention_mask - if mask.dim() == 4: - mask = mask.squeeze(0).squeeze(0) - attn_weights = attn_weights + mask - - attn_weights = torch.softmax(attn_weights, dim=-1) - if dropout and module.training: - attn_weights = F.dropout(attn_weights, p=dropout, training=True) - - attn_output_lhd = torch.matmul(attn_weights, v_lhd) - attn_output = attn_output_lhd.permute(1, 0, 2).unsqueeze(0) - return attn_output, attn_weights - - -ALL_ATTENTION_FUNCTIONS.register("isaac_flash_attention_2", _isaac_flash_attention_forward) -ALL_ATTENTION_FUNCTIONS.register("isaac_flash_attention_3", _isaac_flash_attention_forward) -ALL_ATTENTION_FUNCTIONS.register("isaac_sdpa", _isaac_sdpa_forward) -ALL_ATTENTION_FUNCTIONS.register("isaac_eager", _isaac_eager_forward) - - def create_pixel_shuffle_index_map( seq_sizes: torch.Tensor, token_grids: torch.Tensor, @@ -1371,6 +1306,7 @@ def __init__( **kwargs, ): self._rope_scaling: Optional[dict[str, Any]] = None + self._rope_parameters: Optional[dict[str, Any]] = None resolved_text_config = kwargs.pop("text_config", text_config) if isinstance(resolved_text_config, Qwen3Config): text_config_kwargs = copy.deepcopy(resolved_text_config.to_dict()) @@ -1384,6 +1320,11 @@ def __init__( text_config_kwargs.update(kwargs) self.text_config = self.sub_configs["text_config"](**text_config_kwargs) + if not hasattr(self.text_config, "rope_theta"): + rope_theta_override = text_config_kwargs.get("rope_theta", kwargs.get("rope_theta")) + if rope_theta_override is None: + rope_theta_override = getattr(Qwen3Config(), "rope_theta", 10000.0) + self.text_config.rope_theta = rope_theta_override super().__init__(**kwargs) @@ -1392,6 +1333,9 @@ def __init__( else: self.text_config.rope_scaling = self._rope_scaling + # Keep rope parameters alias in sync with upstream expectations + self._rope_parameters = self._rope_scaling + # Mirror frequently accessed Qwen3 attributes at the composite config level for BC. self.tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False) self.vocab_size = self.text_config.vocab_size @@ -1462,6 +1406,21 @@ def rope_scaling(self, value): if hasattr(self, "text_config") and self.text_config is not None: self.text_config.rope_scaling = value + @property + def rope_parameters(self) -> dict[str, Any] | None: + """Alias introduced upstream for rope scaling dictionaries.""" + value = self._rope_parameters + if value is None: + value = self.rope_scaling + if value is None: + return {"rope_type": "default"} + return value + + @rope_parameters.setter + def rope_parameters(self, value: dict[str, Any] | None) -> None: + self._rope_parameters = value + self.rope_scaling = value + @property def vision_attn_implementation(self) -> Optional[str]: value = getattr(self.vision_config, "_attn_implementation", None) @@ -1477,6 +1436,15 @@ def vision_attn_implementation(self, value: Optional[str]) -> None: elif hasattr(self.vision_config, "attn_implementation"): delattr(self.vision_config, "attn_implementation") + def to_dict(self): + output = super().to_dict() + # Ensure nested configs round-trip through dict serialization + if hasattr(self, "text_config") and self.text_config is not None: + output["text_config"] = self.text_config.to_dict() + if hasattr(self, "vision_config") and self.vision_config is not None: + output["vision_config"] = self.vision_config.to_dict() + return output + # ============================================================================ # Processor Components @@ -1959,6 +1927,19 @@ def forward( elif inputs_embeds is None: raise ValueError("You have to specify either tensor_stream, input_ids or inputs_embeds") + # Ensure cache exists when requested + if use_cache and past_key_values is None: + cache_config = self.config.get_text_config() if hasattr(self.config, "get_text_config") else self.config + past_key_values = DynamicCache(config=cache_config) + + if cache_position is None and (past_key_values is not None or use_cache): + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + # Create default position_ids if not provided if position_ids is None: if tensor_stream is not None: @@ -1989,7 +1970,7 @@ def forward( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=(cos, sin), @@ -2305,6 +2286,23 @@ def prepare_inputs_for_generation( """ Prepare inputs for generation, handling TensorStream inputs properly. """ + if cache_position is None: + seq_length = None + device = None + if input_ids is not None: + seq_length = input_ids.shape[1] + device = input_ids.device + elif inputs_embeds is not None: + seq_length = inputs_embeds.shape[1] + device = inputs_embeds.device + elif tensor_stream is not None: + _, seq_length = tensor_stream.shape + device = tensor_stream.device + if seq_length is not None: + # prepare_inputs_for_generation may be invoked outside `generate`, so synthesize the + # same cache positions that GenerationMixin would have created during prefill. + cache_position = torch.arange(seq_length, dtype=torch.long, device=device) + # Call parent preparation model_inputs = super().prepare_inputs_for_generation( input_ids, @@ -2317,6 +2315,8 @@ def prepare_inputs_for_generation( **kwargs, ) + cache_position = model_inputs.get("cache_position", cache_position) + # Handle TensorStream for first forward pass only if tensor_stream is not None and (cache_position is None or cache_position[0] == 0): model_inputs["tensor_stream"] = tensor_stream diff --git a/tests/models/isaac/isaac_checkpoint_hashes.json b/tests/models/isaac/isaac_checkpoint_hashes.json new file mode 100644 index 000000000000..1898c3e23955 --- /dev/null +++ b/tests/models/isaac/isaac_checkpoint_hashes.json @@ -0,0 +1,5 @@ +{ + "full_model": "e00d024be29cc0a6790dc9f3c2504ad12176dea2332fe342a6272d2c92efdef5", + "core_model": "24bd017cf86d113aefb08bf7d109a196fae831e7a57300c4be05a8f78d2b4b6e", + "vision_modules": "c90c2e60f270c96a7a5c7c2e93815158b5f3247ab076c819da0f4f6358d033c5" +} diff --git a/tests/models/isaac/isaac_generation_golden.json b/tests/models/isaac/isaac_generation_golden.json new file mode 100644 index 000000000000..c662e22e2670 --- /dev/null +++ b/tests/models/isaac/isaac_generation_golden.json @@ -0,0 +1,450 @@ +{ + "logits_statistics": { + "shape": [ + 10, + 151936 + ], + "numel": 1519360, + "mean": 0.0666899336, + "std": 2.8427821364, + "min": -12.0625, + "max": 31.0, + "sum": 101326.0175427794, + "l2_norm": 3505.0433579135 + }, + "input_ids": [ + [ + 151644, + 872, + 198, + 74785, + 419, + 2168, + 25, + 151645, + 198, + 151644, + 872, + 198, + 768, + 743, + 480, + 159, + -154, + -256, + -256, + -256, + -256, + -256, + -256, + -154, + 159, + 480, + 743, + 768, + 743, + 718, + 462, + 149, + -157, + -256, + -256, + -256, + -256, + -256, + -256, + -157, + 149, + 462, + 718, + 743, + 480, + 462, + 273, + 42, + -183, + -256, + -256, + -256, + -256, + -256, + -256, + -183, + 42, + 273, + 462, + 480, + 159, + 149, + 43, + -87, + -214, + -256, + -256, + -256, + -256, + -256, + -256, + -214, + -87, + 43, + 149, + 159, + 151645, + 198, + 151644, + 77091 + ] + ], + "tensor_stream": { + "shape": [ + 1, + 80 + ], + "token_view": [ + [ + 151644, + 872, + 198, + 74785, + 419, + 2168, + 25, + 151645, + 198, + 151644, + 872, + 198, + 768, + 743, + 480, + 159, + -154, + -256, + -256, + -256, + -256, + -256, + -256, + -154, + 159, + 480, + 743, + 768, + 743, + 718, + 462, + 149, + -157, + -256, + -256, + -256, + -256, + -256, + -256, + -157, + 149, + 462, + 718, + 743, + 480, + 462, + 273, + 42, + -183, + -256, + -256, + -256, + -256, + -256, + -256, + -183, + 42, + 273, + 462, + 480, + 159, + 149, + 43, + -87, + -214, + -256, + -256, + -256, + -256, + -256, + -256, + -214, + -87, + 43, + 149, + 159, + 151645, + 198, + 151644, + 77091 + ] + ], + "modality_mask": [ + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1 + ] + ], + "role_mask": [ + [ + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1 + ] + ] + }, + "decoded_text": "user\nDescribe this image:\nuser\nug\tifable\ufffd\ufffdable\tifug\tifadd pro\ufffd\ufffd proadd\tifable proleKKle proable\ufffd\ufffdLL\ufffd\ufffd\nassistant\n\n\n\n\nThe image is a close", + "token_ids": [ + 151644, + 872, + 198, + 74785, + 419, + 2168, + 25, + 151645, + 198, + 151644, + 872, + 198, + 768, + 743, + 480, + 159, + -154, + -256, + -256, + -256, + -256, + -256, + -256, + -154, + 159, + 480, + 743, + 768, + 743, + 718, + 462, + 149, + -157, + -256, + -256, + -256, + -256, + -256, + -256, + -157, + 149, + 462, + 718, + 743, + 480, + 462, + 273, + 42, + -183, + -256, + -256, + -256, + -256, + -256, + -256, + -183, + 42, + 273, + 462, + 480, + 159, + 149, + 43, + -87, + -214, + -256, + -256, + -256, + -256, + -256, + -256, + -214, + -87, + 43, + 149, + 159, + 151645, + 198, + 151644, + 77091, + 198, + 151667, + 271, + 151668, + 271, + 785, + 2168, + 374, + 264, + 3265 + ] +} diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py index 58f24c4e637a..2f4405a04662 100644 --- a/tests/models/isaac/test_modeling_isaac.py +++ b/tests/models/isaac/test_modeling_isaac.py @@ -1,7 +1,35 @@ +import base64 +import hashlib +import io +import json +import os import unittest +from functools import lru_cache +from pathlib import Path -from transformers import IsaacConfig, IsaacForConditionalGeneration, IsaacModel, is_torch_available -from transformers.testing_utils import require_torch, torch_device +import pytest + +from transformers import ( + AutoProcessor, + AutoTokenizer, + IsaacConfig, + IsaacForConditionalGeneration, + IsaacModel, + PreTrainedTokenizer, + is_torch_available, +) +from transformers.models.isaac.configuration_isaac import IsaacVisionConfig +from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessorFast +from transformers.models.isaac.modeling_isaac import IsaacVisionAttention +from transformers.models.isaac.processing_isaac import IsaacProcessor +from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.utils import is_offline_mode, is_vision_available + + +if is_vision_available(): + from PIL import Image +else: + Image = None from ...test_configuration_common import ConfigTester from ...test_modeling_common import ids_tensor @@ -10,6 +38,354 @@ if is_torch_available(): import torch +try: + from perceptron.tensorstream.ops import modality_mask, role_mask, tensor_stream_token_view + from perceptron.tensorstream.tensorstream import TensorStream +except Exception: + TensorStream = None + + +tensorstream_required = pytest.mark.skipif(TensorStream is None, reason="TensorStream backend is not available") + +MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1-Base") +MODEL_REVISION = os.environ.get("ISAAC_TEST_MODEL_REVISION", "refs/pr/3") or None +LOCAL_CHECKPOINT = os.environ.get("ISAAC_TEST_MODEL_PATH") +HASH_FILE = Path(__file__).with_name("isaac_checkpoint_hashes.json") +GENERATION_GOLDEN_FILE = Path(__file__).with_name("isaac_generation_golden.json") +HASH_FILTERS = { + "full_model": {"include": None, "exclude": None}, + "core_model": {"include": None, "exclude": {"vision_embedding", "audio_embedding", "inv_freq"}}, + "vision_modules": {"include": {"vision_embedding"}, "exclude": None}, +} +RED_DOT_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg==" + + +def tensor_stream_snapshot(ts: TensorStream) -> dict[str, object]: + """Summarize TensorStream tokens/modalities using public utilities.""" + + token_view = tensor_stream_token_view(ts).cpu().tolist() + modality = modality_mask(ts).cpu().tolist() + roles = role_mask(ts).cpu().tolist() + + return { + "shape": list(ts.shape), + "token_view": token_view, + "modality_mask": modality, + "role_mask": roles, + } + + +def _assert_tensor_stream_snapshot_equal(actual: dict[str, object], expected: dict[str, object]) -> None: + assert actual["shape"] == expected["shape"], "TensorStream shape changed" + assert actual["token_view"] == expected["token_view"], "TensorStream token view changed" + assert actual["modality_mask"] == expected["modality_mask"], "TensorStream modality mask changed" + assert actual["role_mask"] == expected["role_mask"], "TensorStream role mask changed" + + +def _tensor_to_bytes(tensor): + cpu_tensor = tensor.detach().cpu().contiguous() + if cpu_tensor.is_floating_point(): + cpu_tensor = cpu_tensor.to(dtype=torch.float32) + return cpu_tensor.numpy().tobytes() + + +def _iter_filtered_items(state_dict, include=None, exclude=None): + for name, tensor in state_dict.items(): + if include and not any(token in name for token in include): + continue + if exclude and any(token in name for token in exclude): + continue + yield name, tensor + + +def _hash_state_dict(state_dict, *, include=None, exclude=None): + hasher = hashlib.sha256() + items = sorted(_iter_filtered_items(state_dict, include=include, exclude=exclude), key=lambda kv: kv[0]) + for name, tensor in items: + hasher.update(name.encode("utf-8")) + hasher.update(b"\0") + hasher.update(_tensor_to_bytes(tensor)) + return hasher.hexdigest() + + +def compute_logits_statistics(tensor: torch.Tensor) -> dict[str, object]: + """ + Summarize logits with simple statistics that are stable across minor + implementation changes yet still sensitive to behavioral regressions. + """ + + float_tensor = tensor.detach().to(torch.float32).cpu() + flat = float_tensor.reshape(-1).to(torch.float64) + + def _rounded(value: torch.Tensor | float) -> float: + return round(float(value), 10) + + return { + "shape": list(float_tensor.shape), + "numel": flat.numel(), + "mean": _rounded(flat.mean()), + "std": _rounded(flat.std(unbiased=False)), + "min": _rounded(flat.min()), + "max": _rounded(flat.max()), + "sum": _rounded(flat.sum()), + "l2_norm": _rounded(torch.linalg.vector_norm(flat, ord=2)), + } + + +def _assert_logits_statistics_close( + actual: dict[str, object], + expected: dict[str, object], + *, + rel: float = 1e-5, + abs_tol: float = 1e-6, +) -> None: + assert actual["shape"] == expected["shape"], "Logits shape changed" + assert actual["numel"] == expected["numel"], "Logits numel changed" + for key in ("mean", "std", "min", "max", "sum", "l2_norm"): + assert actual[key] == pytest.approx( + expected[key], + rel=rel, + abs=abs_tol, + ), f"Logits statistic '{key}' drifted" + + +def _hf_from_pretrained(cls, pretrained_id, **kwargs): + """ + Wrapper around `cls.from_pretrained` that automatically injects + the test revision (if any) from MODEL_REVISION. + """ + if MODEL_REVISION is not None: + kwargs.setdefault("revision", MODEL_REVISION) + return cls.from_pretrained(pretrained_id, **kwargs) + + +@pytest.fixture(scope="session") +def tokenizer(isaac_reference_checkpoint): + """Load the tokenizer from the converted Perceptron HF checkpoint.""" + return _hf_from_pretrained( + AutoTokenizer, + isaac_reference_checkpoint, + trust_remote_code=True, + use_fast=False, + ) + + +@require_torch +def test_isaac_sdpa_attention_backend(): + config = IsaacVisionConfig( + hidden_size=32, + intermediate_size=64, + num_hidden_layers=1, + num_attention_heads=4, + num_channels=3, + num_patches=16, + patch_size=4, + ) + config._attn_implementation = "sdpa" + + attn_module = IsaacVisionAttention(config).eval() + seq_len = 8 + hidden_states = torch.randn(1, seq_len, config.hidden_size) + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32) + + with torch.no_grad(): + outputs, attn_weights = attn_module( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=seq_len, + ) + + assert outputs.shape == hidden_states.shape + assert attn_weights is None + + +def _hash_tensor(tensor): + hasher = hashlib.sha256() + hasher.update(_tensor_to_bytes(tensor)) + return hasher.hexdigest() + + +@lru_cache(maxsize=1) +def _load_expected_hashes(): + if not HASH_FILE.exists(): + return None + with HASH_FILE.open("r", encoding="utf-8") as fh: + return json.load(fh) + + +@lru_cache(maxsize=1) +def _load_generation_golden(): + if not GENERATION_GOLDEN_FILE.exists(): + return None + with GENERATION_GOLDEN_FILE.open("r", encoding="utf-8") as fh: + return json.load(fh) + + +def safe_decode(tokenizer, token_ids): + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.tolist() + try: + text = tokenizer.decode(token_ids, skip_special_tokens=True) + except Exception: + tokens = tokenizer.convert_ids_to_tokens(token_ids, skip_special_tokens=True) + tokens = [tok for tok in tokens if tok is not None] + text = tokenizer.convert_tokens_to_string(tokens) + return text.strip() if isinstance(text, str) else text + + +@lru_cache(maxsize=1) +def _load_red_dot_image(): + if Image is None: + return None + data = base64.b64decode(RED_DOT_B64) + return Image.open(io.BytesIO(data)).convert("RGB") + + +def _reference_checkpoint_or_skip(): + if TensorStream is None: + pytest.skip("TensorStream dependency is required for Isaac integration tests.") + if LOCAL_CHECKPOINT: + resolved = Path(LOCAL_CHECKPOINT).expanduser() + if not resolved.exists(): + pytest.skip(f"Local checkpoint path {resolved} does not exist.") + return str(resolved) + if is_offline_mode(): + pytest.skip("Offline mode: set ISAAC_TEST_MODEL_PATH to a local checkpoint to run these tests.") + return MODEL_ID + + +class SimpleIsaacTokenizer(PreTrainedTokenizer): + vocab_files_names = {} + model_input_names = ["input_ids"] + + def __init__(self): + self._vocab = { + "": 0, + "": 1, + "": 2, + "": 3, + "": 4, + } + self._ids_to_tokens = {idx: tok for tok, idx in self._vocab.items()} + super().__init__( + bos_token="", + eos_token="", + pad_token="", + unk_token="", + additional_special_tokens=[""], + model_max_length=512, + ) + self.chat_template = ( + "{% for message in messages %}" + "{{ message['role'] }}: {{ message['content'] | trim }}\n" + "{% endfor %}" + "{% if add_generation_prompt %}assistant:{% endif %}" + ) + + def get_vocab(self): + return dict(self._vocab) + + def _tokenize(self, text): + clean = text.replace("\n", " ").strip() + if not clean: + return [] + return [token for token in clean.split(" ") if token] + + def _convert_token_to_id(self, token): + if token not in self._vocab: + next_id = len(self._vocab) + self._vocab[token] = next_id + self._ids_to_tokens[next_id] = token + return self._vocab[token] + + def _convert_id_to_token(self, index): + return self._ids_to_tokens.get(index, self.unk_token) + + @property + def vocab_size(self) -> int: + return len(self._vocab) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + if token_ids_1 is not None: + token_ids_0 = token_ids_0 + token_ids_1 + return [self.bos_token_id] + list(token_ids_0) + [self.eos_token_id] + + def save_vocabulary(self, save_directory, filename_prefix=None): + return () + + +def _make_dummy_image(size=(32, 32), color=(255, 0, 0)): + if Image is None: + raise RuntimeError("PIL.Image is not available in this environment.") + return Image.new("RGB", size, color=color) + + +@pytest.fixture +def isaac_tiny_config(): + tester = IsaacModelTester(parent=None) + return tester.get_config() + + +@pytest.fixture +def isaac_tokenizer(): + return SimpleIsaacTokenizer() + + +@pytest.fixture +def isaac_processor(isaac_tokenizer, isaac_tiny_config): + vision_config = isaac_tiny_config.vision_config + image_processor = IsaacImageProcessorFast( + patch_size=vision_config.patch_size, + max_num_patches=vision_config.num_patches, + pixel_shuffle_scale=vision_config.pixel_shuffle_scale_factor, + rescale_factor=isaac_tiny_config.vision_rescale_factor, + ) + return IsaacProcessor( + image_processor=image_processor, + tokenizer=isaac_tokenizer, + config=isaac_tiny_config, + ) + + +@pytest.fixture(scope="session") +def isaac_reference_checkpoint(): + return _reference_checkpoint_or_skip() + + +@pytest.fixture(scope="session") +def isaac_config(isaac_reference_checkpoint): + """Load IsaacConfig from the converted checkpoint.""" + # Load the config directly from the converted checkpoint + config = _hf_from_pretrained(IsaacConfig, isaac_reference_checkpoint) + # Most tests assume flash attention in vision unless they explicitly override it. + config.vision_attn_implementation = "flash_attention_2" + return config + + +@pytest.fixture(scope="session") +def isaac_reference_model(isaac_reference_checkpoint, isaac_config): + model_config = IsaacConfig.from_dict(isaac_config.to_dict()) + model_config.vision_attn_implementation = isaac_config.vision_attn_implementation + model = _hf_from_pretrained( + IsaacForConditionalGeneration, + isaac_reference_checkpoint, + config=model_config, + attn_implementation="sdpa", + ) + return model + + +@pytest.fixture(scope="session") +def isaac_reference_processor(isaac_reference_checkpoint): + try: + processor = _hf_from_pretrained(AutoProcessor, isaac_reference_checkpoint) + except (OSError, ValueError) as error: + raise RuntimeError(f"Unable to load reference Isaac processor from {isaac_reference_checkpoint}") from error + print(f"[Isaac tests] Loaded processor type: {type(processor)} from {isaac_reference_checkpoint}") + if not isinstance(processor, IsaacProcessor): + pytest.skip("Loaded processor is not an IsaacProcessor instance.") + return processor + class IsaacModelTester: def __init__( @@ -44,7 +420,11 @@ def __init__( "num_attention_heads": num_attention_heads, "num_hidden_layers": num_hidden_layers, "num_key_value_heads": num_attention_heads, - "rope_parameters": {"rope_type": "default", "mrope_section": [16, 8, 8], "mrope_interleaved": True}, + # Keep the same multi-RoPE setup as the reference checkpoints but shrink the + # sections so they sum to the rotary half-dimension (4) of this tiny test model. + "rope_parameters": {"rope_type": "default", "mrope_section": [2, 1, 1], "mrope_interleaved": True}, + # Qwen3 config expects `rope_theta` to be present on the text sub-config, so we + # set it explicitly to mimic real checkpoints and keep attribute mirroring working. "rope_theta": 10000, "tie_word_embeddings": True, } @@ -139,3 +519,291 @@ def test_prepare_inputs_for_generation(self): self.assertIn("input_ids", prepared_inputs) self.assertIn("position_ids", prepared_inputs) self.assertIsNone(prepared_inputs["position_ids"]) + + +def test_isaac_config_extends_qwen3_defaults(isaac_tiny_config): + assert isaac_tiny_config.hidden_size == isaac_tiny_config.text_config.hidden_size + assert isaac_tiny_config.num_attention_heads == isaac_tiny_config.text_config.num_attention_heads + assert isaac_tiny_config.model_type == "isaac" + assert isaac_tiny_config.vision_config is not None + assert isaac_tiny_config.vision_config.patch_size == 4 + assert isaac_tiny_config.vision_config.num_patches == 64 + assert isaac_tiny_config.max_sequence_length == 16384 + assert isaac_tiny_config.vision_rescale_factor == pytest.approx(1 / 255) + assert isaac_tiny_config.vision_token == "" + + +@require_torch +def test_isaac_for_conditional_generation_initialization(isaac_tiny_config): + model = IsaacForConditionalGeneration(isaac_tiny_config) + model.to(torch_device) + assert hasattr(model, "model") + assert hasattr(model, "lm_head") + assert hasattr(model.model, "vision_embedding") + assert hasattr(model.model, "embed_fns") + + input_ids = torch.randint(0, isaac_tiny_config.vocab_size, (1, 10), device=torch_device, dtype=torch.long) + with torch.no_grad(): + outputs = model(input_ids=input_ids, return_dict=True) + assert outputs.logits.shape == (1, 10, isaac_tiny_config.vocab_size) + + +@require_torch +def test_isaac_for_conditional_generation_loss_and_generate_flag(isaac_tiny_config): + model = IsaacForConditionalGeneration(isaac_tiny_config).to(torch_device) + assert model.can_generate() + + batch_size, seq_len = 1, 8 + input_ids = torch.randint(0, isaac_tiny_config.vocab_size, (batch_size, seq_len), device=torch_device) + labels = torch.randint(0, isaac_tiny_config.vocab_size, (batch_size, seq_len), device=torch_device) + with torch.no_grad(): + outputs = model(input_ids=input_ids, labels=labels, return_dict=True) + assert outputs.loss is not None + assert outputs.loss.ndim == 0 + assert outputs.logits.shape == (batch_size, seq_len, isaac_tiny_config.vocab_size) + + +@require_torch +@require_vision +@tensorstream_required +def test_isaac_processor_matches_config_defaults(isaac_processor, isaac_tiny_config): + assert isaac_processor.vision_token == isaac_tiny_config.vision_token + assert isaac_processor.max_sequence_length == isaac_tiny_config.max_sequence_length + assert isaac_processor.config is isaac_tiny_config + assert isinstance(isaac_processor.image_processor, IsaacImageProcessorFast) + assert isaac_processor.image_processor.rescale_factor == pytest.approx(isaac_tiny_config.vision_rescale_factor) + + +@require_torch +@require_vision +@tensorstream_required +def test_isaac_processor_text_only_round_trip(isaac_processor): + messages = [{"role": "user", "content": "Hello, how are you?"}] + prompt = isaac_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + outputs = isaac_processor(text=prompt, images=None, return_tensors="pt") + + assert "input_ids" in outputs + assert "tensor_stream" in outputs + assert isinstance(outputs["tensor_stream"], TensorStream) + assert outputs["input_ids"].shape[0] == 1 + + +@require_torch +@require_vision +@tensorstream_required +def test_isaac_processor_with_single_image(isaac_processor): + vision_token = isaac_processor.vision_token + text = f"Look at this {vision_token} and describe it." + image = _make_dummy_image() + + outputs = isaac_processor(text=text, images=[image], return_tensors="pt") + assert isinstance(outputs["tensor_stream"], TensorStream) + assert outputs["input_ids"].ndim == 2 + + +@require_torch +@require_vision +@tensorstream_required +def test_isaac_processor_with_multiple_images(isaac_processor): + vision_token = isaac_processor.vision_token + text = f"First {vision_token} then {vision_token}" + images = [_make_dummy_image(color=(255, 0, 0)), _make_dummy_image(color=(0, 255, 0))] + + outputs = isaac_processor(text=text, images=images, return_tensors="pt") + assert isinstance(outputs["tensor_stream"], TensorStream) + assert outputs["input_ids"].shape[0] == 1 + + +@require_torch +@require_vision +@tensorstream_required +def test_isaac_processor_error_on_image_mismatch(isaac_processor): + vision_token = isaac_processor.vision_token + text = f"{vision_token} {vision_token}" + image = _make_dummy_image() + + with pytest.raises(ValueError, match="must match number of images"): + isaac_processor(text=text, images=[image], return_tensors="pt") + + +@require_torch +@require_vision +@tensorstream_required +def test_isaac_processor_consistent_tensor_stream_types(isaac_processor): + text_only = "Simple question?" + text_with_image = f"Describe this {isaac_processor.vision_token}" + image = _make_dummy_image() + + outputs_text = isaac_processor(text=text_only, images=None, return_tensors="pt") + outputs_image = isaac_processor(text=text_with_image, images=[image], return_tensors="pt") + + assert isinstance(outputs_text["tensor_stream"], TensorStream) + assert isinstance(outputs_image["tensor_stream"], TensorStream) + assert outputs_text["input_ids"].shape[0] == outputs_image["input_ids"].shape[0] == 1 + + +@require_torch +@require_vision +@tensorstream_required +def test_isaac_generation_with_tensor_stream(isaac_processor, isaac_tiny_config): + model = IsaacForConditionalGeneration(isaac_tiny_config).to(torch_device) + model.eval() + + messages = [{"role": "user", "content": "Hello there!"}] + prompt = isaac_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + processed = isaac_processor(text=prompt, images=None, return_tensors="pt") + + input_ids = processed["input_ids"].to(torch_device) + tensor_stream = processed["tensor_stream"] + tensor_stream = tensor_stream.to(torch_device) + generated = model.generate( + input_ids=input_ids, + tensor_stream=tensor_stream, + max_new_tokens=5, + do_sample=False, + pad_token_id=isaac_processor.tokenizer.pad_token_id, + eos_token_id=isaac_processor.tokenizer.eos_token_id, + ) + + assert generated.shape[0] == 1 + assert generated.shape[1] >= input_ids.shape[1] + decoded_prompt = isaac_processor.tokenizer.decode(generated[0], skip_special_tokens=True) + assert isinstance(decoded_prompt, str) + assert decoded_prompt.strip() != "" + + +@require_torch +@slow +@tensorstream_required +def test_isaac_checkpoint_hashes(isaac_reference_model): + isaac_reference_model = isaac_reference_model.to("cpu") + expected_hashes = _load_expected_hashes() + if not expected_hashes: + pytest.skip(f"Missing golden hashes file at {HASH_FILE}.") + + missing = [subset for subset in HASH_FILTERS if subset not in expected_hashes] + if missing: + pytest.skip(f"Golden hashes missing entries for: {', '.join(missing)}") + + isaac_reference_model.to("cpu") + state_dict = isaac_reference_model.state_dict() + for subset, filters in HASH_FILTERS.items(): + current_hash = _hash_state_dict(state_dict, include=filters["include"], exclude=filters["exclude"]) + assert current_hash == expected_hashes[subset], f"Hash mismatch for subset '{subset}'" + + +def create_isaac_processor( + tokenizer, + isaac_config, + *, + image_processor=None, + **overrides, +): + """Helper to construct IsaacProcessor without requiring an IsaacConfig instance.""" + params = { + "vision_token": isaac_config.vision_token, + "max_sequence_length": isaac_config.max_sequence_length, + "vision_patch_size": isaac_config.vision_patch_size, + "vision_max_num_patches": isaac_config.vision_max_num_patches, + "vision_min_num_patches": isaac_config.vision_min_num_patches, + "pixel_shuffle_scale": isaac_config.pixel_shuffle_scale, + "rescale_factor": isaac_config.vision_rescale_factor, + "image_mean": tuple(isaac_config.vision_mean), + "image_std": tuple(isaac_config.vision_std), + "vision_attn_implementation": isaac_config.vision_attn_implementation, + } + params.update(overrides) + + processor_image = image_processor + if processor_image is None: + processor_image = IsaacImageProcessorFast( + patch_size=params["vision_patch_size"], + max_num_patches=params["vision_max_num_patches"], + min_num_patches=params["vision_min_num_patches"], + pixel_shuffle_scale=params["pixel_shuffle_scale"], + rescale_factor=params["rescale_factor"], + image_mean=params["image_mean"], + image_std=params["image_std"], + ) + processor_params = { + "vision_token": isaac_config.vision_token, + "max_sequence_length": isaac_config.max_sequence_length, + "rescale_factor": isaac_config.vision_rescale_factor, + } + + return IsaacProcessor( + image_processor=processor_image, + tokenizer=tokenizer, + config=isaac_config, + **processor_params, + ) + + +@require_torch +@require_vision +@slow +@tensorstream_required +def test_hf_generate_vs_training_generate_logits(isaac_reference_model, isaac_reference_processor): + device = "cuda" + dtype = torch.bfloat16 + isaac_reference_model = isaac_reference_model.to(device=device, dtype=dtype) + isaac_reference_model.eval() + golden = _load_generation_golden() + if not golden: + pytest.skip(f"Missing generation golden file at {GENERATION_GOLDEN_FILE}.") + + image = _load_red_dot_image() + if image is None: + pytest.skip("PIL.Image is required for Isaac generation tests.") + + messages = [ + { + "role": "user", + "content": "Describe this image:", + }, + { + "role": "user", + "content": "", + }, + ] + prompt = isaac_reference_processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ).strip() + batch = isaac_reference_processor(text=prompt, images=[image], return_tensors="pt") + + input_ids = batch["input_ids"] + tensor_stream = batch["tensor_stream"] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + isaac_reference_model.to(device) + input_ids = input_ids.to(device) + if tensor_stream is not None and hasattr(tensor_stream, "to"): + tensor_stream = tensor_stream.to(device) + + torch.manual_seed(0) + with torch.no_grad(): + outputs = isaac_reference_model.generate( + input_ids=input_ids, + tensor_stream=tensor_stream, + max_new_tokens=10, + do_sample=False, + pad_token_id=isaac_reference_processor.tokenizer.eos_token_id, + eos_token_id=isaac_reference_processor.tokenizer.eos_token_id, + return_dict_in_generate=True, + output_logits=True, + ) + + logits = torch.cat(outputs.logits, dim=0).to(torch.float32).cpu() + logits_stats = compute_logits_statistics(logits) + generated_ids = outputs.sequences[0].tolist() + + assert generated_ids == golden["token_ids"], "Generated token ids changed" + if "logits_statistics" in golden: + _assert_logits_statistics_close(logits_stats, golden["logits_statistics"]) + else: + pytest.fail( + "Golden file missing both logits_statistics and logits_hash. " + f"Regenerate {GENERATION_GOLDEN_FILE} via scripts/update_isaac_hashes.py." + ) + + isaac_reference_model.to("cpu") From 046309904aa70db26119bcce96753eb7ea201d0d Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Wed, 3 Dec 2025 19:31:57 +0400 Subject: [PATCH 46/77] fix: update TensorType import for latest changes in transformers main (#5) * test: fix mrope expectation * fix: default cache position setting * test: add golden values * fix: add isaac image processor to auto * fix: handle new delegated cache functionality * fix: update deprecated argument name * fix: pop unused args to pass validation * fix: alias rope for compatibility with checkpoints * test: extensive e2e testing of isaac * feat: convert to modular artifacts * test: add vision token to pre-defined messages * fix: hand roll siglip embeddings so it's moved properly to modeling isaac * test: fix separate vision token to isolated message * fix: move implementation into class so implementations are copied to modeling file * test: save logit stats for more interpretable test deviations * WIP tests are passing! * test: fix off by one from not stripping whitespace * test: generation consistency tests passing * test: test-specific dtype and device * test: add vocab size property for testing utils compatibility * test: move tensorstream to correct device * test: linting * fix: unregister * chore: linting * fix: remporarily remove vision embeddings copy for fix copies * test: update tests to load model from revision with updated configs * chore: latest transformers modular convert artifact * chore: make fixup completion artifact * fix: update TensorType import * chore: convert utility artifacts --- .../models/isaac/image_processing_isaac_fast.py | 3 +-- src/transformers/models/isaac/modeling_isaac.py | 6 +++--- src/transformers/models/isaac/modular_isaac.py | 3 +-- src/transformers/models/isaac/processing_isaac.py | 2 +- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/isaac/image_processing_isaac_fast.py b/src/transformers/models/isaac/image_processing_isaac_fast.py index 789dae85b40a..d0129357bc85 100644 --- a/src/transformers/models/isaac/image_processing_isaac_fast.py +++ b/src/transformers/models/isaac/image_processing_isaac_fast.py @@ -97,8 +97,7 @@ from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict, group_images_by_shape, reorder_images from ...image_utils import ChannelDimension, PILImageResampling from ...processing_utils import Unpack -from ...tokenization_utils import TensorType -from ...utils import auto_docstring +from ...utils import TensorType, auto_docstring # Vision preprocessing constants from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 3920e53f9d09..e9970e3b0f7b 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -105,7 +105,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation.utils import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -1001,6 +1001,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -1092,6 +1093,7 @@ def __init__(self, config: IsaacConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = IsaacRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = IsaacRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None @@ -1659,8 +1661,6 @@ def forward( **kwargs, ) -> tuple | CausalLMOutputWithPast: r""" - Forward pass for conditional generation supporting both standard inputs and TensorStream. - tensor_stream (`TensorStream`, *optional*): Packed multimodal stream (text, vision, audio tokens) that already encodes spatial metadata. When provided, the model derives embeddings, modality masks, and 3D rotary coordinates directly from the stream instead of diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index b5b6723b43e9..9bb4e1414c4d 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -136,8 +136,7 @@ from ...models.qwen3.configuration_qwen3 import Qwen3Config from ...models.qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3PreTrainedModel from ...processing_utils import ProcessorMixin, Unpack -from ...tokenization_utils import TensorType -from ...utils import auto_docstring +from ...utils import auto_docstring, TensorType # Vision preprocessing constants from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 1313eb3db119..45a5ff650322 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -99,7 +99,7 @@ from ...feature_extraction_utils import BatchFeature from ...models.auto.tokenization_auto import AutoTokenizer from ...processing_utils import ProcessorMixin -from ...tokenization_utils import TensorType +from ...utils import TensorType from .configuration_isaac import IsaacConfig From 95296b7ea79a83430b686badd8d061c404581164 Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Thu, 4 Dec 2025 18:32:15 +0400 Subject: [PATCH 47/77] fix: updates for v5 standards (#6) * test: fix mrope expectation * fix: default cache position setting * test: add golden values * fix: add isaac image processor to auto * fix: handle new delegated cache functionality * fix: update deprecated argument name * fix: pop unused args to pass validation * fix: alias rope for compatibility with checkpoints * test: extensive e2e testing of isaac * feat: convert to modular artifacts * test: add vision token to pre-defined messages * fix: hand roll siglip embeddings so it's moved properly to modeling isaac * test: fix separate vision token to isolated message * fix: move implementation into class so implementations are copied to modeling file * test: save logit stats for more interpretable test deviations * WIP tests are passing! * test: fix off by one from not stripping whitespace * test: generation consistency tests passing * test: test-specific dtype and device * test: add vocab size property for testing utils compatibility * test: move tensorstream to correct device * test: linting * fix: unregister * chore: linting * fix: remporarily remove vision embeddings copy for fix copies * test: update tests to load model from revision with updated configs * chore: latest transformers modular convert artifact * chore: make fixup completion artifact * fix: update TensorType import * chore: convert utility artifacts * style(v5): support batch encoding for processor input * test(v5): processor supports batchencoding * chore: convert artifact * style(v5): set and prioritize rope_parameters when loading config * style: drop reference to fast/slow backend for tokenizers * style(v5): mock tokenizer inherits from new PythonBackend base class * fix(v5): rely on and output rope_parameters instead of high level rope_theta * test: isaac config accepts rope_theta and returns rope_parameters * chore: convert utility artifact * style: make fixup lints --- .../models/isaac/configuration_isaac.py | 90 +++++++++--- .../models/isaac/modeling_isaac.py | 20 ++- .../models/isaac/modular_isaac.py | 137 ++++++++++++++---- .../models/isaac/processing_isaac.py | 23 ++- tests/models/isaac/test_modeling_isaac.py | 34 ++++- 5 files changed, 246 insertions(+), 58 deletions(-) diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index 53963cf0e07b..ae7293bcd1e7 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -145,6 +145,13 @@ def __init__( if self._attn_implementation is None: self._attn_implementation = "flash_attention_2" + # Keep legacy and new attention implementation fields in sync + existing_attn_impl = getattr(self, "attn_implementation", None) + if existing_attn_impl is None: + self.attn_implementation = self._attn_implementation + else: + self._attn_implementation = existing_attn_impl + class IsaacConfig(PretrainedConfig): """Configuration class for Isaac multimodal model. @@ -168,6 +175,7 @@ def __init__( ): self._rope_scaling: Optional[dict[str, Any]] = None self._rope_parameters: Optional[dict[str, Any]] = None + resolved_text_config = kwargs.pop("text_config", text_config) if isinstance(resolved_text_config, Qwen3Config): text_config_kwargs = copy.deepcopy(resolved_text_config.to_dict()) @@ -180,22 +188,48 @@ def __init__( text_config_kwargs.update(kwargs) - self.text_config = self.sub_configs["text_config"](**text_config_kwargs) - if not hasattr(self.text_config, "rope_theta"): - rope_theta_override = text_config_kwargs.get("rope_theta", kwargs.get("rope_theta")) - if rope_theta_override is None: - rope_theta_override = getattr(Qwen3Config(), "rope_theta", 10000.0) - self.text_config.rope_theta = rope_theta_override + legacy_rope_theta = text_config_kwargs.pop("rope_theta", None) + incoming_rope_params = text_config_kwargs.pop("rope_parameters", None) + incoming_rope_scaling = text_config_kwargs.pop("rope_scaling", None) + normalized_rope_params = incoming_rope_params or incoming_rope_scaling + if normalized_rope_params is None and legacy_rope_theta is not None: + normalized_rope_params = {"rope_type": "default", "rope_theta": legacy_rope_theta} + elif ( + normalized_rope_params is not None + and legacy_rope_theta is not None + and "rope_theta" not in normalized_rope_params + ): + normalized_rope_params = {**normalized_rope_params, "rope_theta": legacy_rope_theta} + if normalized_rope_params is not None: + text_config_kwargs["rope_parameters"] = normalized_rope_params - super().__init__(**kwargs) + self.text_config = self.sub_configs["text_config"](**text_config_kwargs) - if self._rope_scaling is None: - self._rope_scaling = getattr(self.text_config, "rope_scaling", None) + # Normalize rope parameters on the text config (prefer rope_parameters; alias rope_scaling) + self._rope_parameters = getattr(self.text_config, "rope_parameters", None) + if self._rope_parameters is None: + self._rope_parameters = getattr(self.text_config, "rope_scaling", None) + if self._rope_parameters is None and normalized_rope_params is not None: + self._rope_parameters = normalized_rope_params + if self._rope_parameters is None: + self._rope_parameters = {"rope_type": "default"} + + try: + self.text_config.rope_parameters = self._rope_parameters + except AttributeError: + setattr(self.text_config, "rope_parameters", self._rope_parameters) + if hasattr(self.text_config, "rope_scaling"): + self.text_config.rope_scaling = self._rope_parameters else: - self.text_config.rope_scaling = self._rope_scaling + try: + setattr(self.text_config, "rope_scaling", self._rope_parameters) + except Exception: + pass + + super().__init__(**kwargs) # Keep rope parameters alias in sync with upstream expectations - self._rope_parameters = self._rope_scaling + self._rope_scaling = self._rope_parameters # Mirror frequently accessed Qwen3 attributes at the composite config level for BC. self.tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False) @@ -217,7 +251,6 @@ def __init__( self.initializer_range = self.text_config.initializer_range self.rms_norm_eps = self.text_config.rms_norm_eps self.use_cache = self.text_config.use_cache - self.rope_theta = self.text_config.rope_theta self.attention_bias = getattr(self.text_config, "attention_bias", False) self.attention_dropout = getattr(self.text_config, "attention_dropout", 0.0) @@ -258,21 +291,33 @@ def get_text_config(self, *_, **kwargs) -> Qwen3Config: @property def rope_scaling(self): if hasattr(self, "text_config") and self.text_config is not None: - return getattr(self.text_config, "rope_scaling", None) - return self._rope_scaling + return getattr(self.text_config, "rope_parameters", None) or getattr( + self.text_config, "rope_scaling", None + ) + return self._rope_parameters @rope_scaling.setter def rope_scaling(self, value): + self._rope_parameters = value self._rope_scaling = value if hasattr(self, "text_config") and self.text_config is not None: - self.text_config.rope_scaling = value + try: + self.text_config.rope_parameters = value + except AttributeError: + setattr(self.text_config, "rope_parameters", value) + try: + self.text_config.rope_scaling = value + except AttributeError: + pass @property def rope_parameters(self) -> dict[str, Any] | None: """Alias introduced upstream for rope scaling dictionaries.""" value = self._rope_parameters - if value is None: - value = self.rope_scaling + if value is None and hasattr(self, "text_config") and self.text_config is not None: + value = getattr(self.text_config, "rope_parameters", None) or getattr( + self.text_config, "rope_scaling", None + ) if value is None: return {"rope_type": "default"} return value @@ -280,6 +325,7 @@ def rope_parameters(self) -> dict[str, Any] | None: @rope_parameters.setter def rope_parameters(self, value: dict[str, Any] | None) -> None: self._rope_parameters = value + self._rope_scaling = value self.rope_scaling = value @property @@ -299,9 +345,17 @@ def vision_attn_implementation(self, value: Optional[str]) -> None: def to_dict(self): output = super().to_dict() + rope_params = self.rope_parameters + output["rope_parameters"] = rope_params + output.pop("rope_scaling", None) + output.pop("rope_theta", None) # Ensure nested configs round-trip through dict serialization if hasattr(self, "text_config") and self.text_config is not None: - output["text_config"] = self.text_config.to_dict() + text_config_dict = self.text_config.to_dict() + text_config_dict.pop("rope_theta", None) + text_config_dict.pop("rope_scaling", None) + text_config_dict["rope_parameters"] = rope_params + output["text_config"] = text_config_dict if hasattr(self, "vision_config") and self.vision_config is not None: output["vision_config"] = self.vision_config.to_dict() return output diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index e9970e3b0f7b..c7a898f3464d 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -893,17 +893,29 @@ def __init__(self, config: IsaacConfig, device=None): super().__init__() rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config - rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} + rope_params = ( + getattr(rope_source_cfg, "rope_parameters", None) or getattr(rope_source_cfg, "rope_scaling", None) or {} + ) + legacy_rope_theta = getattr(rope_source_cfg, "rope_theta", None) + if legacy_rope_theta is not None and isinstance(rope_params, dict) and "rope_theta" not in rope_params: + rope_params = {**rope_params, "rope_theta": legacy_rope_theta} - sanitized_scaling = {k: v for k, v in rope_scaling.items() if k not in self.EXTRA_ROPE_KEYS} + sanitized_params = {k: v for k, v in rope_params.items() if k not in self.EXTRA_ROPE_KEYS} config_for_rope = copy.copy(rope_source_cfg) - config_for_rope.rope_scaling = sanitized_scaling if sanitized_scaling else None + config_for_rope.rope_parameters = sanitized_params if sanitized_params else None + if hasattr(config_for_rope, "rope_scaling"): + config_for_rope.rope_scaling = sanitized_params if sanitized_params else None + if hasattr(config_for_rope, "rope_theta"): + try: + delattr(config_for_rope, "rope_theta") + except Exception: + config_for_rope.rope_theta = None init_device = device if device is not None and getattr(device, "type", None) != "meta" else None self._qwen_rotary = qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding(config_for_rope, device=init_device) rotary_half_dim = self._qwen_rotary.inv_freq.shape[0] - self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim) + self.mrope_section = self._resolve_mrope_section(rope_params.get("mrope_section"), rotary_half_dim) self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size @staticmethod diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 9bb4e1414c4d..31d7d21e2287 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -136,7 +136,8 @@ from ...models.qwen3.configuration_qwen3 import Qwen3Config from ...models.qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3PreTrainedModel from ...processing_utils import ProcessorMixin, Unpack -from ...utils import auto_docstring, TensorType +from ...tokenization_utils_base import BatchEncoding +from ...utils import TensorType, auto_docstring # Vision preprocessing constants from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN @@ -202,6 +203,13 @@ def __init__( if self._attn_implementation is None: self._attn_implementation = "flash_attention_2" + # Keep legacy and new attention implementation fields in sync + existing_attn_impl = getattr(self, "attn_implementation", None) + if existing_attn_impl is None: + self.attn_implementation = self._attn_implementation + else: + self._attn_implementation = existing_attn_impl + class IsaacImageProcessorKwargs(ImagesKwargs, total=False): patch_size: Optional[int] @@ -1306,6 +1314,7 @@ def __init__( ): self._rope_scaling: Optional[dict[str, Any]] = None self._rope_parameters: Optional[dict[str, Any]] = None + resolved_text_config = kwargs.pop("text_config", text_config) if isinstance(resolved_text_config, Qwen3Config): text_config_kwargs = copy.deepcopy(resolved_text_config.to_dict()) @@ -1318,22 +1327,48 @@ def __init__( text_config_kwargs.update(kwargs) - self.text_config = self.sub_configs["text_config"](**text_config_kwargs) - if not hasattr(self.text_config, "rope_theta"): - rope_theta_override = text_config_kwargs.get("rope_theta", kwargs.get("rope_theta")) - if rope_theta_override is None: - rope_theta_override = getattr(Qwen3Config(), "rope_theta", 10000.0) - self.text_config.rope_theta = rope_theta_override + legacy_rope_theta = text_config_kwargs.pop("rope_theta", None) + incoming_rope_params = text_config_kwargs.pop("rope_parameters", None) + incoming_rope_scaling = text_config_kwargs.pop("rope_scaling", None) + normalized_rope_params = incoming_rope_params or incoming_rope_scaling + if normalized_rope_params is None and legacy_rope_theta is not None: + normalized_rope_params = {"rope_type": "default", "rope_theta": legacy_rope_theta} + elif ( + normalized_rope_params is not None + and legacy_rope_theta is not None + and "rope_theta" not in normalized_rope_params + ): + normalized_rope_params = {**normalized_rope_params, "rope_theta": legacy_rope_theta} + if normalized_rope_params is not None: + text_config_kwargs["rope_parameters"] = normalized_rope_params - super().__init__(**kwargs) + self.text_config = self.sub_configs["text_config"](**text_config_kwargs) - if self._rope_scaling is None: - self._rope_scaling = getattr(self.text_config, "rope_scaling", None) + # Normalize rope parameters on the text config (prefer rope_parameters; alias rope_scaling) + self._rope_parameters = getattr(self.text_config, "rope_parameters", None) + if self._rope_parameters is None: + self._rope_parameters = getattr(self.text_config, "rope_scaling", None) + if self._rope_parameters is None and normalized_rope_params is not None: + self._rope_parameters = normalized_rope_params + if self._rope_parameters is None: + self._rope_parameters = {"rope_type": "default"} + + try: + self.text_config.rope_parameters = self._rope_parameters + except AttributeError: + setattr(self.text_config, "rope_parameters", self._rope_parameters) + if hasattr(self.text_config, "rope_scaling"): + self.text_config.rope_scaling = self._rope_parameters else: - self.text_config.rope_scaling = self._rope_scaling + try: + setattr(self.text_config, "rope_scaling", self._rope_parameters) + except Exception: + pass + + super().__init__(**kwargs) # Keep rope parameters alias in sync with upstream expectations - self._rope_parameters = self._rope_scaling + self._rope_scaling = self._rope_parameters # Mirror frequently accessed Qwen3 attributes at the composite config level for BC. self.tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False) @@ -1355,7 +1390,6 @@ def __init__( self.initializer_range = self.text_config.initializer_range self.rms_norm_eps = self.text_config.rms_norm_eps self.use_cache = self.text_config.use_cache - self.rope_theta = self.text_config.rope_theta self.attention_bias = getattr(self.text_config, "attention_bias", False) self.attention_dropout = getattr(self.text_config, "attention_dropout", 0.0) @@ -1396,21 +1430,33 @@ def get_text_config(self, *_, **kwargs) -> Qwen3Config: @property def rope_scaling(self): if hasattr(self, "text_config") and self.text_config is not None: - return getattr(self.text_config, "rope_scaling", None) - return self._rope_scaling + return getattr(self.text_config, "rope_parameters", None) or getattr( + self.text_config, "rope_scaling", None + ) + return self._rope_parameters @rope_scaling.setter def rope_scaling(self, value): + self._rope_parameters = value self._rope_scaling = value if hasattr(self, "text_config") and self.text_config is not None: - self.text_config.rope_scaling = value + try: + self.text_config.rope_parameters = value + except AttributeError: + setattr(self.text_config, "rope_parameters", value) + try: + self.text_config.rope_scaling = value + except AttributeError: + pass @property def rope_parameters(self) -> dict[str, Any] | None: """Alias introduced upstream for rope scaling dictionaries.""" value = self._rope_parameters - if value is None: - value = self.rope_scaling + if value is None and hasattr(self, "text_config") and self.text_config is not None: + value = getattr(self.text_config, "rope_parameters", None) or getattr( + self.text_config, "rope_scaling", None + ) if value is None: return {"rope_type": "default"} return value @@ -1418,6 +1464,7 @@ def rope_parameters(self) -> dict[str, Any] | None: @rope_parameters.setter def rope_parameters(self, value: dict[str, Any] | None) -> None: self._rope_parameters = value + self._rope_scaling = value self.rope_scaling = value @property @@ -1437,9 +1484,17 @@ def vision_attn_implementation(self, value: Optional[str]) -> None: def to_dict(self): output = super().to_dict() + rope_params = self.rope_parameters + output["rope_parameters"] = rope_params + output.pop("rope_scaling", None) + output.pop("rope_theta", None) # Ensure nested configs round-trip through dict serialization if hasattr(self, "text_config") and self.text_config is not None: - output["text_config"] = self.text_config.to_dict() + text_config_dict = self.text_config.to_dict() + text_config_dict.pop("rope_theta", None) + text_config_dict.pop("rope_scaling", None) + text_config_dict["rope_parameters"] = rope_params + output["text_config"] = text_config_dict if hasattr(self, "vision_config") and self.vision_config is not None: output["vision_config"] = self.vision_config.to_dict() return output @@ -1599,8 +1654,26 @@ def __call__( Returns: BatchFeature with input_ids and tensor_stream """ - # Normalize inputs to lists - if isinstance(text, str): + # Normalize inputs to lists and support BatchEncoding (v5 apply_chat_template default) + encoding_input = None + if isinstance(text, BatchEncoding): + encoding_input = text + elif isinstance(text, dict) and "input_ids" in text: + encoding_input = BatchEncoding(text) + + if encoding_input is not None: + input_ids_field = encoding_input["input_ids"] + if isinstance(input_ids_field, torch.Tensor): + ids = input_ids_field + else: + ids = torch.tensor(input_ids_field) + if ids.ndim == 1: + ids = ids.unsqueeze(0) + if ids.size(0) != 1: + raise ValueError("IsaacProcessor currently supports batch_size=1 for chat templates.") + decoded_text = self.tokenizer.decode(ids[0].tolist(), skip_special_tokens=False) + texts = [decoded_text] + elif isinstance(text, str): texts = [text] else: texts = text @@ -1683,17 +1756,29 @@ def __init__(self, config: IsaacConfig, device=None): super().__init__() rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config - rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} + rope_params = ( + getattr(rope_source_cfg, "rope_parameters", None) or getattr(rope_source_cfg, "rope_scaling", None) or {} + ) + legacy_rope_theta = getattr(rope_source_cfg, "rope_theta", None) + if legacy_rope_theta is not None and isinstance(rope_params, dict) and "rope_theta" not in rope_params: + rope_params = {**rope_params, "rope_theta": legacy_rope_theta} - sanitized_scaling = {k: v for k, v in rope_scaling.items() if k not in self.EXTRA_ROPE_KEYS} + sanitized_params = {k: v for k, v in rope_params.items() if k not in self.EXTRA_ROPE_KEYS} config_for_rope = copy.copy(rope_source_cfg) - config_for_rope.rope_scaling = sanitized_scaling if sanitized_scaling else None + config_for_rope.rope_parameters = sanitized_params if sanitized_params else None + if hasattr(config_for_rope, "rope_scaling"): + config_for_rope.rope_scaling = sanitized_params if sanitized_params else None + if hasattr(config_for_rope, "rope_theta"): + try: + delattr(config_for_rope, "rope_theta") + except Exception: + config_for_rope.rope_theta = None init_device = device if device is not None and getattr(device, "type", None) != "meta" else None self._qwen_rotary = qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding(config_for_rope, device=init_device) rotary_half_dim = self._qwen_rotary.inv_freq.shape[0] - self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim) + self.mrope_section = self._resolve_mrope_section(rope_params.get("mrope_section"), rotary_half_dim) self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size @staticmethod @@ -2201,8 +2286,6 @@ def forward( **kwargs, ) -> tuple | CausalLMOutputWithPast: r""" - Forward pass for conditional generation supporting both standard inputs and TensorStream. - tensor_stream (`TensorStream`, *optional*): Packed multimodal stream (text, vision, audio tokens) that already encodes spatial metadata. When provided, the model derives embeddings, modality masks, and 3D rotary coordinates directly from the stream instead of diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 45a5ff650322..af13ac6a0f48 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -99,6 +99,7 @@ from ...feature_extraction_utils import BatchFeature from ...models.auto.tokenization_auto import AutoTokenizer from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import BatchEncoding from ...utils import TensorType from .configuration_isaac import IsaacConfig @@ -257,8 +258,26 @@ def __call__( Returns: BatchFeature with input_ids and tensor_stream """ - # Normalize inputs to lists - if isinstance(text, str): + # Normalize inputs to lists and support BatchEncoding (v5 apply_chat_template default) + encoding_input = None + if isinstance(text, BatchEncoding): + encoding_input = text + elif isinstance(text, dict) and "input_ids" in text: + encoding_input = BatchEncoding(text) + + if encoding_input is not None: + input_ids_field = encoding_input["input_ids"] + if isinstance(input_ids_field, torch.Tensor): + ids = input_ids_field + else: + ids = torch.tensor(input_ids_field) + if ids.ndim == 1: + ids = ids.unsqueeze(0) + if ids.size(0) != 1: + raise ValueError("IsaacProcessor currently supports batch_size=1 for chat templates.") + decoded_text = self.tokenizer.decode(ids[0].tolist(), skip_special_tokens=False) + texts = [decoded_text] + elif isinstance(text, str): texts = [text] else: texts = text diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py index 2f4405a04662..02c1ab114f2f 100644 --- a/tests/models/isaac/test_modeling_isaac.py +++ b/tests/models/isaac/test_modeling_isaac.py @@ -15,7 +15,7 @@ IsaacConfig, IsaacForConditionalGeneration, IsaacModel, - PreTrainedTokenizer, + PythonBackend, is_torch_available, ) from transformers.models.isaac.configuration_isaac import IsaacVisionConfig @@ -166,7 +166,6 @@ def tokenizer(isaac_reference_checkpoint): AutoTokenizer, isaac_reference_checkpoint, trust_remote_code=True, - use_fast=False, ) @@ -254,7 +253,7 @@ def _reference_checkpoint_or_skip(): return MODEL_ID -class SimpleIsaacTokenizer(PreTrainedTokenizer): +class SimpleIsaacTokenizer(PythonBackend): vocab_files_names = {} model_input_names = ["input_ids"] @@ -272,7 +271,7 @@ def __init__(self): eos_token="", pad_token="", unk_token="", - additional_special_tokens=[""], + extra_special_tokens=[""], model_max_length=512, ) self.chat_template = ( @@ -423,9 +422,6 @@ def __init__( # Keep the same multi-RoPE setup as the reference checkpoints but shrink the # sections so they sum to the rotary half-dimension (4) of this tiny test model. "rope_parameters": {"rope_type": "default", "mrope_section": [2, 1, 1], "mrope_interleaved": True}, - # Qwen3 config expects `rope_theta` to be present on the text sub-config, so we - # set it explicitly to mimic real checkpoints and keep attribute mirroring working. - "rope_theta": 10000, "tie_word_embeddings": True, } @@ -533,6 +529,16 @@ def test_isaac_config_extends_qwen3_defaults(isaac_tiny_config): assert isaac_tiny_config.vision_token == "" +def test_isaac_config_migrates_legacy_rope_theta(): + cfg = IsaacConfig(text_config={"rope_theta": 12345}) + assert cfg.rope_parameters.get("rope_theta") == 12345 + assert cfg.rope_parameters.get("rope_type") == "default" + serialized = cfg.to_dict() + assert "rope_theta" not in serialized + assert "rope_theta" not in serialized.get("text_config", {}) + assert serialized["rope_parameters"].get("rope_theta") == 12345 + + @require_torch def test_isaac_for_conditional_generation_initialization(isaac_tiny_config): model = IsaacForConditionalGeneration(isaac_tiny_config) @@ -588,6 +594,20 @@ def test_isaac_processor_text_only_round_trip(isaac_processor): assert outputs["input_ids"].shape[0] == 1 +@require_torch +@tensorstream_required +def test_isaac_processor_accepts_batchencoding_chat_template(isaac_processor): + messages = [{"role": "user", "content": "Hello, how are you?"}] + batch_encoding = isaac_processor.apply_chat_template(messages, add_generation_prompt=True) + + outputs = isaac_processor(text=batch_encoding, images=None, return_tensors="pt") + + assert "input_ids" in outputs + assert "tensor_stream" in outputs + assert isinstance(outputs["tensor_stream"], TensorStream) + assert outputs["input_ids"].shape[0] == 1 + + @require_torch @require_vision @tensorstream_required From 1cb3c4b2f8d53e32a80651b8efc9704b4163212a Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Tue, 9 Dec 2025 11:32:57 +0400 Subject: [PATCH 48/77] feat: guard perceptron imports (#7) * style: drop unused func * feat: add perceptron availability check * test: use new perceptron availability utility * feat: guard imports in modular file * test: update tensorstream requirement tests * chore: convert script artifacts * test: update is_offline_mode import --- .../models/isaac/modeling_isaac.py | 33 ++++++----- .../models/isaac/modular_isaac.py | 59 +++++++++++-------- .../models/isaac/processing_isaac.py | 19 +++++- src/transformers/utils/import_utils.py | 8 +++ tests/models/isaac/test_modeling_isaac.py | 34 ++++++----- 5 files changed, 96 insertions(+), 57 deletions(-) diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index c7a898f3464d..05f4f40e3f75 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -95,12 +95,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from perceptron.tensorstream.ops import ( - compute_mrope_pos_tensor, - modality_mask, - reconstruct_tensor_stream_from_compact_dict, -) -from perceptron.tensorstream.tensorstream import TensorStream, TextType, VisionType, group_streams from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache @@ -117,11 +111,29 @@ from ...processing_utils import Unpack from ...utils import auto_docstring from ...utils.generic import TransformersKwargs, can_return_tuple -from ...utils.import_utils import is_torchdynamo_compiling +from ...utils.import_utils import is_perceptron_available, is_torchdynamo_compiling from ..qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling from .configuration_isaac import IsaacConfig, IsaacVisionConfig +if is_perceptron_available(): + from perceptron.tensorstream.ops import ( + compute_mrope_pos_tensor, + modality_mask, + reconstruct_tensor_stream_from_compact_dict, + ) + from perceptron.tensorstream.tensorstream import TensorStream, TextType, VisionType, group_streams +else: + ts_slice = None + Event = None + Stream = None + TensorStream = None + TextType = None + VisionType = None + create_stream = None + group_streams = None + + class IsaacVisionEmbeddings(nn.Module): """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences.""" @@ -270,13 +282,6 @@ def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor return torch.cat(output_chunks, dim=0) -def _max_from_cu(cu: Optional[torch.Tensor], fallback: int) -> int: - """Helper to compute max sequence length from cumulative sequence lengths.""" - if cu is None or len(cu) < 2: - return fallback - return int((cu[1:] - cu[:-1]).max().item()) - - def build_document_attention_mask( cu_seqlens: Optional[torch.Tensor], total_tokens: int, diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 31d7d21e2287..123e96a1fca9 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -93,24 +93,39 @@ import torch import torch.nn as nn import torch.nn.functional as F -from perceptron.tensorstream.ops import ( - compute_mrope_pos_tensor, - modality_mask, - reconstruct_tensor_stream_from_compact_dict, - tensor_stream_token_view, -) -from perceptron.tensorstream.ops import ( - slice as ts_slice, -) -from perceptron.tensorstream.tensorstream import ( - Event, - Stream, - TensorStream, - TextType, - VisionType, - create_stream, - group_streams, -) + +from ...utils.import_utils import is_perceptron_available, is_torchdynamo_compiling + + +if is_perceptron_available(): + from perceptron.tensorstream.ops import ( + compute_mrope_pos_tensor, + modality_mask, + reconstruct_tensor_stream_from_compact_dict, + tensor_stream_token_view, + ) + from perceptron.tensorstream.ops import ( + slice as ts_slice, + ) + from perceptron.tensorstream.tensorstream import ( + Event, + Stream, + TensorStream, + TextType, + VisionType, + create_stream, + group_streams, + ) +else: + ts_slice = None + Event = None + Stream = None + TensorStream = None + TextType = None + VisionType = None + create_stream = None + group_streams = None + from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...configuration_utils import PretrainedConfig, layer_type_validation @@ -143,7 +158,6 @@ from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD from ...utils.generic import TransformersKwargs, can_return_tuple -from ...utils.import_utils import is_torchdynamo_compiling from ..qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling from ..siglip2.configuration_siglip2 import Siglip2VisionConfig from ..siglip2.modeling_siglip2 import ( @@ -460,13 +474,6 @@ def _preprocess( ) -def _max_from_cu(cu: Optional[torch.Tensor], fallback: int) -> int: - """Helper to compute max sequence length from cumulative sequence lengths.""" - if cu is None or len(cu) < 2: - return fallback - return int((cu[1:] - cu[:-1]).max().item()) - - def build_document_attention_mask( cu_seqlens: Optional[torch.Tensor], total_tokens: int, diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index af13ac6a0f48..ffd583550927 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -92,18 +92,31 @@ import PIL.Image import torch -from perceptron.tensorstream.ops import slice as ts_slice -from perceptron.tensorstream.ops import tensor_stream_token_view -from perceptron.tensorstream.tensorstream import Event, Stream, TensorStream, TextType, VisionType, create_stream from ...feature_extraction_utils import BatchFeature from ...models.auto.tokenization_auto import AutoTokenizer from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import BatchEncoding from ...utils import TensorType +from ...utils.import_utils import is_perceptron_available from .configuration_isaac import IsaacConfig +if is_perceptron_available(): + from perceptron.tensorstream.ops import slice as ts_slice + from perceptron.tensorstream.ops import tensor_stream_token_view + from perceptron.tensorstream.tensorstream import Event, Stream, TensorStream, TextType, VisionType, create_stream +else: + ts_slice = None + Event = None + Stream = None + TensorStream = None + TextType = None + VisionType = None + create_stream = None + group_streams = None + + # ============================================================================ # Processor Components # ============================================================================ diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index af34c54ed305..8ad20ee2ff35 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -698,6 +698,14 @@ def is_mamba_2_ssm_available() -> bool: return is_torch_cuda_available() and is_available and version.parse(mamba_ssm_version) >= version.parse("2.0.4") +@lru_cache +def is_perceptron_available() -> bool: + if is_torch_cuda_available() and _is_package_available("perceptron"): + return True + else: + return False + + @lru_cache def is_flash_linear_attention_available(): is_available, fla_version = _is_package_available("fla", return_version=True) diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py index 02c1ab114f2f..bcf928c882e9 100644 --- a/tests/models/isaac/test_modeling_isaac.py +++ b/tests/models/isaac/test_modeling_isaac.py @@ -8,6 +8,7 @@ from pathlib import Path import pytest +from huggingface_hub import is_offline_mode from transformers import ( AutoProcessor, @@ -23,7 +24,8 @@ from transformers.models.isaac.modeling_isaac import IsaacVisionAttention from transformers.models.isaac.processing_isaac import IsaacProcessor from transformers.testing_utils import require_torch, require_vision, slow, torch_device -from transformers.utils import is_offline_mode, is_vision_available +from transformers.utils import is_vision_available +from transformers.utils.import_utils import is_perceptron_available if is_vision_available(): @@ -38,14 +40,14 @@ if is_torch_available(): import torch -try: +if is_perceptron_available(): from perceptron.tensorstream.ops import modality_mask, role_mask, tensor_stream_token_view from perceptron.tensorstream.tensorstream import TensorStream -except Exception: +else: TensorStream = None -tensorstream_required = pytest.mark.skipif(TensorStream is None, reason="TensorStream backend is not available") +require_tensorstream = pytest.mark.skipif(TensorStream is None, reason="TensorStream backend is not available") MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1-Base") MODEL_REVISION = os.environ.get("ISAAC_TEST_MODEL_REVISION", "refs/pr/3") or None @@ -479,6 +481,7 @@ def setUp(self): def test_config(self): self.config_tester.run_common_tests() + @require_tensorstream def test_model_forward(self): config, input_ids, attention_mask, _ = self.model_tester.prepare_config_and_inputs() model = IsaacModel(config) @@ -492,6 +495,7 @@ def test_model_forward(self): (self.model_tester.batch_size, self.model_tester.seq_length, config.hidden_size), ) + @require_tensorstream def test_for_conditional_generation(self): config, input_ids, attention_mask, labels = self.model_tester.prepare_config_and_inputs() model = IsaacForConditionalGeneration(config) @@ -540,6 +544,7 @@ def test_isaac_config_migrates_legacy_rope_theta(): @require_torch +@require_tensorstream def test_isaac_for_conditional_generation_initialization(isaac_tiny_config): model = IsaacForConditionalGeneration(isaac_tiny_config) model.to(torch_device) @@ -555,6 +560,7 @@ def test_isaac_for_conditional_generation_initialization(isaac_tiny_config): @require_torch +@require_tensorstream def test_isaac_for_conditional_generation_loss_and_generate_flag(isaac_tiny_config): model = IsaacForConditionalGeneration(isaac_tiny_config).to(torch_device) assert model.can_generate() @@ -571,7 +577,7 @@ def test_isaac_for_conditional_generation_loss_and_generate_flag(isaac_tiny_conf @require_torch @require_vision -@tensorstream_required +@require_tensorstream def test_isaac_processor_matches_config_defaults(isaac_processor, isaac_tiny_config): assert isaac_processor.vision_token == isaac_tiny_config.vision_token assert isaac_processor.max_sequence_length == isaac_tiny_config.max_sequence_length @@ -582,7 +588,7 @@ def test_isaac_processor_matches_config_defaults(isaac_processor, isaac_tiny_con @require_torch @require_vision -@tensorstream_required +@require_tensorstream def test_isaac_processor_text_only_round_trip(isaac_processor): messages = [{"role": "user", "content": "Hello, how are you?"}] prompt = isaac_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) @@ -595,7 +601,7 @@ def test_isaac_processor_text_only_round_trip(isaac_processor): @require_torch -@tensorstream_required +@require_tensorstream def test_isaac_processor_accepts_batchencoding_chat_template(isaac_processor): messages = [{"role": "user", "content": "Hello, how are you?"}] batch_encoding = isaac_processor.apply_chat_template(messages, add_generation_prompt=True) @@ -610,7 +616,7 @@ def test_isaac_processor_accepts_batchencoding_chat_template(isaac_processor): @require_torch @require_vision -@tensorstream_required +@require_tensorstream def test_isaac_processor_with_single_image(isaac_processor): vision_token = isaac_processor.vision_token text = f"Look at this {vision_token} and describe it." @@ -623,7 +629,7 @@ def test_isaac_processor_with_single_image(isaac_processor): @require_torch @require_vision -@tensorstream_required +@require_tensorstream def test_isaac_processor_with_multiple_images(isaac_processor): vision_token = isaac_processor.vision_token text = f"First {vision_token} then {vision_token}" @@ -636,7 +642,7 @@ def test_isaac_processor_with_multiple_images(isaac_processor): @require_torch @require_vision -@tensorstream_required +@require_tensorstream def test_isaac_processor_error_on_image_mismatch(isaac_processor): vision_token = isaac_processor.vision_token text = f"{vision_token} {vision_token}" @@ -648,7 +654,7 @@ def test_isaac_processor_error_on_image_mismatch(isaac_processor): @require_torch @require_vision -@tensorstream_required +@require_tensorstream def test_isaac_processor_consistent_tensor_stream_types(isaac_processor): text_only = "Simple question?" text_with_image = f"Describe this {isaac_processor.vision_token}" @@ -664,7 +670,7 @@ def test_isaac_processor_consistent_tensor_stream_types(isaac_processor): @require_torch @require_vision -@tensorstream_required +@require_tensorstream def test_isaac_generation_with_tensor_stream(isaac_processor, isaac_tiny_config): model = IsaacForConditionalGeneration(isaac_tiny_config).to(torch_device) model.eval() @@ -694,7 +700,7 @@ def test_isaac_generation_with_tensor_stream(isaac_processor, isaac_tiny_config) @require_torch @slow -@tensorstream_required +@require_tensorstream def test_isaac_checkpoint_hashes(isaac_reference_model): isaac_reference_model = isaac_reference_model.to("cpu") expected_hashes = _load_expected_hashes() @@ -762,7 +768,7 @@ def create_isaac_processor( @require_torch @require_vision @slow -@tensorstream_required +@require_tensorstream def test_hf_generate_vs_training_generate_logits(isaac_reference_model, isaac_reference_processor): device = "cuda" dtype = torch.bfloat16 From d439313f2d4e8d6c57fcd34b4cc0376d70480248 Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Tue, 9 Dec 2025 12:13:48 +0400 Subject: [PATCH 49/77] fix: guard PIL import (#8) * fix: make vision import conditional * chore: convert artifact * docs: fixup artifact --- docs/source/en/model_doc/isaac.md | 2 +- src/transformers/models/isaac/modular_isaac.py | 7 +++++-- src/transformers/models/isaac/processing_isaac.py | 7 +++++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/docs/source/en/model_doc/isaac.md b/docs/source/en/model_doc/isaac.md index 5b991f166acf..bd2f11908d0e 100644 --- a/docs/source/en/model_doc/isaac.md +++ b/docs/source/en/model_doc/isaac.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on {release_date} and added to Hugging Face Transformers on 2025-11-17.* +*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-08.* *This model was added to Hugging Face Transformers in 2025.*
diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 123e96a1fca9..c6a3b2457c94 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -89,12 +89,15 @@ from collections.abc import Sequence from typing import Any, Optional, Union -import PIL.Image import torch import torch.nn as nn import torch.nn.functional as F -from ...utils.import_utils import is_perceptron_available, is_torchdynamo_compiling +from ...utils.import_utils import is_perceptron_available, is_torchdynamo_compiling, is_vision_available + + +if is_vision_available(): + import PIL.Image if is_perceptron_available(): diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index ffd583550927..749b6e7f5f67 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -90,7 +90,6 @@ import re from typing import Optional, Union -import PIL.Image import torch from ...feature_extraction_utils import BatchFeature @@ -98,10 +97,14 @@ from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import BatchEncoding from ...utils import TensorType -from ...utils.import_utils import is_perceptron_available +from ...utils.import_utils import is_perceptron_available, is_vision_available from .configuration_isaac import IsaacConfig +if is_vision_available(): + import PIL.Image + + if is_perceptron_available(): from perceptron.tensorstream.ops import slice as ts_slice from perceptron.tensorstream.ops import tensor_stream_token_view From e2fe9f96899c5212bf49364947cfe3420a76c27f Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Tue, 9 Dec 2025 13:04:33 +0400 Subject: [PATCH 50/77] fix: guard perceptron PIL and torch imports for CI (#9) * style: drop unused func * feat: add perceptron availability check * test: use new perceptron availability utility * feat: guard imports in modular file * test: update tensorstream requirement tests * chore: convert script artifacts * test: update is_offline_mode import * fix: make vision import conditional * chore: convert artifact * docs: fixup artifact * fix: guard torch import * fix: guard PIL import --- docs/source/en/model_doc/isaac.md | 2 +- .../isaac/image_processing_isaac_fast.py | 9 ++++--- .../models/isaac/modeling_isaac.py | 16 ++++++++---- .../models/isaac/modular_isaac.py | 25 +++++++++++++------ .../models/isaac/processing_isaac.py | 18 +++++++------ 5 files changed, 46 insertions(+), 24 deletions(-) diff --git a/docs/source/en/model_doc/isaac.md b/docs/source/en/model_doc/isaac.md index bd2f11908d0e..ec57da58250e 100644 --- a/docs/source/en/model_doc/isaac.md +++ b/docs/source/en/model_doc/isaac.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-08.* +*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-09.* *This model was added to Hugging Face Transformers in 2025.*
diff --git a/src/transformers/models/isaac/image_processing_isaac_fast.py b/src/transformers/models/isaac/image_processing_isaac_fast.py index d0129357bc85..634460a14199 100644 --- a/src/transformers/models/isaac/image_processing_isaac_fast.py +++ b/src/transformers/models/isaac/image_processing_isaac_fast.py @@ -90,9 +90,6 @@ from collections.abc import Sequence from typing import Any, Optional, Union -import torch -import torch.nn.functional as F - from ...feature_extraction_utils import BatchFeature from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict, group_images_by_shape, reorder_images from ...image_utils import ChannelDimension, PILImageResampling @@ -102,9 +99,15 @@ # Vision preprocessing constants from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD +from ...utils.import_utils import is_torch_available from .image_processing_isaac import IsaacImageProcessorKwargs +if is_torch_available(): + import torch + import torch.nn.functional as F + + def get_scaled_image_size( scale: float, original_size: int, diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 05f4f40e3f75..2982a3512f79 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -92,10 +92,6 @@ from collections.abc import Callable from typing import Any, Optional -import torch -import torch.nn as nn -import torch.nn.functional as F - from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation.utils import GenerationMixin @@ -111,11 +107,21 @@ from ...processing_utils import Unpack from ...utils import auto_docstring from ...utils.generic import TransformersKwargs, can_return_tuple -from ...utils.import_utils import is_perceptron_available, is_torchdynamo_compiling +from ...utils.import_utils import ( + is_perceptron_available, + is_torch_available, + is_torchdynamo_compiling, +) from ..qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling from .configuration_isaac import IsaacConfig, IsaacVisionConfig +if is_torch_available(): + import torch + import torch.nn as nn + import torch.nn.functional as F + + if is_perceptron_available(): from perceptron.tensorstream.ops import ( compute_mrope_pos_tensor, diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index c6a3b2457c94..7c3b5010e675 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -89,15 +89,24 @@ from collections.abc import Sequence from typing import Any, Optional, Union -import torch -import torch.nn as nn -import torch.nn.functional as F +from ...utils.import_utils import ( + is_perceptron_available, + is_torch_available, + is_torchdynamo_compiling, + is_vision_available, +) + -from ...utils.import_utils import is_perceptron_available, is_torchdynamo_compiling, is_vision_available +if is_torch_available(): + import torch + import torch.nn as nn + import torch.nn.functional as F if is_vision_available(): - import PIL.Image + from PIL.Image import Image +else: + Image = None if is_perceptron_available(): @@ -1604,7 +1613,7 @@ def __init__( def build_event_stream_simple( self, text: str, - images: Optional[list[PIL.Image.Image]] = None, + images: Optional[list[Image]] = None, ) -> Stream: events = [] # Process text and images @@ -1650,7 +1659,7 @@ def build_event_stream_simple( def __call__( self, text: Union[str, list[str]], - images: Optional[Union[PIL.Image.Image, list[PIL.Image.Image]]] = None, + images: Optional[Union[Image, list[Image]]] = None, return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, **kwargs, ) -> BatchFeature: @@ -1689,7 +1698,7 @@ def __call__( texts = text if images is not None: - if isinstance(images, PIL.Image.Image): + if isinstance(images, Image): images_list = [images] else: images_list = images diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 749b6e7f5f67..cf5dbd35dc5c 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -90,19 +90,23 @@ import re from typing import Optional, Union -import torch - from ...feature_extraction_utils import BatchFeature from ...models.auto.tokenization_auto import AutoTokenizer from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import BatchEncoding from ...utils import TensorType -from ...utils.import_utils import is_perceptron_available, is_vision_available +from ...utils.import_utils import is_perceptron_available, is_torch_available, is_vision_available from .configuration_isaac import IsaacConfig +if is_torch_available(): + import torch + + if is_vision_available(): - import PIL.Image + from PIL.Image import Image +else: + Image = None if is_perceptron_available(): @@ -214,7 +218,7 @@ def __init__( def build_event_stream_simple( self, text: str, - images: Optional[list[PIL.Image.Image]] = None, + images: Optional[list[Image]] = None, ) -> Stream: events = [] # Process text and images @@ -260,7 +264,7 @@ def build_event_stream_simple( def __call__( self, text: Union[str, list[str]], - images: Optional[Union[PIL.Image.Image, list[PIL.Image.Image]]] = None, + images: Optional[Union[Image, list[Image]]] = None, return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, **kwargs, ) -> BatchFeature: @@ -299,7 +303,7 @@ def __call__( texts = text if images is not None: - if isinstance(images, PIL.Image.Image): + if isinstance(images, Image): images_list = [images] else: images_list = images From 257f47c80fa85c2935787ec3a75f2b471a1dc401 Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Fri, 12 Dec 2025 20:58:18 +0400 Subject: [PATCH 51/77] review revisions (#10) * style: drop unused function * fix: do not expose gradient checkpoint flag * refactor: move test files to fixtures direcrtory, refer to them in test file * style: remove redundant routing * feat: isaac model forward autodoc + check inputs * docs: add docstring for isaac forward * style: remove conditions on pixel shuffle scale * style: explicitly use torch dtype's min value to avoid float casting * style: remove unnecessary forward() configuration handling * refactor: move to updated masking API * chore: license + module docstring for test * docs: update license * refactor: doc mask rework wip 2 doc mask refactor finished callable rework * style: pass all args down for interface compatibility * style: remove extra cast on attention implementation * test: update tests * test: remove outdated tests * refactor: isolated vision embedding class * refactor: simplify attention flow to prepare for proper handling * refactor: simplify config * chore: convert artifacts * test: add text only test * style: cache * tests: expand integration testing --- docs/source/en/model_doc/isaac.md | 2 +- .../models/isaac/configuration_isaac.py | 226 +----- .../models/isaac/image_processing_isaac.py | 89 +-- .../isaac/image_processing_isaac_fast.py | 105 +-- .../models/isaac/modeling_isaac.py | 492 +++++-------- .../models/isaac/modular_isaac.py | 667 +++++------------- .../models/isaac/processing_isaac.py | 112 +-- .../isaac/isaac_checkpoint_hashes.json | 0 .../isaac/isaac_generation_golden.json | 0 tests/models/isaac/test_modeling_isaac.py | 436 +++++++++--- 10 files changed, 763 insertions(+), 1366 deletions(-) rename tests/{models => fixtures}/isaac/isaac_checkpoint_hashes.json (100%) rename tests/{models => fixtures}/isaac/isaac_generation_golden.json (100%) diff --git a/docs/source/en/model_doc/isaac.md b/docs/source/en/model_doc/isaac.md index ec57da58250e..7cd17f2a7f3d 100644 --- a/docs/source/en/model_doc/isaac.md +++ b/docs/source/en/model_doc/isaac.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-09.* +*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-12.* *This model was added to Hugging Face Transformers in 2025.*
diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index ae7293bcd1e7..38e785ef1528 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -4,87 +4,20 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# Copyright (c) 2024 Perceptron, Inc. All rights reserved. -# Perceptron, Inc. Non-Production License (2024-01-01) - - -### 1. Scope and acceptance - -# **1.1. Scope of the Agreement.** -# This Agreement applies to any use, modification, or Distribution of any Perceptron Model by You, regardless of the source You obtained a copy of such Perceptron Model. -# -# **1.2. Acceptance.** By accessing, using, modifying, Distributing a Perceptron Model, or by creating, using or distributing a Derivative of the Perceptron Model, You agree to be bound by this Agreement. -# -# **1.3. Acceptance on behalf of a third-party.** If You accept this Agreement on behalf of Your employer or another person or entity, You warrant and represent that You have the authority to act and accept this Agreement on their behalf. In such a case, the word โ€œYouโ€ in this Agreement will refer to Your employer or such other person or entity. -# -# ## 2. License -# **2.1. Grant of rights.** Subject to Section 3 below, Perceptron, Inc. hereby grants You a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable, limited license to use, copy, modify, and Distribute under the conditions provided in Section 2.2 below, the Perceptron Model and any Derivatives made by or for Perceptron, Inc. and to create Derivatives of the Perceptron Model. -# -# **2.2. Distribution of Perceptron Model and Derivatives made by or for Perceptron, Inc..** Subject to Section 3 below, You may Distribute copies of the Perceptron Model and/or Derivatives made by or for Perceptron, Inc., under the following conditions: -# - You must make available a copy of this Agreement to third-party recipients of the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. you Distribute, it being specified that any rights to use the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. shall be directly granted by Perceptron, Inc. to said third-party recipients pursuant to the Perceptron, Inc. Non-Production License agreement executed between these parties; -# - You must retain in all copies of the Perceptron Models the following attribution notice within a โ€œNoticeโ€ text file distributed as part of such copies: โ€œLicensed by Perceptron, Inc. under the Perceptron, Inc. Non-Production Licenseโ€. -# -# **2.3. Distribution of Derivatives made by or for You.** Subject to Section 3 below, You may Distribute any Derivatives made by or for You under additional or different terms and conditions, provided that: -# - In any event, the use and modification of Perceptron Model and/or Derivatives made by or for Perceptron, Inc. shall remain governed by the terms and conditions of this Agreement; -# - You include in any such Derivatives made by or for You prominent notices stating that You modified the concerned Perceptron Model; and -# - Any terms and conditions You impose on any third-party recipients relating to Derivatives made by or for You shall neither limit such third-party recipientsโ€™ use of the Perceptron Model or any Derivatives made by or for Perceptron, Inc. in accordance with the Perceptron, Inc. Non-Production License nor conflict with any of its terms and conditions. -# -# ## 3. Limitations -# **3.1. Misrepresentation.** You must not misrepresent or imply, through any means, that the Derivatives made by or for You and/or any modified version of the Perceptron Model You Distribute under your name and responsibility is an official product of Perceptron, Inc. or has been endorsed, approved or validated by Perceptron, Inc., unless You are authorized by Us to do so in writing. -# -# **3.2. Usage Limitation** -# - You shall only use the Perceptron Models and Derivatives (whether or not created by Perceptron, Inc.) for testing, research, Personal, or evaluation purposes in Non-Production Environments; -# - Subject to the foregoing, You shall not supply the Perceptron Models or Derivatives in the course of a commercial activity, whether in return for payment or free of charge, in any medium or form, including but not limited to through a hosted or managed service (e.g. SaaS, cloud instances, etc.), or behind a software layer. -# -# **3.3. Usage not permitted under this Agreement.** If You want to use a Perceptron Model or a Derivative for any purpose that is not expressly authorized under this Agreement, You must request a license from Perceptron, Inc., which Perceptron, Inc. may grant to You in Perceptron, Inc.โ€™s sole discretion. Please contact Perceptron, Inc. at the following e-mail address if You want to discuss such a license: sales@perceptron.inc -# -# ## 4. Intellectual Property -# **4.1. Trademarks.** No trademark licenses are granted under this Agreement, and in connection with the Perceptron Models, You may not use any name or mark owned by or associated with Perceptron, Inc. or any of its affiliates, except (i) as required for reasonable and customary use in describing and Distributing the Perceptron Models and Derivatives made by or for Perceptron, Inc. and (ii) for attribution purposes as required by this Agreement. -# -# **4.2. Outputs.** We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs You generate and their subsequent uses in accordance with this Agreement. -# -# **4.3. Derivatives.** By entering into this Agreement, You accept that any Derivatives that You may create or that may be created for You shall be subject to the restrictions set out in Section 3 of this Agreement. -# -# # 5. Liability -# **5.1. Limitation of liability.** In no event, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall Perceptron, Inc. be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Perceptron Models and Derivatives (including but not limited to damages for loss of data, loss of goodwill, loss of expected profit or savings, work stoppage, computer failure or malfunction, or any damage caused by malware or security breaches), even if Perceptron, Inc. has been advised of the possibility of such damages. -# -# **5.2. Indemnification.** You agree to indemnify and hold harmless Perceptron, Inc. from and against any claims, damages, or losses arising out of or related to Your use or Distribution of the Perceptron Models and Derivatives. -# -# ## 6. Warranty -# **6.1. Disclaimer.** Unless required by applicable law or agreed to in writing, Perceptron, Inc. provides the Perceptron Models and Derivatives on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. Perceptron, Inc. does not represent nor warrant that the Perceptron Models and Derivatives will be error-free, meet Your or any third partyโ€™s requirements, be secure or will allow You or any third party to achieve any kind of result or generate any kind of content. You are solely responsible for determining the appropriateness of using or Distributing the Perceptron Models and Derivatives and assume any risks associated with Your exercise of rights under this Agreement. -# -# # 7. Termination -# **7.1. Term.** This Agreement is effective as of the date of your acceptance of this Agreement or access to the concerned Perceptron Models or Derivatives and will continue until terminated in accordance with the following terms. -# -# **7.2. Termination.** Perceptron, Inc. may terminate this Agreement at any time if You are in breach of this Agreement. Upon termination of this Agreement, You must cease to use all Perceptron Models and Derivatives and shall permanently delete any copy thereof. Sections 5, 6, 7 and 8 shall survive the termination of this Agreement. -# -# **7.3. Litigation.** If You initiate any legal action or proceedings against Us or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the Model or a Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement will immediately terminate as of the date such legal action or claim is filed or initiated. -# -# # 8. General provisions -# 8.1. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the State of Washington, without regard to its conflict of law principles. -# -# 8.2. Jurisdiction. The state and federal courts located in King County, Washington shall have exclusive jurisdiction over any dispute arising out of or relating to this Agreement, and You and We consent to personal jurisdiction and venue in such courts. -# -# **8.3. Severability.** If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. -# -# # 9. Definitions -# **โ€œAgreementโ€**: means this Perceptron, Inc. Non-Production License agreement governing the access, use, and Distribution of the Perceptron Models and Derivatives. +# coding=utf-8 +# Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved. # -# **โ€œDerivativeโ€**: means any (i) modified version of the Perceptron Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the Perceptron Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered as Derivatives under this Agreement. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# **โ€œDistributionโ€**, **โ€œDistributingโ€**, **โ€œDistributeโ€** or **โ€œDistributedโ€**: means providing or making available, by any means, a copy of the Perceptron Models and/or the Derivatives as the case may be, subject to Section 3 of this Agreement. +# http://www.apache.org/licenses/LICENSE-2.0 # -# **โ€œPerceptron, Inc.โ€**, **โ€œWeโ€** or **โ€œUsโ€**: means Perceptron, Inc., a Delaware corporation with its principal place of business at 10900 NE 8th St Suite 613, Bellevue, WA 98004. -# -# **โ€œPerceptron Modelโ€**: means the foundational large language model(s), and its elements which include algorithms, software, instructed checkpoints, parameters, source code (inference code, evaluation code and, if applicable, fine-tuning code) and any other elements associated thereto made available by Perceptron, Inc. under this Agreement, including, if any, the technical documentation, manuals and instructions for the use and operation thereof. -# -# **โ€œNon-Production Environmentโ€**: means any setting, use case, or application of the Perceptron Models or Derivatives that expressly excludes live, real-world conditions, commercial operations, revenue-generating activities, or direct interactions with or impacts on end users (such as, for instance, Your employees or customers). Non-Production Environment may include, but is not limited to, any setting, use case, or application for research, development, testing, quality assurance, training, internal evaluation (other than any internal usage by employees in the context of the companyโ€™s business activities), and demonstration purposes. -# -# **โ€œOutputsโ€**: means any content generated by the operation of the Perceptron Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a Perceptron Models, such as any fine-tuned versions of the Perceptron Models, the weights, or parameters. -# -# **โ€œPersonalโ€**: means any use of a Perceptron Model or a Derivative that is (i) solely for personal, non-profit and non-commercial purposes and (ii) not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities. For illustration purposes, Personal use of a Model or a Derivative does not include any usage by individuals employed in companies in the context of their daily tasks, any activity that is intended to generate revenue, or that is performed on behalf of a commercial entity. -# -# **โ€œYouโ€**: means the individual or entity entering into this Agreement with Perceptron, Inc.. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import copy @@ -109,7 +42,6 @@ class IsaacVisionConfig(PreTrainedConfig): model_type = "isaac_vision" base_config_key = "vision_config" - _attn_implementation: Optional[str] = None def __init__( self, @@ -142,16 +74,6 @@ def __init__( # Add our custom fields self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor - if self._attn_implementation is None: - self._attn_implementation = "flash_attention_2" - - # Keep legacy and new attention implementation fields in sync - existing_attn_impl = getattr(self, "attn_implementation", None) - if existing_attn_impl is None: - self.attn_implementation = self._attn_implementation - else: - self._attn_implementation = existing_attn_impl - class IsaacConfig(PretrainedConfig): """Configuration class for Isaac multimodal model. @@ -173,9 +95,7 @@ def __init__( vision_token: str = "", **kwargs, ): - self._rope_scaling: Optional[dict[str, Any]] = None self._rope_parameters: Optional[dict[str, Any]] = None - resolved_text_config = kwargs.pop("text_config", text_config) if isinstance(resolved_text_config, Qwen3Config): text_config_kwargs = copy.deepcopy(resolved_text_config.to_dict()) @@ -188,83 +108,37 @@ def __init__( text_config_kwargs.update(kwargs) - legacy_rope_theta = text_config_kwargs.pop("rope_theta", None) - incoming_rope_params = text_config_kwargs.pop("rope_parameters", None) - incoming_rope_scaling = text_config_kwargs.pop("rope_scaling", None) - normalized_rope_params = incoming_rope_params or incoming_rope_scaling - if normalized_rope_params is None and legacy_rope_theta is not None: - normalized_rope_params = {"rope_type": "default", "rope_theta": legacy_rope_theta} - elif ( - normalized_rope_params is not None - and legacy_rope_theta is not None - and "rope_theta" not in normalized_rope_params - ): - normalized_rope_params = {**normalized_rope_params, "rope_theta": legacy_rope_theta} - if normalized_rope_params is not None: - text_config_kwargs["rope_parameters"] = normalized_rope_params - self.text_config = self.sub_configs["text_config"](**text_config_kwargs) - - # Normalize rope parameters on the text config (prefer rope_parameters; alias rope_scaling) - self._rope_parameters = getattr(self.text_config, "rope_parameters", None) - if self._rope_parameters is None: - self._rope_parameters = getattr(self.text_config, "rope_scaling", None) - if self._rope_parameters is None and normalized_rope_params is not None: - self._rope_parameters = normalized_rope_params - if self._rope_parameters is None: - self._rope_parameters = {"rope_type": "default"} - - try: - self.text_config.rope_parameters = self._rope_parameters - except AttributeError: - setattr(self.text_config, "rope_parameters", self._rope_parameters) - if hasattr(self.text_config, "rope_scaling"): - self.text_config.rope_scaling = self._rope_parameters - else: - try: - setattr(self.text_config, "rope_scaling", self._rope_parameters) - except Exception: - pass + if not hasattr(self.text_config, "rope_theta"): + rope_theta_override = text_config_kwargs.get("rope_theta", kwargs.get("rope_theta")) + if rope_theta_override is None: + rope_theta_override = getattr(Qwen3Config(), "rope_theta", 10000.0) + self.text_config.rope_theta = rope_theta_override super().__init__(**kwargs) + if self._rope_scaling is None: + self._rope_scaling = getattr(self.text_config, "rope_scaling", None) + else: + self.text_config.rope_scaling = self._rope_scaling + # Keep rope parameters alias in sync with upstream expectations - self._rope_scaling = self._rope_parameters + self._rope_parameters = self._rope_scaling # Mirror frequently accessed Qwen3 attributes at the composite config level for BC. - self.tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False) self.vocab_size = self.text_config.vocab_size - self.max_position_embeddings = self.text_config.max_position_embeddings self.hidden_size = self.text_config.hidden_size - self.intermediate_size = self.text_config.intermediate_size self.num_hidden_layers = self.text_config.num_hidden_layers self.num_attention_heads = self.text_config.num_attention_heads - self.use_sliding_window = getattr(self.text_config, "use_sliding_window", False) - sliding_window = getattr(self.text_config, "sliding_window", None) - self.sliding_window = sliding_window if self.use_sliding_window else None - self.max_window_layers = getattr(self.text_config, "max_window_layers", None) - self.num_key_value_heads = getattr(self.text_config, "num_key_value_heads", None) - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads self.head_dim = self.text_config.head_dim self.hidden_act = self.text_config.hidden_act - self.initializer_range = self.text_config.initializer_range - self.rms_norm_eps = self.text_config.rms_norm_eps self.use_cache = self.text_config.use_cache - self.attention_bias = getattr(self.text_config, "attention_bias", False) - self.attention_dropout = getattr(self.text_config, "attention_dropout", 0.0) + self.rope_theta = self.text_config.rope_theta # Validate rotary parameters now that they have been mirrored locally. rope_config_validation(self) self.layer_types = getattr(self.text_config, "layer_types", None) - if self.layer_types is None: - self.layer_types = [ - "sliding_attention" - if self.sliding_window is not None and i >= self.max_window_layers - else "full_attention" - for i in range(self.num_hidden_layers) - ] layer_type_validation(self.layer_types, self.num_hidden_layers) # Handle vision config - either dict or IsaacVisionConfig instance @@ -282,42 +156,24 @@ def __init__( self.max_sequence_length = max_sequence_length self.vision_token = vision_token - def get_text_config(self, *_, **kwargs) -> Qwen3Config: - # Accept optional decoder/encoder flags to align with HF composite configs - kwargs.pop("decoder", None) - kwargs.pop("encoder", None) - return self.text_config - @property def rope_scaling(self): if hasattr(self, "text_config") and self.text_config is not None: - return getattr(self.text_config, "rope_parameters", None) or getattr( - self.text_config, "rope_scaling", None - ) - return self._rope_parameters + return getattr(self.text_config, "rope_scaling", None) + return self._rope_scaling @rope_scaling.setter def rope_scaling(self, value): - self._rope_parameters = value self._rope_scaling = value if hasattr(self, "text_config") and self.text_config is not None: - try: - self.text_config.rope_parameters = value - except AttributeError: - setattr(self.text_config, "rope_parameters", value) - try: - self.text_config.rope_scaling = value - except AttributeError: - pass + self.text_config.rope_scaling = value @property def rope_parameters(self) -> dict[str, Any] | None: """Alias introduced upstream for rope scaling dictionaries.""" value = self._rope_parameters - if value is None and hasattr(self, "text_config") and self.text_config is not None: - value = getattr(self.text_config, "rope_parameters", None) or getattr( - self.text_config, "rope_scaling", None - ) + if value is None: + value = self.rope_scaling if value is None: return {"rope_type": "default"} return value @@ -325,37 +181,13 @@ def rope_parameters(self) -> dict[str, Any] | None: @rope_parameters.setter def rope_parameters(self, value: dict[str, Any] | None) -> None: self._rope_parameters = value - self._rope_scaling = value self.rope_scaling = value - @property - def vision_attn_implementation(self) -> Optional[str]: - value = getattr(self.vision_config, "_attn_implementation", None) - if value is None: - value = getattr(self.vision_config, "attn_implementation", None) - return value - - @vision_attn_implementation.setter - def vision_attn_implementation(self, value: Optional[str]) -> None: - self.vision_config._attn_implementation = value - if value is not None: - self.vision_config.attn_implementation = value - elif hasattr(self.vision_config, "attn_implementation"): - delattr(self.vision_config, "attn_implementation") - def to_dict(self): output = super().to_dict() - rope_params = self.rope_parameters - output["rope_parameters"] = rope_params - output.pop("rope_scaling", None) - output.pop("rope_theta", None) # Ensure nested configs round-trip through dict serialization if hasattr(self, "text_config") and self.text_config is not None: - text_config_dict = self.text_config.to_dict() - text_config_dict.pop("rope_theta", None) - text_config_dict.pop("rope_scaling", None) - text_config_dict["rope_parameters"] = rope_params - output["text_config"] = text_config_dict + output["text_config"] = self.text_config.to_dict() if hasattr(self, "vision_config") and self.vision_config is not None: output["vision_config"] = self.vision_config.to_dict() return output diff --git a/src/transformers/models/isaac/image_processing_isaac.py b/src/transformers/models/isaac/image_processing_isaac.py index ecd28aaae954..9e09a15fc072 100644 --- a/src/transformers/models/isaac/image_processing_isaac.py +++ b/src/transformers/models/isaac/image_processing_isaac.py @@ -4,87 +4,20 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# Copyright (c) 2024 Perceptron, Inc. All rights reserved. -# Perceptron, Inc. Non-Production License (2024-01-01) - - -### 1. Scope and acceptance - -# **1.1. Scope of the Agreement.** -# This Agreement applies to any use, modification, or Distribution of any Perceptron Model by You, regardless of the source You obtained a copy of such Perceptron Model. -# -# **1.2. Acceptance.** By accessing, using, modifying, Distributing a Perceptron Model, or by creating, using or distributing a Derivative of the Perceptron Model, You agree to be bound by this Agreement. -# -# **1.3. Acceptance on behalf of a third-party.** If You accept this Agreement on behalf of Your employer or another person or entity, You warrant and represent that You have the authority to act and accept this Agreement on their behalf. In such a case, the word โ€œYouโ€ in this Agreement will refer to Your employer or such other person or entity. -# -# ## 2. License -# **2.1. Grant of rights.** Subject to Section 3 below, Perceptron, Inc. hereby grants You a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable, limited license to use, copy, modify, and Distribute under the conditions provided in Section 2.2 below, the Perceptron Model and any Derivatives made by or for Perceptron, Inc. and to create Derivatives of the Perceptron Model. -# -# **2.2. Distribution of Perceptron Model and Derivatives made by or for Perceptron, Inc..** Subject to Section 3 below, You may Distribute copies of the Perceptron Model and/or Derivatives made by or for Perceptron, Inc., under the following conditions: -# - You must make available a copy of this Agreement to third-party recipients of the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. you Distribute, it being specified that any rights to use the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. shall be directly granted by Perceptron, Inc. to said third-party recipients pursuant to the Perceptron, Inc. Non-Production License agreement executed between these parties; -# - You must retain in all copies of the Perceptron Models the following attribution notice within a โ€œNoticeโ€ text file distributed as part of such copies: โ€œLicensed by Perceptron, Inc. under the Perceptron, Inc. Non-Production Licenseโ€. -# -# **2.3. Distribution of Derivatives made by or for You.** Subject to Section 3 below, You may Distribute any Derivatives made by or for You under additional or different terms and conditions, provided that: -# - In any event, the use and modification of Perceptron Model and/or Derivatives made by or for Perceptron, Inc. shall remain governed by the terms and conditions of this Agreement; -# - You include in any such Derivatives made by or for You prominent notices stating that You modified the concerned Perceptron Model; and -# - Any terms and conditions You impose on any third-party recipients relating to Derivatives made by or for You shall neither limit such third-party recipientsโ€™ use of the Perceptron Model or any Derivatives made by or for Perceptron, Inc. in accordance with the Perceptron, Inc. Non-Production License nor conflict with any of its terms and conditions. -# -# ## 3. Limitations -# **3.1. Misrepresentation.** You must not misrepresent or imply, through any means, that the Derivatives made by or for You and/or any modified version of the Perceptron Model You Distribute under your name and responsibility is an official product of Perceptron, Inc. or has been endorsed, approved or validated by Perceptron, Inc., unless You are authorized by Us to do so in writing. -# -# **3.2. Usage Limitation** -# - You shall only use the Perceptron Models and Derivatives (whether or not created by Perceptron, Inc.) for testing, research, Personal, or evaluation purposes in Non-Production Environments; -# - Subject to the foregoing, You shall not supply the Perceptron Models or Derivatives in the course of a commercial activity, whether in return for payment or free of charge, in any medium or form, including but not limited to through a hosted or managed service (e.g. SaaS, cloud instances, etc.), or behind a software layer. -# -# **3.3. Usage not permitted under this Agreement.** If You want to use a Perceptron Model or a Derivative for any purpose that is not expressly authorized under this Agreement, You must request a license from Perceptron, Inc., which Perceptron, Inc. may grant to You in Perceptron, Inc.โ€™s sole discretion. Please contact Perceptron, Inc. at the following e-mail address if You want to discuss such a license: sales@perceptron.inc -# -# ## 4. Intellectual Property -# **4.1. Trademarks.** No trademark licenses are granted under this Agreement, and in connection with the Perceptron Models, You may not use any name or mark owned by or associated with Perceptron, Inc. or any of its affiliates, except (i) as required for reasonable and customary use in describing and Distributing the Perceptron Models and Derivatives made by or for Perceptron, Inc. and (ii) for attribution purposes as required by this Agreement. -# -# **4.2. Outputs.** We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs You generate and their subsequent uses in accordance with this Agreement. -# -# **4.3. Derivatives.** By entering into this Agreement, You accept that any Derivatives that You may create or that may be created for You shall be subject to the restrictions set out in Section 3 of this Agreement. -# -# # 5. Liability -# **5.1. Limitation of liability.** In no event, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall Perceptron, Inc. be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Perceptron Models and Derivatives (including but not limited to damages for loss of data, loss of goodwill, loss of expected profit or savings, work stoppage, computer failure or malfunction, or any damage caused by malware or security breaches), even if Perceptron, Inc. has been advised of the possibility of such damages. -# -# **5.2. Indemnification.** You agree to indemnify and hold harmless Perceptron, Inc. from and against any claims, damages, or losses arising out of or related to Your use or Distribution of the Perceptron Models and Derivatives. -# -# ## 6. Warranty -# **6.1. Disclaimer.** Unless required by applicable law or agreed to in writing, Perceptron, Inc. provides the Perceptron Models and Derivatives on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. Perceptron, Inc. does not represent nor warrant that the Perceptron Models and Derivatives will be error-free, meet Your or any third partyโ€™s requirements, be secure or will allow You or any third party to achieve any kind of result or generate any kind of content. You are solely responsible for determining the appropriateness of using or Distributing the Perceptron Models and Derivatives and assume any risks associated with Your exercise of rights under this Agreement. -# -# # 7. Termination -# **7.1. Term.** This Agreement is effective as of the date of your acceptance of this Agreement or access to the concerned Perceptron Models or Derivatives and will continue until terminated in accordance with the following terms. -# -# **7.2. Termination.** Perceptron, Inc. may terminate this Agreement at any time if You are in breach of this Agreement. Upon termination of this Agreement, You must cease to use all Perceptron Models and Derivatives and shall permanently delete any copy thereof. Sections 5, 6, 7 and 8 shall survive the termination of this Agreement. -# -# **7.3. Litigation.** If You initiate any legal action or proceedings against Us or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the Model or a Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement will immediately terminate as of the date such legal action or claim is filed or initiated. -# -# # 8. General provisions -# 8.1. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the State of Washington, without regard to its conflict of law principles. -# -# 8.2. Jurisdiction. The state and federal courts located in King County, Washington shall have exclusive jurisdiction over any dispute arising out of or relating to this Agreement, and You and We consent to personal jurisdiction and venue in such courts. -# -# **8.3. Severability.** If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. -# -# # 9. Definitions -# **โ€œAgreementโ€**: means this Perceptron, Inc. Non-Production License agreement governing the access, use, and Distribution of the Perceptron Models and Derivatives. -# -# **โ€œDerivativeโ€**: means any (i) modified version of the Perceptron Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the Perceptron Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered as Derivatives under this Agreement. -# -# **โ€œDistributionโ€**, **โ€œDistributingโ€**, **โ€œDistributeโ€** or **โ€œDistributedโ€**: means providing or making available, by any means, a copy of the Perceptron Models and/or the Derivatives as the case may be, subject to Section 3 of this Agreement. -# -# **โ€œPerceptron, Inc.โ€**, **โ€œWeโ€** or **โ€œUsโ€**: means Perceptron, Inc., a Delaware corporation with its principal place of business at 10900 NE 8th St Suite 613, Bellevue, WA 98004. -# -# **โ€œPerceptron Modelโ€**: means the foundational large language model(s), and its elements which include algorithms, software, instructed checkpoints, parameters, source code (inference code, evaluation code and, if applicable, fine-tuning code) and any other elements associated thereto made available by Perceptron, Inc. under this Agreement, including, if any, the technical documentation, manuals and instructions for the use and operation thereof. -# -# **โ€œNon-Production Environmentโ€**: means any setting, use case, or application of the Perceptron Models or Derivatives that expressly excludes live, real-world conditions, commercial operations, revenue-generating activities, or direct interactions with or impacts on end users (such as, for instance, Your employees or customers). Non-Production Environment may include, but is not limited to, any setting, use case, or application for research, development, testing, quality assurance, training, internal evaluation (other than any internal usage by employees in the context of the companyโ€™s business activities), and demonstration purposes. +# coding=utf-8 +# Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved. # -# **โ€œOutputsโ€**: means any content generated by the operation of the Perceptron Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a Perceptron Models, such as any fine-tuned versions of the Perceptron Models, the weights, or parameters. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# **โ€œPersonalโ€**: means any use of a Perceptron Model or a Derivative that is (i) solely for personal, non-profit and non-commercial purposes and (ii) not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities. For illustration purposes, Personal use of a Model or a Derivative does not include any usage by individuals employed in companies in the context of their daily tasks, any activity that is intended to generate revenue, or that is performed on behalf of a commercial entity. +# http://www.apache.org/licenses/LICENSE-2.0 # -# **โ€œYouโ€**: means the individual or entity entering into this Agreement with Perceptron, Inc.. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Optional diff --git a/src/transformers/models/isaac/image_processing_isaac_fast.py b/src/transformers/models/isaac/image_processing_isaac_fast.py index 634460a14199..1cbfccd70a2b 100644 --- a/src/transformers/models/isaac/image_processing_isaac_fast.py +++ b/src/transformers/models/isaac/image_processing_isaac_fast.py @@ -4,87 +4,20 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# Copyright (c) 2024 Perceptron, Inc. All rights reserved. -# Perceptron, Inc. Non-Production License (2024-01-01) - - -### 1. Scope and acceptance - -# **1.1. Scope of the Agreement.** -# This Agreement applies to any use, modification, or Distribution of any Perceptron Model by You, regardless of the source You obtained a copy of such Perceptron Model. -# -# **1.2. Acceptance.** By accessing, using, modifying, Distributing a Perceptron Model, or by creating, using or distributing a Derivative of the Perceptron Model, You agree to be bound by this Agreement. -# -# **1.3. Acceptance on behalf of a third-party.** If You accept this Agreement on behalf of Your employer or another person or entity, You warrant and represent that You have the authority to act and accept this Agreement on their behalf. In such a case, the word โ€œYouโ€ in this Agreement will refer to Your employer or such other person or entity. -# -# ## 2. License -# **2.1. Grant of rights.** Subject to Section 3 below, Perceptron, Inc. hereby grants You a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable, limited license to use, copy, modify, and Distribute under the conditions provided in Section 2.2 below, the Perceptron Model and any Derivatives made by or for Perceptron, Inc. and to create Derivatives of the Perceptron Model. -# -# **2.2. Distribution of Perceptron Model and Derivatives made by or for Perceptron, Inc..** Subject to Section 3 below, You may Distribute copies of the Perceptron Model and/or Derivatives made by or for Perceptron, Inc., under the following conditions: -# - You must make available a copy of this Agreement to third-party recipients of the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. you Distribute, it being specified that any rights to use the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. shall be directly granted by Perceptron, Inc. to said third-party recipients pursuant to the Perceptron, Inc. Non-Production License agreement executed between these parties; -# - You must retain in all copies of the Perceptron Models the following attribution notice within a โ€œNoticeโ€ text file distributed as part of such copies: โ€œLicensed by Perceptron, Inc. under the Perceptron, Inc. Non-Production Licenseโ€. -# -# **2.3. Distribution of Derivatives made by or for You.** Subject to Section 3 below, You may Distribute any Derivatives made by or for You under additional or different terms and conditions, provided that: -# - In any event, the use and modification of Perceptron Model and/or Derivatives made by or for Perceptron, Inc. shall remain governed by the terms and conditions of this Agreement; -# - You include in any such Derivatives made by or for You prominent notices stating that You modified the concerned Perceptron Model; and -# - Any terms and conditions You impose on any third-party recipients relating to Derivatives made by or for You shall neither limit such third-party recipientsโ€™ use of the Perceptron Model or any Derivatives made by or for Perceptron, Inc. in accordance with the Perceptron, Inc. Non-Production License nor conflict with any of its terms and conditions. -# -# ## 3. Limitations -# **3.1. Misrepresentation.** You must not misrepresent or imply, through any means, that the Derivatives made by or for You and/or any modified version of the Perceptron Model You Distribute under your name and responsibility is an official product of Perceptron, Inc. or has been endorsed, approved or validated by Perceptron, Inc., unless You are authorized by Us to do so in writing. -# -# **3.2. Usage Limitation** -# - You shall only use the Perceptron Models and Derivatives (whether or not created by Perceptron, Inc.) for testing, research, Personal, or evaluation purposes in Non-Production Environments; -# - Subject to the foregoing, You shall not supply the Perceptron Models or Derivatives in the course of a commercial activity, whether in return for payment or free of charge, in any medium or form, including but not limited to through a hosted or managed service (e.g. SaaS, cloud instances, etc.), or behind a software layer. -# -# **3.3. Usage not permitted under this Agreement.** If You want to use a Perceptron Model or a Derivative for any purpose that is not expressly authorized under this Agreement, You must request a license from Perceptron, Inc., which Perceptron, Inc. may grant to You in Perceptron, Inc.โ€™s sole discretion. Please contact Perceptron, Inc. at the following e-mail address if You want to discuss such a license: sales@perceptron.inc -# -# ## 4. Intellectual Property -# **4.1. Trademarks.** No trademark licenses are granted under this Agreement, and in connection with the Perceptron Models, You may not use any name or mark owned by or associated with Perceptron, Inc. or any of its affiliates, except (i) as required for reasonable and customary use in describing and Distributing the Perceptron Models and Derivatives made by or for Perceptron, Inc. and (ii) for attribution purposes as required by this Agreement. -# -# **4.2. Outputs.** We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs You generate and their subsequent uses in accordance with this Agreement. -# -# **4.3. Derivatives.** By entering into this Agreement, You accept that any Derivatives that You may create or that may be created for You shall be subject to the restrictions set out in Section 3 of this Agreement. -# -# # 5. Liability -# **5.1. Limitation of liability.** In no event, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall Perceptron, Inc. be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Perceptron Models and Derivatives (including but not limited to damages for loss of data, loss of goodwill, loss of expected profit or savings, work stoppage, computer failure or malfunction, or any damage caused by malware or security breaches), even if Perceptron, Inc. has been advised of the possibility of such damages. -# -# **5.2. Indemnification.** You agree to indemnify and hold harmless Perceptron, Inc. from and against any claims, damages, or losses arising out of or related to Your use or Distribution of the Perceptron Models and Derivatives. +# coding=utf-8 +# Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved. # -# ## 6. Warranty -# **6.1. Disclaimer.** Unless required by applicable law or agreed to in writing, Perceptron, Inc. provides the Perceptron Models and Derivatives on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. Perceptron, Inc. does not represent nor warrant that the Perceptron Models and Derivatives will be error-free, meet Your or any third partyโ€™s requirements, be secure or will allow You or any third party to achieve any kind of result or generate any kind of content. You are solely responsible for determining the appropriateness of using or Distributing the Perceptron Models and Derivatives and assume any risks associated with Your exercise of rights under this Agreement. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# # 7. Termination -# **7.1. Term.** This Agreement is effective as of the date of your acceptance of this Agreement or access to the concerned Perceptron Models or Derivatives and will continue until terminated in accordance with the following terms. +# http://www.apache.org/licenses/LICENSE-2.0 # -# **7.2. Termination.** Perceptron, Inc. may terminate this Agreement at any time if You are in breach of this Agreement. Upon termination of this Agreement, You must cease to use all Perceptron Models and Derivatives and shall permanently delete any copy thereof. Sections 5, 6, 7 and 8 shall survive the termination of this Agreement. -# -# **7.3. Litigation.** If You initiate any legal action or proceedings against Us or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the Model or a Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement will immediately terminate as of the date such legal action or claim is filed or initiated. -# -# # 8. General provisions -# 8.1. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the State of Washington, without regard to its conflict of law principles. -# -# 8.2. Jurisdiction. The state and federal courts located in King County, Washington shall have exclusive jurisdiction over any dispute arising out of or relating to this Agreement, and You and We consent to personal jurisdiction and venue in such courts. -# -# **8.3. Severability.** If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. -# -# # 9. Definitions -# **โ€œAgreementโ€**: means this Perceptron, Inc. Non-Production License agreement governing the access, use, and Distribution of the Perceptron Models and Derivatives. -# -# **โ€œDerivativeโ€**: means any (i) modified version of the Perceptron Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the Perceptron Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered as Derivatives under this Agreement. -# -# **โ€œDistributionโ€**, **โ€œDistributingโ€**, **โ€œDistributeโ€** or **โ€œDistributedโ€**: means providing or making available, by any means, a copy of the Perceptron Models and/or the Derivatives as the case may be, subject to Section 3 of this Agreement. -# -# **โ€œPerceptron, Inc.โ€**, **โ€œWeโ€** or **โ€œUsโ€**: means Perceptron, Inc., a Delaware corporation with its principal place of business at 10900 NE 8th St Suite 613, Bellevue, WA 98004. -# -# **โ€œPerceptron Modelโ€**: means the foundational large language model(s), and its elements which include algorithms, software, instructed checkpoints, parameters, source code (inference code, evaluation code and, if applicable, fine-tuning code) and any other elements associated thereto made available by Perceptron, Inc. under this Agreement, including, if any, the technical documentation, manuals and instructions for the use and operation thereof. -# -# **โ€œNon-Production Environmentโ€**: means any setting, use case, or application of the Perceptron Models or Derivatives that expressly excludes live, real-world conditions, commercial operations, revenue-generating activities, or direct interactions with or impacts on end users (such as, for instance, Your employees or customers). Non-Production Environment may include, but is not limited to, any setting, use case, or application for research, development, testing, quality assurance, training, internal evaluation (other than any internal usage by employees in the context of the companyโ€™s business activities), and demonstration purposes. -# -# **โ€œOutputsโ€**: means any content generated by the operation of the Perceptron Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a Perceptron Models, such as any fine-tuned versions of the Perceptron Models, the weights, or parameters. -# -# **โ€œPersonalโ€**: means any use of a Perceptron Model or a Derivative that is (i) solely for personal, non-profit and non-commercial purposes and (ii) not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities. For illustration purposes, Personal use of a Model or a Derivative does not include any usage by individuals employed in companies in the context of their daily tasks, any activity that is intended to generate revenue, or that is performed on behalf of a commercial entity. -# -# **โ€œYouโ€**: means the individual or entity entering into this Agreement with Perceptron, Inc.. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import math from collections.abc import Sequence @@ -433,16 +366,12 @@ def _preprocess( .repeat(batch_size, 1) ) - if pixel_shuffle_scale > 1: - if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): - raise ValueError( - "Spatial dimensions must be divisible by pixel_shuffle_scale when pixel shuffle is enabled." - ) - virtual_height = height_tokens // pixel_shuffle_scale - virtual_width = width_tokens // pixel_shuffle_scale - else: - virtual_height = height_tokens - virtual_width = width_tokens + if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): + raise ValueError( + "Spatial dimensions must be divisible by pixel_shuffle_scale when pixel shuffle is enabled." + ) + virtual_height = height_tokens // pixel_shuffle_scale + virtual_width = width_tokens // pixel_shuffle_scale virtual_dim = ( torch.tensor( diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 2982a3512f79..03869bff8c83 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -4,109 +4,41 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# Copyright (c) 2024 Perceptron, Inc. All rights reserved. -# Perceptron, Inc. Non-Production License (2024-01-01) - - -### 1. Scope and acceptance - -# **1.1. Scope of the Agreement.** -# This Agreement applies to any use, modification, or Distribution of any Perceptron Model by You, regardless of the source You obtained a copy of such Perceptron Model. -# -# **1.2. Acceptance.** By accessing, using, modifying, Distributing a Perceptron Model, or by creating, using or distributing a Derivative of the Perceptron Model, You agree to be bound by this Agreement. -# -# **1.3. Acceptance on behalf of a third-party.** If You accept this Agreement on behalf of Your employer or another person or entity, You warrant and represent that You have the authority to act and accept this Agreement on their behalf. In such a case, the word โ€œYouโ€ in this Agreement will refer to Your employer or such other person or entity. -# -# ## 2. License -# **2.1. Grant of rights.** Subject to Section 3 below, Perceptron, Inc. hereby grants You a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable, limited license to use, copy, modify, and Distribute under the conditions provided in Section 2.2 below, the Perceptron Model and any Derivatives made by or for Perceptron, Inc. and to create Derivatives of the Perceptron Model. -# -# **2.2. Distribution of Perceptron Model and Derivatives made by or for Perceptron, Inc..** Subject to Section 3 below, You may Distribute copies of the Perceptron Model and/or Derivatives made by or for Perceptron, Inc., under the following conditions: -# - You must make available a copy of this Agreement to third-party recipients of the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. you Distribute, it being specified that any rights to use the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. shall be directly granted by Perceptron, Inc. to said third-party recipients pursuant to the Perceptron, Inc. Non-Production License agreement executed between these parties; -# - You must retain in all copies of the Perceptron Models the following attribution notice within a โ€œNoticeโ€ text file distributed as part of such copies: โ€œLicensed by Perceptron, Inc. under the Perceptron, Inc. Non-Production Licenseโ€. -# -# **2.3. Distribution of Derivatives made by or for You.** Subject to Section 3 below, You may Distribute any Derivatives made by or for You under additional or different terms and conditions, provided that: -# - In any event, the use and modification of Perceptron Model and/or Derivatives made by or for Perceptron, Inc. shall remain governed by the terms and conditions of this Agreement; -# - You include in any such Derivatives made by or for You prominent notices stating that You modified the concerned Perceptron Model; and -# - Any terms and conditions You impose on any third-party recipients relating to Derivatives made by or for You shall neither limit such third-party recipientsโ€™ use of the Perceptron Model or any Derivatives made by or for Perceptron, Inc. in accordance with the Perceptron, Inc. Non-Production License nor conflict with any of its terms and conditions. -# -# ## 3. Limitations -# **3.1. Misrepresentation.** You must not misrepresent or imply, through any means, that the Derivatives made by or for You and/or any modified version of the Perceptron Model You Distribute under your name and responsibility is an official product of Perceptron, Inc. or has been endorsed, approved or validated by Perceptron, Inc., unless You are authorized by Us to do so in writing. -# -# **3.2. Usage Limitation** -# - You shall only use the Perceptron Models and Derivatives (whether or not created by Perceptron, Inc.) for testing, research, Personal, or evaluation purposes in Non-Production Environments; -# - Subject to the foregoing, You shall not supply the Perceptron Models or Derivatives in the course of a commercial activity, whether in return for payment or free of charge, in any medium or form, including but not limited to through a hosted or managed service (e.g. SaaS, cloud instances, etc.), or behind a software layer. -# -# **3.3. Usage not permitted under this Agreement.** If You want to use a Perceptron Model or a Derivative for any purpose that is not expressly authorized under this Agreement, You must request a license from Perceptron, Inc., which Perceptron, Inc. may grant to You in Perceptron, Inc.โ€™s sole discretion. Please contact Perceptron, Inc. at the following e-mail address if You want to discuss such a license: sales@perceptron.inc -# -# ## 4. Intellectual Property -# **4.1. Trademarks.** No trademark licenses are granted under this Agreement, and in connection with the Perceptron Models, You may not use any name or mark owned by or associated with Perceptron, Inc. or any of its affiliates, except (i) as required for reasonable and customary use in describing and Distributing the Perceptron Models and Derivatives made by or for Perceptron, Inc. and (ii) for attribution purposes as required by this Agreement. -# -# **4.2. Outputs.** We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs You generate and their subsequent uses in accordance with this Agreement. -# -# **4.3. Derivatives.** By entering into this Agreement, You accept that any Derivatives that You may create or that may be created for You shall be subject to the restrictions set out in Section 3 of this Agreement. -# -# # 5. Liability -# **5.1. Limitation of liability.** In no event, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall Perceptron, Inc. be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Perceptron Models and Derivatives (including but not limited to damages for loss of data, loss of goodwill, loss of expected profit or savings, work stoppage, computer failure or malfunction, or any damage caused by malware or security breaches), even if Perceptron, Inc. has been advised of the possibility of such damages. +# coding=utf-8 +# Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved. # -# **5.2. Indemnification.** You agree to indemnify and hold harmless Perceptron, Inc. from and against any claims, damages, or losses arising out of or related to Your use or Distribution of the Perceptron Models and Derivatives. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# ## 6. Warranty -# **6.1. Disclaimer.** Unless required by applicable law or agreed to in writing, Perceptron, Inc. provides the Perceptron Models and Derivatives on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. Perceptron, Inc. does not represent nor warrant that the Perceptron Models and Derivatives will be error-free, meet Your or any third partyโ€™s requirements, be secure or will allow You or any third party to achieve any kind of result or generate any kind of content. You are solely responsible for determining the appropriateness of using or Distributing the Perceptron Models and Derivatives and assume any risks associated with Your exercise of rights under this Agreement. +# http://www.apache.org/licenses/LICENSE-2.0 # -# # 7. Termination -# **7.1. Term.** This Agreement is effective as of the date of your acceptance of this Agreement or access to the concerned Perceptron Models or Derivatives and will continue until terminated in accordance with the following terms. -# -# **7.2. Termination.** Perceptron, Inc. may terminate this Agreement at any time if You are in breach of this Agreement. Upon termination of this Agreement, You must cease to use all Perceptron Models and Derivatives and shall permanently delete any copy thereof. Sections 5, 6, 7 and 8 shall survive the termination of this Agreement. -# -# **7.3. Litigation.** If You initiate any legal action or proceedings against Us or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the Model or a Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement will immediately terminate as of the date such legal action or claim is filed or initiated. -# -# # 8. General provisions -# 8.1. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the State of Washington, without regard to its conflict of law principles. -# -# 8.2. Jurisdiction. The state and federal courts located in King County, Washington shall have exclusive jurisdiction over any dispute arising out of or relating to this Agreement, and You and We consent to personal jurisdiction and venue in such courts. -# -# **8.3. Severability.** If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. -# -# # 9. Definitions -# **โ€œAgreementโ€**: means this Perceptron, Inc. Non-Production License agreement governing the access, use, and Distribution of the Perceptron Models and Derivatives. -# -# **โ€œDerivativeโ€**: means any (i) modified version of the Perceptron Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the Perceptron Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered as Derivatives under this Agreement. -# -# **โ€œDistributionโ€**, **โ€œDistributingโ€**, **โ€œDistributeโ€** or **โ€œDistributedโ€**: means providing or making available, by any means, a copy of the Perceptron Models and/or the Derivatives as the case may be, subject to Section 3 of this Agreement. -# -# **โ€œPerceptron, Inc.โ€**, **โ€œWeโ€** or **โ€œUsโ€**: means Perceptron, Inc., a Delaware corporation with its principal place of business at 10900 NE 8th St Suite 613, Bellevue, WA 98004. -# -# **โ€œPerceptron Modelโ€**: means the foundational large language model(s), and its elements which include algorithms, software, instructed checkpoints, parameters, source code (inference code, evaluation code and, if applicable, fine-tuning code) and any other elements associated thereto made available by Perceptron, Inc. under this Agreement, including, if any, the technical documentation, manuals and instructions for the use and operation thereof. -# -# **โ€œNon-Production Environmentโ€**: means any setting, use case, or application of the Perceptron Models or Derivatives that expressly excludes live, real-world conditions, commercial operations, revenue-generating activities, or direct interactions with or impacts on end users (such as, for instance, Your employees or customers). Non-Production Environment may include, but is not limited to, any setting, use case, or application for research, development, testing, quality assurance, training, internal evaluation (other than any internal usage by employees in the context of the companyโ€™s business activities), and demonstration purposes. -# -# **โ€œOutputsโ€**: means any content generated by the operation of the Perceptron Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a Perceptron Models, such as any fine-tuned versions of the Perceptron Models, the weights, or parameters. -# -# **โ€œPersonalโ€**: means any use of a Perceptron Model or a Derivative that is (i) solely for personal, non-profit and non-commercial purposes and (ii) not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities. For illustration purposes, Personal use of a Model or a Derivative does not include any usage by individuals employed in companies in the context of their daily tasks, any activity that is intended to generate revenue, or that is performed on behalf of a commercial entity. -# -# **โ€œYouโ€**: means the individual or entity entering into this Agreement with Perceptron, Inc.. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import copy from collections import defaultdict from collections.abc import Callable -from typing import Any, Optional +from typing import Any, Optional, Union from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation.utils import GenerationMixin from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_masks_for_generate, eager_mask, packed_sequence_mask_function, sdpa_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...models.auto.modeling_auto import AutoModel -from ...models.qwen3.configuration_qwen3 import Qwen3Config from ...models.qwen3.modeling_qwen3 import Qwen3PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring -from ...utils.generic import TransformersKwargs, can_return_tuple +from ...utils.generic import TransformersKwargs, can_return_tuple, check_model_inputs from ...utils.import_utils import ( is_perceptron_available, is_torch_available, @@ -288,13 +220,12 @@ def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor return torch.cat(output_chunks, dim=0) -def build_document_attention_mask( - cu_seqlens: Optional[torch.Tensor], - total_tokens: int, - dtype: torch.dtype, - device: torch.device, -) -> Optional[torch.Tensor]: - """Creates an additive attention mask that blocks cross-document attention.""" +def document_mask_function_from_cu_seqlens(cu_seqlens: Optional[torch.Tensor]) -> Optional[Callable]: + """Return a mask function that blocks cross-document attention from packed ``cu_seqlens``. + + The returned callable matches the signature expected by ``masking_utils`` mask factories and + yields ``True`` only when query/key positions belong to the same packed segment. + """ if cu_seqlens is None: return None @@ -306,11 +237,10 @@ def build_document_attention_mask( if seq_sizes.numel() == 0: return None - seg_ids = torch.repeat_interleave(torch.arange(seq_sizes.numel(), device=device), seq_sizes) - block_mask = seg_ids[:, None] != seg_ids[None, :] - additive_mask = torch.zeros((total_tokens, total_tokens), dtype=dtype, device=device) - additive_mask.masked_fill_(block_mask, float("-inf")) - return additive_mask.view(1, 1, total_tokens, total_tokens) + total_tokens = int(seq_sizes.sum().item()) + seg_ids = torch.repeat_interleave(torch.arange(seq_sizes.numel(), device=cu_seqlens.device), seq_sizes) + packed_sequence_mask = seg_ids.view(1, total_tokens) + return packed_sequence_mask_function(packed_sequence_mask) def ensure_document_attention_mask( @@ -319,16 +249,26 @@ def ensure_document_attention_mask( total_tokens: int, dtype: torch.dtype, device: torch.device, -) -> Optional[torch.Tensor]: - if attention_mask is not None or cu_seqlens is None: + *, + return_mask_function: bool = False, +) -> Optional[Union[torch.Tensor, Callable]]: + """Return the provided mask, a callable mask from ``cu_seqlens``, or ``None``. + + ``return_mask_function=True`` yields a callable suitable for ``masking_utils``; otherwise + ``None`` is returned when no explicit ``attention_mask`` is provided. The legacy additive mask + has been removed in favor of the callable-based path. + """ + + if attention_mask is not None: return attention_mask - return build_document_attention_mask( - cu_seqlens=cu_seqlens, - total_tokens=total_tokens, - dtype=dtype, - device=device, - ) + if cu_seqlens is None: + return None + + if return_mask_function: + return document_mask_function_from_cu_seqlens(cu_seqlens) + + return None class IsaacVisionAttention(nn.Module): @@ -367,16 +307,31 @@ def __init__(self, config): self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) self._variable_length_metadata = None - def forward(self, hidden_states, attention_mask=None, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[torch.Tensor] = None, + output_attentions: bool = False, + is_causal: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" - cu_seqlens = kwargs.pop("cu_seqlens", None) - max_seqlen = kwargs.pop("max_seqlen", None) - kwargs.pop("output_attentions", None) + # Unused arguments are accepted for interface compatibility + _ = position_ids + _ = past_key_value + _ = is_causal + _ = output_attentions + kwargs.pop("output_hidden_states", None) kwargs.pop("return_dict", None) if kwargs: unexpected = ", ".join(sorted(kwargs)) raise TypeError(f"Unexpected kwargs for IsaacVisionAttention.forward: {unexpected}") + cached_cu, cached_max = self._consume_variable_length_metadata() if cu_seqlens is None: cu_seqlens = cached_cu @@ -398,7 +353,9 @@ def forward(self, hidden_states, attention_mask=None, **kwargs) -> tuple[torch.T k = self.k_proj(x).view(L, H, D) v = self.v_proj(x).view(L, H, D) - attn_impl = getattr(self.config, "_attn_implementation", "flash_attention_3") + resolved_key = "isaac_sdpa" + if self.config._attn_implementation != "sdpa": + resolved_key = self.ATTENTION_KEY_MAP.get(self.config._attn_implementation, resolved_key) attn_mask = ensure_document_attention_mask( attention_mask, @@ -406,10 +363,9 @@ def forward(self, hidden_states, attention_mask=None, **kwargs) -> tuple[torch.T L, q.dtype, q.device, + return_mask_function=True, ) - resolved_key = self.ATTENTION_KEY_MAP.get(attn_impl, attn_impl) - attn_weights = None if resolved_key in self._FLASH_IMPLS: y_lhd = self._flash_attention_forward( @@ -440,7 +396,7 @@ def forward(self, hidden_states, attention_mask=None, **kwargs) -> tuple[torch.T else: attention_fn = ALL_ATTENTION_FUNCTIONS.get(resolved_key) if attention_fn is None: - raise ValueError(f"Attention implementation {attn_impl} not found.") + raise ValueError(f"Attention implementation {resolved_key} not found.") query_states = q.transpose(0, 1).unsqueeze(0) key_states = k.transpose(0, 1).unsqueeze(0) @@ -537,19 +493,30 @@ def _sdpa_attention_forward( q_lhd: torch.Tensor, k_lhd: torch.Tensor, v_lhd: torch.Tensor, - attention_mask: Optional[torch.Tensor], + attention_mask: Optional[Union[torch.Tensor, Callable]], cu_seqlens: Optional[torch.Tensor], dropout: float, ) -> torch.Tensor: L = q_lhd.size(0) attn_mask = attention_mask - if attn_mask is None: - attn_mask = build_document_attention_mask( - cu_seqlens=cu_seqlens, - total_tokens=L, - dtype=q_lhd.dtype, - device=q_lhd.device, + + if callable(attn_mask): + cache_position = torch.arange(L, device=q_lhd.device, dtype=torch.long) + attn_mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=L, + kv_offset=0, + mask_function=attn_mask, + attention_mask=None, + allow_is_causal_skip=False, + allow_is_bidirectional_skip=False, + allow_torch_fix=False, + use_vmap=False, ) + # sdpa_mask returns True for allowed positions; SDPA expects True to mean "mask out" + if attn_mask is not None and attn_mask.dtype == torch.bool: + attn_mask = ~attn_mask q = q_lhd.permute(1, 0, 2).unsqueeze(0) k = k_lhd.permute(1, 0, 2).unsqueeze(0) @@ -575,15 +542,30 @@ def _eager_attention_forward( q_lhd: torch.Tensor, k_lhd: torch.Tensor, v_lhd: torch.Tensor, - attention_mask: Optional[torch.Tensor], + attention_mask: Optional[Union[torch.Tensor, Callable]], dropout: float, ) -> tuple[torch.Tensor, torch.Tensor]: + L = q_lhd.size(0) + attn_mask = attention_mask + if callable(attn_mask): + cache_position = torch.arange(L, device=q_lhd.device, dtype=torch.long) + attn_mask = eager_mask( + batch_size=1, + cache_position=cache_position, + kv_length=L, + kv_offset=0, + mask_function=attn_mask, + attention_mask=None, + allow_is_bidirectional_skip=False, + use_vmap=False, + dtype=q_lhd.dtype, + ) + if attn_mask is not None and attn_mask.dim() == 4: + attn_mask = attn_mask.squeeze(0).squeeze(0) + attn_weights = torch.matmul(q_lhd, k_lhd.transpose(1, 2)) * self.scale - if attention_mask is not None: - mask = attention_mask - if mask.dim() == 4: - mask = mask.squeeze(0).squeeze(0) - attn_weights = attn_weights + mask + if attn_mask is not None: + attn_weights = attn_weights + attn_mask attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_lhd.dtype) if dropout and self.training: @@ -884,12 +866,11 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): # Apply final layer normalization hidden_states = self.post_layernorm(hidden_states) - if self.pixel_shuffle_scale_factor > 1: - hidden_states = pixel_shuffle_varlen( - x=hidden_states, - token_grids=token_grids, - scale_factor=self.pixel_shuffle_scale_factor, - ) + hidden_states = pixel_shuffle_varlen( + x=hidden_states, + token_grids=token_grids, + scale_factor=self.pixel_shuffle_scale_factor, + ) # Remove the pseudo batch dimension we added earlier hidden_states = hidden_states.squeeze(0) @@ -897,6 +878,26 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): return hidden_states +class IsaacVisionEmbedding(nn.Module): + """Vision embedding wrapper exposing tower and projector.""" + + def __init__(self, config: IsaacConfig): + super().__init__() + vision_cfg = config.vision_config + hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2) + + self.vision_tower = IsaacVisionTransformer(vision_cfg) + self.multimodal_projector = nn.Sequential( + nn.Linear(hidden_dim, 4 * hidden_dim, bias=False), + nn.SiLU(), + nn.Linear(4 * hidden_dim, config.hidden_size, bias=False), + ) + + def forward(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + hidden_states = self.vision_tower(vision_tokens) + return self.multimodal_projector(hidden_states) + + class IsaacRotaryEmbedding(nn.Module): EXTRA_ROPE_KEYS = {"mrope_section", "mrope_interleaved"} @@ -904,29 +905,17 @@ def __init__(self, config: IsaacConfig, device=None): super().__init__() rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config - rope_params = ( - getattr(rope_source_cfg, "rope_parameters", None) or getattr(rope_source_cfg, "rope_scaling", None) or {} - ) - legacy_rope_theta = getattr(rope_source_cfg, "rope_theta", None) - if legacy_rope_theta is not None and isinstance(rope_params, dict) and "rope_theta" not in rope_params: - rope_params = {**rope_params, "rope_theta": legacy_rope_theta} + rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} - sanitized_params = {k: v for k, v in rope_params.items() if k not in self.EXTRA_ROPE_KEYS} + sanitized_scaling = {k: v for k, v in rope_scaling.items() if k not in self.EXTRA_ROPE_KEYS} config_for_rope = copy.copy(rope_source_cfg) - config_for_rope.rope_parameters = sanitized_params if sanitized_params else None - if hasattr(config_for_rope, "rope_scaling"): - config_for_rope.rope_scaling = sanitized_params if sanitized_params else None - if hasattr(config_for_rope, "rope_theta"): - try: - delattr(config_for_rope, "rope_theta") - except Exception: - config_for_rope.rope_theta = None + config_for_rope.rope_scaling = sanitized_scaling if sanitized_scaling else None init_device = device if device is not None and getattr(device, "type", None) != "meta" else None self._qwen_rotary = qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding(config_for_rope, device=init_device) rotary_half_dim = self._qwen_rotary.inv_freq.shape[0] - self.mrope_section = self._resolve_mrope_section(rope_params.get("mrope_section"), rotary_half_dim) + self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim) self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size @staticmethod @@ -1256,9 +1245,8 @@ class IsaacModel(PreTrainedModel): def __init__(self, config: IsaacConfig): Qwen3PreTrainedModel.__init__(self, config) - text_cfg_source = getattr(config, "get_text_config", lambda: config)() + text_cfg_source = config.text_config text_cfg = copy.deepcopy(text_cfg_source) - text_cfg._attn_implementation = config._attn_implementation self.text_model = AutoModel.from_config(text_cfg) # Ensure downstream callers observe the composed config self.text_model.config = config @@ -1268,17 +1256,7 @@ def __init__(self, config: IsaacConfig): if config.vision_config is None: raise ValueError("IsaacConfig should always have vision_config") - hidden_dim = config.vision_config.hidden_size * (config.vision_config.pixel_shuffle_scale_factor**2) - self.vision_embedding = nn.Sequential( - IsaacVisionTransformer(config.vision_config), - nn.Linear( - hidden_dim, - 4 * hidden_dim, - bias=False, - ), - nn.SiLU(), - nn.Linear(4 * hidden_dim, config.hidden_size, bias=False), - ) + self.vision_embedding = IsaacVisionEmbedding(config) # Dispatch table for TensorStream balanced embedding (text + vision) self.embed_fns = { @@ -1313,11 +1291,6 @@ def layers(self) -> nn.ModuleList: def norm(self) -> nn.Module: return self.text_model.norm - def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None): - self.text_model._set_gradient_checkpointing( - enable=enable, gradient_checkpointing_func=gradient_checkpointing_func - ) - def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: """Embed text tokens, squeezing singleton dimensions.""" # Text events are shaped as (..., 1); squeeze the singleton index dim @@ -1365,6 +1338,8 @@ def embed_stream(self, tensor_stream: TensorStream) -> torch.Tensor: h = embedded_ts.compact() # (B, T, D) return h + @auto_docstring + @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1384,12 +1359,17 @@ def forward( Forward pass with MRoPE position embeddings. Computes position embeddings once and passes them through all layers. + + Args: + tensor_stream (`TensorStream`, *optional*): + Packed multimodal stream of text and vision events to embed directly. Mutually exclusive with + `input_ids` and `inputs_embeds`. When provided, the method derives `position_ids` and `modality_tensor` + if they are not supplied. + modality_tensor (`torch.LongTensor`, *optional*): + Modality identifiers aligned with the embedded sequence, shaped `(batch_size, seq_len)` and containing + values from `TextType`/`VisionType`. Automatically built from `tensor_stream` or `input_ids` when + omitted. """ - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Get inputs if tensor_stream is not None and inputs_embeds is not None: @@ -1443,18 +1423,28 @@ def forward( sin = sin.to(inputs_embeds.dtype) # Prepare attention mask - if attention_mask is not None: - attention_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, False - ) + + if not isinstance(attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + attention_mask = create_masks_for_generate(**mask_kwargs) # Initialize hidden states hidden_states = inputs_embeds for decoder_layer in self.text_model.layers: + layer_attention_mask = ( + attention_mask[decoder_layer.attention_type] if isinstance(attention_mask, dict) else attention_mask + ) layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=layer_attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, @@ -1473,158 +1463,6 @@ def forward( past_key_values=past_key_values, ) - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - config: Qwen3Config, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Qwen3Config`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( - cache_position.reshape(-1, 1) - config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - @auto_docstring class IsaacPreTrainedModel(PreTrainedModel): @@ -1684,6 +1522,8 @@ def forward( **kwargs, ) -> tuple | CausalLMOutputWithPast: r""" + Forward pass for conditional generation supporting both standard inputs and TensorStream. + tensor_stream (`TensorStream`, *optional*): Packed multimodal stream (text, vision, audio tokens) that already encodes spatial metadata. When provided, the model derives embeddings, modality masks, and 3D rotary coordinates directly from the stream instead of diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 7c3b5010e675..289c6af9f70d 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -1,84 +1,17 @@ -# Copyright (c) 2024 Perceptron, Inc. All rights reserved. -# Perceptron, Inc. Non-Production License (2024-01-01) - - -### 1. Scope and acceptance - -# **1.1. Scope of the Agreement.** -# This Agreement applies to any use, modification, or Distribution of any Perceptron Model by You, regardless of the source You obtained a copy of such Perceptron Model. -# -# **1.2. Acceptance.** By accessing, using, modifying, Distributing a Perceptron Model, or by creating, using or distributing a Derivative of the Perceptron Model, You agree to be bound by this Agreement. -# -# **1.3. Acceptance on behalf of a third-party.** If You accept this Agreement on behalf of Your employer or another person or entity, You warrant and represent that You have the authority to act and accept this Agreement on their behalf. In such a case, the word โ€œYouโ€ in this Agreement will refer to Your employer or such other person or entity. -# -# ## 2. License -# **2.1. Grant of rights.** Subject to Section 3 below, Perceptron, Inc. hereby grants You a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable, limited license to use, copy, modify, and Distribute under the conditions provided in Section 2.2 below, the Perceptron Model and any Derivatives made by or for Perceptron, Inc. and to create Derivatives of the Perceptron Model. -# -# **2.2. Distribution of Perceptron Model and Derivatives made by or for Perceptron, Inc..** Subject to Section 3 below, You may Distribute copies of the Perceptron Model and/or Derivatives made by or for Perceptron, Inc., under the following conditions: -# - You must make available a copy of this Agreement to third-party recipients of the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. you Distribute, it being specified that any rights to use the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. shall be directly granted by Perceptron, Inc. to said third-party recipients pursuant to the Perceptron, Inc. Non-Production License agreement executed between these parties; -# - You must retain in all copies of the Perceptron Models the following attribution notice within a โ€œNoticeโ€ text file distributed as part of such copies: โ€œLicensed by Perceptron, Inc. under the Perceptron, Inc. Non-Production Licenseโ€. -# -# **2.3. Distribution of Derivatives made by or for You.** Subject to Section 3 below, You may Distribute any Derivatives made by or for You under additional or different terms and conditions, provided that: -# - In any event, the use and modification of Perceptron Model and/or Derivatives made by or for Perceptron, Inc. shall remain governed by the terms and conditions of this Agreement; -# - You include in any such Derivatives made by or for You prominent notices stating that You modified the concerned Perceptron Model; and -# - Any terms and conditions You impose on any third-party recipients relating to Derivatives made by or for You shall neither limit such third-party recipientsโ€™ use of the Perceptron Model or any Derivatives made by or for Perceptron, Inc. in accordance with the Perceptron, Inc. Non-Production License nor conflict with any of its terms and conditions. -# -# ## 3. Limitations -# **3.1. Misrepresentation.** You must not misrepresent or imply, through any means, that the Derivatives made by or for You and/or any modified version of the Perceptron Model You Distribute under your name and responsibility is an official product of Perceptron, Inc. or has been endorsed, approved or validated by Perceptron, Inc., unless You are authorized by Us to do so in writing. -# -# **3.2. Usage Limitation** -# - You shall only use the Perceptron Models and Derivatives (whether or not created by Perceptron, Inc.) for testing, research, Personal, or evaluation purposes in Non-Production Environments; -# - Subject to the foregoing, You shall not supply the Perceptron Models or Derivatives in the course of a commercial activity, whether in return for payment or free of charge, in any medium or form, including but not limited to through a hosted or managed service (e.g. SaaS, cloud instances, etc.), or behind a software layer. -# -# **3.3. Usage not permitted under this Agreement.** If You want to use a Perceptron Model or a Derivative for any purpose that is not expressly authorized under this Agreement, You must request a license from Perceptron, Inc., which Perceptron, Inc. may grant to You in Perceptron, Inc.โ€™s sole discretion. Please contact Perceptron, Inc. at the following e-mail address if You want to discuss such a license: sales@perceptron.inc -# -# ## 4. Intellectual Property -# **4.1. Trademarks.** No trademark licenses are granted under this Agreement, and in connection with the Perceptron Models, You may not use any name or mark owned by or associated with Perceptron, Inc. or any of its affiliates, except (i) as required for reasonable and customary use in describing and Distributing the Perceptron Models and Derivatives made by or for Perceptron, Inc. and (ii) for attribution purposes as required by this Agreement. -# -# **4.2. Outputs.** We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs You generate and their subsequent uses in accordance with this Agreement. -# -# **4.3. Derivatives.** By entering into this Agreement, You accept that any Derivatives that You may create or that may be created for You shall be subject to the restrictions set out in Section 3 of this Agreement. -# -# # 5. Liability -# **5.1. Limitation of liability.** In no event, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall Perceptron, Inc. be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Perceptron Models and Derivatives (including but not limited to damages for loss of data, loss of goodwill, loss of expected profit or savings, work stoppage, computer failure or malfunction, or any damage caused by malware or security breaches), even if Perceptron, Inc. has been advised of the possibility of such damages. -# -# **5.2. Indemnification.** You agree to indemnify and hold harmless Perceptron, Inc. from and against any claims, damages, or losses arising out of or related to Your use or Distribution of the Perceptron Models and Derivatives. -# -# ## 6. Warranty -# **6.1. Disclaimer.** Unless required by applicable law or agreed to in writing, Perceptron, Inc. provides the Perceptron Models and Derivatives on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. Perceptron, Inc. does not represent nor warrant that the Perceptron Models and Derivatives will be error-free, meet Your or any third partyโ€™s requirements, be secure or will allow You or any third party to achieve any kind of result or generate any kind of content. You are solely responsible for determining the appropriateness of using or Distributing the Perceptron Models and Derivatives and assume any risks associated with Your exercise of rights under this Agreement. -# -# # 7. Termination -# **7.1. Term.** This Agreement is effective as of the date of your acceptance of this Agreement or access to the concerned Perceptron Models or Derivatives and will continue until terminated in accordance with the following terms. +# coding=utf-8 +# Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved. # -# **7.2. Termination.** Perceptron, Inc. may terminate this Agreement at any time if You are in breach of this Agreement. Upon termination of this Agreement, You must cease to use all Perceptron Models and Derivatives and shall permanently delete any copy thereof. Sections 5, 6, 7 and 8 shall survive the termination of this Agreement. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# **7.3. Litigation.** If You initiate any legal action or proceedings against Us or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the Model or a Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement will immediately terminate as of the date such legal action or claim is filed or initiated. +# http://www.apache.org/licenses/LICENSE-2.0 # -# # 8. General provisions -# 8.1. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the State of Washington, without regard to its conflict of law principles. -# -# 8.2. Jurisdiction. The state and federal courts located in King County, Washington shall have exclusive jurisdiction over any dispute arising out of or relating to this Agreement, and You and We consent to personal jurisdiction and venue in such courts. -# -# **8.3. Severability.** If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. -# -# # 9. Definitions -# **โ€œAgreementโ€**: means this Perceptron, Inc. Non-Production License agreement governing the access, use, and Distribution of the Perceptron Models and Derivatives. -# -# **โ€œDerivativeโ€**: means any (i) modified version of the Perceptron Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the Perceptron Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered as Derivatives under this Agreement. -# -# **โ€œDistributionโ€**, **โ€œDistributingโ€**, **โ€œDistributeโ€** or **โ€œDistributedโ€**: means providing or making available, by any means, a copy of the Perceptron Models and/or the Derivatives as the case may be, subject to Section 3 of this Agreement. -# -# **โ€œPerceptron, Inc.โ€**, **โ€œWeโ€** or **โ€œUsโ€**: means Perceptron, Inc., a Delaware corporation with its principal place of business at 10900 NE 8th St Suite 613, Bellevue, WA 98004. -# -# **โ€œPerceptron Modelโ€**: means the foundational large language model(s), and its elements which include algorithms, software, instructed checkpoints, parameters, source code (inference code, evaluation code and, if applicable, fine-tuning code) and any other elements associated thereto made available by Perceptron, Inc. under this Agreement, including, if any, the technical documentation, manuals and instructions for the use and operation thereof. -# -# **โ€œNon-Production Environmentโ€**: means any setting, use case, or application of the Perceptron Models or Derivatives that expressly excludes live, real-world conditions, commercial operations, revenue-generating activities, or direct interactions with or impacts on end users (such as, for instance, Your employees or customers). Non-Production Environment may include, but is not limited to, any setting, use case, or application for research, development, testing, quality assurance, training, internal evaluation (other than any internal usage by employees in the context of the companyโ€™s business activities), and demonstration purposes. -# -# **โ€œOutputsโ€**: means any content generated by the operation of the Perceptron Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a Perceptron Models, such as any fine-tuned versions of the Perceptron Models, the weights, or parameters. -# -# **โ€œPersonalโ€**: means any use of a Perceptron Model or a Derivative that is (i) solely for personal, non-profit and non-commercial purposes and (ii) not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities. For illustration purposes, Personal use of a Model or a Derivative does not include any usage by individuals employed in companies in the context of their daily tasks, any activity that is intended to generate revenue, or that is performed on behalf of a commercial entity. -# -# **โ€œYouโ€**: means the individual or entity entering into this Agreement with Perceptron, Inc.. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from __future__ import annotations @@ -86,7 +19,7 @@ import math import re from collections import defaultdict -from collections.abc import Sequence +from collections.abc import Callable, Sequence from typing import Any, Optional, Union from ...utils.import_utils import ( @@ -139,7 +72,7 @@ group_streams = None -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import DynamicCache from ...configuration_utils import PretrainedConfig, layer_type_validation from ...feature_extraction_utils import BatchFeature from ...generation.utils import GenerationMixin @@ -154,7 +87,7 @@ ChannelDimension, PILImageResampling, ) -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_masks_for_generate, eager_mask, packed_sequence_mask_function, sdpa_mask from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -163,13 +96,12 @@ from ...models.qwen3.configuration_qwen3 import Qwen3Config from ...models.qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3PreTrainedModel from ...processing_utils import ProcessorMixin, Unpack -from ...tokenization_utils_base import BatchEncoding from ...utils import TensorType, auto_docstring # Vision preprocessing constants from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD -from ...utils.generic import TransformersKwargs, can_return_tuple +from ...utils.generic import TransformersKwargs, can_return_tuple, check_model_inputs from ..qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling from ..siglip2.configuration_siglip2 import Siglip2VisionConfig from ..siglip2.modeling_siglip2 import ( @@ -193,7 +125,6 @@ class IsaacVisionConfig(Siglip2VisionConfig): model_type = "isaac_vision" base_config_key = "vision_config" - _attn_implementation: Optional[str] = None def __init__( self, @@ -226,16 +157,6 @@ def __init__( # Add our custom fields self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor - if self._attn_implementation is None: - self._attn_implementation = "flash_attention_2" - - # Keep legacy and new attention implementation fields in sync - existing_attn_impl = getattr(self, "attn_implementation", None) - if existing_attn_impl is None: - self.attn_implementation = self._attn_implementation - else: - self._attn_implementation = existing_attn_impl - class IsaacImageProcessorKwargs(ImagesKwargs, total=False): patch_size: Optional[int] @@ -439,16 +360,12 @@ def _preprocess( .repeat(batch_size, 1) ) - if pixel_shuffle_scale > 1: - if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): - raise ValueError( - "Spatial dimensions must be divisible by pixel_shuffle_scale when pixel shuffle is enabled." - ) - virtual_height = height_tokens // pixel_shuffle_scale - virtual_width = width_tokens // pixel_shuffle_scale - else: - virtual_height = height_tokens - virtual_width = width_tokens + if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): + raise ValueError( + "Spatial dimensions must be divisible by pixel_shuffle_scale when pixel shuffle is enabled." + ) + virtual_height = height_tokens // pixel_shuffle_scale + virtual_width = width_tokens // pixel_shuffle_scale virtual_dim = ( torch.tensor( @@ -486,13 +403,12 @@ def _preprocess( ) -def build_document_attention_mask( - cu_seqlens: Optional[torch.Tensor], - total_tokens: int, - dtype: torch.dtype, - device: torch.device, -) -> Optional[torch.Tensor]: - """Creates an additive attention mask that blocks cross-document attention.""" +def document_mask_function_from_cu_seqlens(cu_seqlens: Optional[torch.Tensor]) -> Optional[Callable]: + """Return a mask function that blocks cross-document attention from packed ``cu_seqlens``. + + The returned callable matches the signature expected by ``masking_utils`` mask factories and + yields ``True`` only when query/key positions belong to the same packed segment. + """ if cu_seqlens is None: return None @@ -504,11 +420,10 @@ def build_document_attention_mask( if seq_sizes.numel() == 0: return None - seg_ids = torch.repeat_interleave(torch.arange(seq_sizes.numel(), device=device), seq_sizes) - block_mask = seg_ids[:, None] != seg_ids[None, :] - additive_mask = torch.zeros((total_tokens, total_tokens), dtype=dtype, device=device) - additive_mask.masked_fill_(block_mask, float("-inf")) - return additive_mask.view(1, 1, total_tokens, total_tokens) + total_tokens = int(seq_sizes.sum().item()) + seg_ids = torch.repeat_interleave(torch.arange(seq_sizes.numel(), device=cu_seqlens.device), seq_sizes) + packed_sequence_mask = seg_ids.view(1, total_tokens) + return packed_sequence_mask_function(packed_sequence_mask) def ensure_document_attention_mask( @@ -517,16 +432,26 @@ def ensure_document_attention_mask( total_tokens: int, dtype: torch.dtype, device: torch.device, -) -> Optional[torch.Tensor]: - if attention_mask is not None or cu_seqlens is None: + *, + return_mask_function: bool = False, +) -> Optional[Union[torch.Tensor, Callable]]: + """Return the provided mask, a callable mask from ``cu_seqlens``, or ``None``. + + ``return_mask_function=True`` yields a callable suitable for ``masking_utils``; otherwise + ``None`` is returned when no explicit ``attention_mask`` is provided. The legacy additive mask + has been removed in favor of the callable-based path. + """ + + if attention_mask is not None: return attention_mask - return build_document_attention_mask( - cu_seqlens=cu_seqlens, - total_tokens=total_tokens, - dtype=dtype, - device=device, - ) + if cu_seqlens is None: + return None + + if return_mask_function: + return document_mask_function_from_cu_seqlens(cu_seqlens) + + return None class IsaacVisionEmbeddings(nn.Module): @@ -708,15 +633,30 @@ def _consume_variable_length_metadata(self): self._variable_length_metadata = None return cu_seqlens, max_seqlen - def forward(self, hidden_states, attention_mask=None, **kwargs): - cu_seqlens = kwargs.pop("cu_seqlens", None) - max_seqlen = kwargs.pop("max_seqlen", None) - kwargs.pop("output_attentions", None) + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[torch.Tensor] = None, + output_attentions: bool = False, + is_causal: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + **kwargs, + ): + # Unused arguments are accepted for interface compatibility + _ = position_ids + _ = past_key_value + _ = is_causal + _ = output_attentions + kwargs.pop("output_hidden_states", None) kwargs.pop("return_dict", None) if kwargs: unexpected = ", ".join(sorted(kwargs)) raise TypeError(f"Unexpected kwargs for IsaacVisionAttention.forward: {unexpected}") + cached_cu, cached_max = self._consume_variable_length_metadata() if cu_seqlens is None: cu_seqlens = cached_cu @@ -738,7 +678,9 @@ def forward(self, hidden_states, attention_mask=None, **kwargs): k = self.k_proj(x).view(L, H, D) v = self.v_proj(x).view(L, H, D) - attn_impl = getattr(self.config, "_attn_implementation", "flash_attention_3") + resolved_key = "isaac_sdpa" + if self.config._attn_implementation != "sdpa": + resolved_key = self.ATTENTION_KEY_MAP.get(self.config._attn_implementation, resolved_key) attn_mask = ensure_document_attention_mask( attention_mask, @@ -746,10 +688,9 @@ def forward(self, hidden_states, attention_mask=None, **kwargs): L, q.dtype, q.device, + return_mask_function=True, ) - resolved_key = self.ATTENTION_KEY_MAP.get(attn_impl, attn_impl) - attn_weights = None if resolved_key in self._FLASH_IMPLS: y_lhd = self._flash_attention_forward( @@ -780,7 +721,7 @@ def forward(self, hidden_states, attention_mask=None, **kwargs): else: attention_fn = ALL_ATTENTION_FUNCTIONS.get(resolved_key) if attention_fn is None: - raise ValueError(f"Attention implementation {attn_impl} not found.") + raise ValueError(f"Attention implementation {resolved_key} not found.") query_states = q.transpose(0, 1).unsqueeze(0) key_states = k.transpose(0, 1).unsqueeze(0) @@ -866,19 +807,30 @@ def _sdpa_attention_forward( q_lhd: torch.Tensor, k_lhd: torch.Tensor, v_lhd: torch.Tensor, - attention_mask: Optional[torch.Tensor], + attention_mask: Optional[Union[torch.Tensor, Callable]], cu_seqlens: Optional[torch.Tensor], dropout: float, ) -> torch.Tensor: L = q_lhd.size(0) attn_mask = attention_mask - if attn_mask is None: - attn_mask = build_document_attention_mask( - cu_seqlens=cu_seqlens, - total_tokens=L, - dtype=q_lhd.dtype, - device=q_lhd.device, + + if callable(attn_mask): + cache_position = torch.arange(L, device=q_lhd.device, dtype=torch.long) + attn_mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=L, + kv_offset=0, + mask_function=attn_mask, + attention_mask=None, + allow_is_causal_skip=False, + allow_is_bidirectional_skip=False, + allow_torch_fix=False, + use_vmap=False, ) + # sdpa_mask returns True for allowed positions; SDPA expects True to mean "mask out" + if attn_mask is not None and attn_mask.dtype == torch.bool: + attn_mask = ~attn_mask q = q_lhd.permute(1, 0, 2).unsqueeze(0) k = k_lhd.permute(1, 0, 2).unsqueeze(0) @@ -904,15 +856,30 @@ def _eager_attention_forward( q_lhd: torch.Tensor, k_lhd: torch.Tensor, v_lhd: torch.Tensor, - attention_mask: Optional[torch.Tensor], + attention_mask: Optional[Union[torch.Tensor, Callable]], dropout: float, ) -> tuple[torch.Tensor, torch.Tensor]: + L = q_lhd.size(0) + attn_mask = attention_mask + if callable(attn_mask): + cache_position = torch.arange(L, device=q_lhd.device, dtype=torch.long) + attn_mask = eager_mask( + batch_size=1, + cache_position=cache_position, + kv_length=L, + kv_offset=0, + mask_function=attn_mask, + attention_mask=None, + allow_is_bidirectional_skip=False, + use_vmap=False, + dtype=q_lhd.dtype, + ) + if attn_mask is not None and attn_mask.dim() == 4: + attn_mask = attn_mask.squeeze(0).squeeze(0) + attn_weights = torch.matmul(q_lhd, k_lhd.transpose(1, 2)) * self.scale - if attention_mask is not None: - mask = attention_mask - if mask.dim() == 4: - mask = mask.squeeze(0).squeeze(0) - attn_weights = attn_weights + mask + if attn_mask is not None: + attn_weights = attn_weights + attn_mask attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_lhd.dtype) if dropout and self.training: @@ -1181,12 +1148,11 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): # Apply final layer normalization hidden_states = self.post_layernorm(hidden_states) - if self.pixel_shuffle_scale_factor > 1: - hidden_states = pixel_shuffle_varlen( - x=hidden_states, - token_grids=token_grids, - scale_factor=self.pixel_shuffle_scale_factor, - ) + hidden_states = pixel_shuffle_varlen( + x=hidden_states, + token_grids=token_grids, + scale_factor=self.pixel_shuffle_scale_factor, + ) # Remove the pseudo batch dimension we added earlier hidden_states = hidden_states.squeeze(0) @@ -1194,6 +1160,26 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): return hidden_states +class IsaacVisionEmbedding(nn.Module): + """Vision embedding wrapper exposing tower and projector.""" + + def __init__(self, config: IsaacConfig): + super().__init__() + vision_cfg = config.vision_config + hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2) + + self.vision_tower = IsaacVisionTransformer(vision_cfg) + self.multimodal_projector = nn.Sequential( + nn.Linear(hidden_dim, 4 * hidden_dim, bias=False), + nn.SiLU(), + nn.Linear(4 * hidden_dim, config.hidden_size, bias=False), + ) + + def forward(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + hidden_states = self.vision_tower(vision_tokens) + return self.multimodal_projector(hidden_states) + + def get_scaled_image_size( scale: float, original_size: int, @@ -1331,9 +1317,7 @@ def __init__( vision_token: str = "", **kwargs, ): - self._rope_scaling: Optional[dict[str, Any]] = None self._rope_parameters: Optional[dict[str, Any]] = None - resolved_text_config = kwargs.pop("text_config", text_config) if isinstance(resolved_text_config, Qwen3Config): text_config_kwargs = copy.deepcopy(resolved_text_config.to_dict()) @@ -1346,83 +1330,37 @@ def __init__( text_config_kwargs.update(kwargs) - legacy_rope_theta = text_config_kwargs.pop("rope_theta", None) - incoming_rope_params = text_config_kwargs.pop("rope_parameters", None) - incoming_rope_scaling = text_config_kwargs.pop("rope_scaling", None) - normalized_rope_params = incoming_rope_params or incoming_rope_scaling - if normalized_rope_params is None and legacy_rope_theta is not None: - normalized_rope_params = {"rope_type": "default", "rope_theta": legacy_rope_theta} - elif ( - normalized_rope_params is not None - and legacy_rope_theta is not None - and "rope_theta" not in normalized_rope_params - ): - normalized_rope_params = {**normalized_rope_params, "rope_theta": legacy_rope_theta} - if normalized_rope_params is not None: - text_config_kwargs["rope_parameters"] = normalized_rope_params - self.text_config = self.sub_configs["text_config"](**text_config_kwargs) - - # Normalize rope parameters on the text config (prefer rope_parameters; alias rope_scaling) - self._rope_parameters = getattr(self.text_config, "rope_parameters", None) - if self._rope_parameters is None: - self._rope_parameters = getattr(self.text_config, "rope_scaling", None) - if self._rope_parameters is None and normalized_rope_params is not None: - self._rope_parameters = normalized_rope_params - if self._rope_parameters is None: - self._rope_parameters = {"rope_type": "default"} - - try: - self.text_config.rope_parameters = self._rope_parameters - except AttributeError: - setattr(self.text_config, "rope_parameters", self._rope_parameters) - if hasattr(self.text_config, "rope_scaling"): - self.text_config.rope_scaling = self._rope_parameters - else: - try: - setattr(self.text_config, "rope_scaling", self._rope_parameters) - except Exception: - pass + if not hasattr(self.text_config, "rope_theta"): + rope_theta_override = text_config_kwargs.get("rope_theta", kwargs.get("rope_theta")) + if rope_theta_override is None: + rope_theta_override = getattr(Qwen3Config(), "rope_theta", 10000.0) + self.text_config.rope_theta = rope_theta_override super().__init__(**kwargs) + if self._rope_scaling is None: + self._rope_scaling = getattr(self.text_config, "rope_scaling", None) + else: + self.text_config.rope_scaling = self._rope_scaling + # Keep rope parameters alias in sync with upstream expectations - self._rope_scaling = self._rope_parameters + self._rope_parameters = self._rope_scaling # Mirror frequently accessed Qwen3 attributes at the composite config level for BC. - self.tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False) self.vocab_size = self.text_config.vocab_size - self.max_position_embeddings = self.text_config.max_position_embeddings self.hidden_size = self.text_config.hidden_size - self.intermediate_size = self.text_config.intermediate_size self.num_hidden_layers = self.text_config.num_hidden_layers self.num_attention_heads = self.text_config.num_attention_heads - self.use_sliding_window = getattr(self.text_config, "use_sliding_window", False) - sliding_window = getattr(self.text_config, "sliding_window", None) - self.sliding_window = sliding_window if self.use_sliding_window else None - self.max_window_layers = getattr(self.text_config, "max_window_layers", None) - self.num_key_value_heads = getattr(self.text_config, "num_key_value_heads", None) - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads self.head_dim = self.text_config.head_dim self.hidden_act = self.text_config.hidden_act - self.initializer_range = self.text_config.initializer_range - self.rms_norm_eps = self.text_config.rms_norm_eps self.use_cache = self.text_config.use_cache - self.attention_bias = getattr(self.text_config, "attention_bias", False) - self.attention_dropout = getattr(self.text_config, "attention_dropout", 0.0) + self.rope_theta = self.text_config.rope_theta # Validate rotary parameters now that they have been mirrored locally. rope_config_validation(self) self.layer_types = getattr(self.text_config, "layer_types", None) - if self.layer_types is None: - self.layer_types = [ - "sliding_attention" - if self.sliding_window is not None and i >= self.max_window_layers - else "full_attention" - for i in range(self.num_hidden_layers) - ] layer_type_validation(self.layer_types, self.num_hidden_layers) # Handle vision config - either dict or IsaacVisionConfig instance @@ -1440,42 +1378,24 @@ def __init__( self.max_sequence_length = max_sequence_length self.vision_token = vision_token - def get_text_config(self, *_, **kwargs) -> Qwen3Config: - # Accept optional decoder/encoder flags to align with HF composite configs - kwargs.pop("decoder", None) - kwargs.pop("encoder", None) - return self.text_config - @property def rope_scaling(self): if hasattr(self, "text_config") and self.text_config is not None: - return getattr(self.text_config, "rope_parameters", None) or getattr( - self.text_config, "rope_scaling", None - ) - return self._rope_parameters + return getattr(self.text_config, "rope_scaling", None) + return self._rope_scaling @rope_scaling.setter def rope_scaling(self, value): - self._rope_parameters = value self._rope_scaling = value if hasattr(self, "text_config") and self.text_config is not None: - try: - self.text_config.rope_parameters = value - except AttributeError: - setattr(self.text_config, "rope_parameters", value) - try: - self.text_config.rope_scaling = value - except AttributeError: - pass + self.text_config.rope_scaling = value @property def rope_parameters(self) -> dict[str, Any] | None: """Alias introduced upstream for rope scaling dictionaries.""" value = self._rope_parameters - if value is None and hasattr(self, "text_config") and self.text_config is not None: - value = getattr(self.text_config, "rope_parameters", None) or getattr( - self.text_config, "rope_scaling", None - ) + if value is None: + value = self.rope_scaling if value is None: return {"rope_type": "default"} return value @@ -1483,37 +1403,13 @@ def rope_parameters(self) -> dict[str, Any] | None: @rope_parameters.setter def rope_parameters(self, value: dict[str, Any] | None) -> None: self._rope_parameters = value - self._rope_scaling = value self.rope_scaling = value - @property - def vision_attn_implementation(self) -> Optional[str]: - value = getattr(self.vision_config, "_attn_implementation", None) - if value is None: - value = getattr(self.vision_config, "attn_implementation", None) - return value - - @vision_attn_implementation.setter - def vision_attn_implementation(self, value: Optional[str]) -> None: - self.vision_config._attn_implementation = value - if value is not None: - self.vision_config.attn_implementation = value - elif hasattr(self.vision_config, "attn_implementation"): - delattr(self.vision_config, "attn_implementation") - def to_dict(self): output = super().to_dict() - rope_params = self.rope_parameters - output["rope_parameters"] = rope_params - output.pop("rope_scaling", None) - output.pop("rope_theta", None) # Ensure nested configs round-trip through dict serialization if hasattr(self, "text_config") and self.text_config is not None: - text_config_dict = self.text_config.to_dict() - text_config_dict.pop("rope_theta", None) - text_config_dict.pop("rope_scaling", None) - text_config_dict["rope_parameters"] = rope_params - output["text_config"] = text_config_dict + output["text_config"] = self.text_config.to_dict() if hasattr(self, "vision_config") and self.vision_config is not None: output["vision_config"] = self.vision_config.to_dict() return output @@ -1673,26 +1569,8 @@ def __call__( Returns: BatchFeature with input_ids and tensor_stream """ - # Normalize inputs to lists and support BatchEncoding (v5 apply_chat_template default) - encoding_input = None - if isinstance(text, BatchEncoding): - encoding_input = text - elif isinstance(text, dict) and "input_ids" in text: - encoding_input = BatchEncoding(text) - - if encoding_input is not None: - input_ids_field = encoding_input["input_ids"] - if isinstance(input_ids_field, torch.Tensor): - ids = input_ids_field - else: - ids = torch.tensor(input_ids_field) - if ids.ndim == 1: - ids = ids.unsqueeze(0) - if ids.size(0) != 1: - raise ValueError("IsaacProcessor currently supports batch_size=1 for chat templates.") - decoded_text = self.tokenizer.decode(ids[0].tolist(), skip_special_tokens=False) - texts = [decoded_text] - elif isinstance(text, str): + # Normalize inputs to lists + if isinstance(text, str): texts = [text] else: texts = text @@ -1775,29 +1653,17 @@ def __init__(self, config: IsaacConfig, device=None): super().__init__() rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config - rope_params = ( - getattr(rope_source_cfg, "rope_parameters", None) or getattr(rope_source_cfg, "rope_scaling", None) or {} - ) - legacy_rope_theta = getattr(rope_source_cfg, "rope_theta", None) - if legacy_rope_theta is not None and isinstance(rope_params, dict) and "rope_theta" not in rope_params: - rope_params = {**rope_params, "rope_theta": legacy_rope_theta} + rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} - sanitized_params = {k: v for k, v in rope_params.items() if k not in self.EXTRA_ROPE_KEYS} + sanitized_scaling = {k: v for k, v in rope_scaling.items() if k not in self.EXTRA_ROPE_KEYS} config_for_rope = copy.copy(rope_source_cfg) - config_for_rope.rope_parameters = sanitized_params if sanitized_params else None - if hasattr(config_for_rope, "rope_scaling"): - config_for_rope.rope_scaling = sanitized_params if sanitized_params else None - if hasattr(config_for_rope, "rope_theta"): - try: - delattr(config_for_rope, "rope_theta") - except Exception: - config_for_rope.rope_theta = None + config_for_rope.rope_scaling = sanitized_scaling if sanitized_scaling else None init_device = device if device is not None and getattr(device, "type", None) != "meta" else None self._qwen_rotary = qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding(config_for_rope, device=init_device) rotary_half_dim = self._qwen_rotary.inv_freq.shape[0] - self.mrope_section = self._resolve_mrope_section(rope_params.get("mrope_section"), rotary_half_dim) + self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim) self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size @staticmethod @@ -1873,9 +1739,8 @@ class IsaacModel(Qwen3PreTrainedModel): def __init__(self, config: IsaacConfig): Qwen3PreTrainedModel.__init__(self, config) - text_cfg_source = getattr(config, "get_text_config", lambda: config)() + text_cfg_source = config.text_config text_cfg = copy.deepcopy(text_cfg_source) - text_cfg._attn_implementation = config._attn_implementation self.text_model = AutoModel.from_config(text_cfg) # Ensure downstream callers observe the composed config self.text_model.config = config @@ -1885,17 +1750,7 @@ def __init__(self, config: IsaacConfig): if config.vision_config is None: raise ValueError("IsaacConfig should always have vision_config") - hidden_dim = config.vision_config.hidden_size * (config.vision_config.pixel_shuffle_scale_factor**2) - self.vision_embedding = nn.Sequential( - IsaacVisionTransformer(config.vision_config), - nn.Linear( - hidden_dim, - 4 * hidden_dim, - bias=False, - ), - nn.SiLU(), - nn.Linear(4 * hidden_dim, config.hidden_size, bias=False), - ) + self.vision_embedding = IsaacVisionEmbedding(config) # Dispatch table for TensorStream balanced embedding (text + vision) self.embed_fns = { @@ -1930,11 +1785,6 @@ def layers(self) -> nn.ModuleList: def norm(self) -> nn.Module: return self.text_model.norm - def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None): - self.text_model._set_gradient_checkpointing( - enable=enable, gradient_checkpointing_func=gradient_checkpointing_func - ) - def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: """Embed text tokens, squeezing singleton dimensions.""" # Text events are shaped as (..., 1); squeeze the singleton index dim @@ -1982,6 +1832,8 @@ def embed_stream(self, tensor_stream: TensorStream) -> torch.Tensor: h = embedded_ts.compact() # (B, T, D) return h + @auto_docstring + @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -2001,12 +1853,17 @@ def forward( Forward pass with MRoPE position embeddings. Computes position embeddings once and passes them through all layers. + + Args: + tensor_stream (`TensorStream`, *optional*): + Packed multimodal stream of text and vision events to embed directly. Mutually exclusive with + `input_ids` and `inputs_embeds`. When provided, the method derives `position_ids` and `modality_tensor` + if they are not supplied. + modality_tensor (`torch.LongTensor`, *optional*): + Modality identifiers aligned with the embedded sequence, shaped `(batch_size, seq_len)` and containing + values from `TextType`/`VisionType`. Automatically built from `tensor_stream` or `input_ids` when + omitted. """ - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Get inputs if tensor_stream is not None and inputs_embeds is not None: @@ -2060,18 +1917,28 @@ def forward( sin = sin.to(inputs_embeds.dtype) # Prepare attention mask - if attention_mask is not None: - attention_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, False - ) + + if not isinstance(attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + attention_mask = create_masks_for_generate(**mask_kwargs) # Initialize hidden states hidden_states = inputs_embeds for decoder_layer in self.text_model.layers: + layer_attention_mask = ( + attention_mask[decoder_layer.attention_type] if isinstance(attention_mask, dict) else attention_mask + ) layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=layer_attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, @@ -2090,158 +1957,6 @@ def forward( past_key_values=past_key_values, ) - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - config: Qwen3Config, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Qwen3Config`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( - cache_position.reshape(-1, 1) - config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): """Isaac multimodal model for conditional generation.""" @@ -2305,6 +2020,8 @@ def forward( **kwargs, ) -> tuple | CausalLMOutputWithPast: r""" + Forward pass for conditional generation supporting both standard inputs and TensorStream. + tensor_stream (`TensorStream`, *optional*): Packed multimodal stream (text, vision, audio tokens) that already encodes spatial metadata. When provided, the model derives embeddings, modality masks, and 3D rotary coordinates directly from the stream instead of diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index cf5dbd35dc5c..3a2792792449 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -4,87 +4,20 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# Copyright (c) 2024 Perceptron, Inc. All rights reserved. -# Perceptron, Inc. Non-Production License (2024-01-01) - - -### 1. Scope and acceptance - -# **1.1. Scope of the Agreement.** -# This Agreement applies to any use, modification, or Distribution of any Perceptron Model by You, regardless of the source You obtained a copy of such Perceptron Model. -# -# **1.2. Acceptance.** By accessing, using, modifying, Distributing a Perceptron Model, or by creating, using or distributing a Derivative of the Perceptron Model, You agree to be bound by this Agreement. -# -# **1.3. Acceptance on behalf of a third-party.** If You accept this Agreement on behalf of Your employer or another person or entity, You warrant and represent that You have the authority to act and accept this Agreement on their behalf. In such a case, the word โ€œYouโ€ in this Agreement will refer to Your employer or such other person or entity. -# -# ## 2. License -# **2.1. Grant of rights.** Subject to Section 3 below, Perceptron, Inc. hereby grants You a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable, limited license to use, copy, modify, and Distribute under the conditions provided in Section 2.2 below, the Perceptron Model and any Derivatives made by or for Perceptron, Inc. and to create Derivatives of the Perceptron Model. -# -# **2.2. Distribution of Perceptron Model and Derivatives made by or for Perceptron, Inc..** Subject to Section 3 below, You may Distribute copies of the Perceptron Model and/or Derivatives made by or for Perceptron, Inc., under the following conditions: -# - You must make available a copy of this Agreement to third-party recipients of the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. you Distribute, it being specified that any rights to use the Perceptron Models and/or Derivatives made by or for Perceptron, Inc. shall be directly granted by Perceptron, Inc. to said third-party recipients pursuant to the Perceptron, Inc. Non-Production License agreement executed between these parties; -# - You must retain in all copies of the Perceptron Models the following attribution notice within a โ€œNoticeโ€ text file distributed as part of such copies: โ€œLicensed by Perceptron, Inc. under the Perceptron, Inc. Non-Production Licenseโ€. -# -# **2.3. Distribution of Derivatives made by or for You.** Subject to Section 3 below, You may Distribute any Derivatives made by or for You under additional or different terms and conditions, provided that: -# - In any event, the use and modification of Perceptron Model and/or Derivatives made by or for Perceptron, Inc. shall remain governed by the terms and conditions of this Agreement; -# - You include in any such Derivatives made by or for You prominent notices stating that You modified the concerned Perceptron Model; and -# - Any terms and conditions You impose on any third-party recipients relating to Derivatives made by or for You shall neither limit such third-party recipientsโ€™ use of the Perceptron Model or any Derivatives made by or for Perceptron, Inc. in accordance with the Perceptron, Inc. Non-Production License nor conflict with any of its terms and conditions. -# -# ## 3. Limitations -# **3.1. Misrepresentation.** You must not misrepresent or imply, through any means, that the Derivatives made by or for You and/or any modified version of the Perceptron Model You Distribute under your name and responsibility is an official product of Perceptron, Inc. or has been endorsed, approved or validated by Perceptron, Inc., unless You are authorized by Us to do so in writing. -# -# **3.2. Usage Limitation** -# - You shall only use the Perceptron Models and Derivatives (whether or not created by Perceptron, Inc.) for testing, research, Personal, or evaluation purposes in Non-Production Environments; -# - Subject to the foregoing, You shall not supply the Perceptron Models or Derivatives in the course of a commercial activity, whether in return for payment or free of charge, in any medium or form, including but not limited to through a hosted or managed service (e.g. SaaS, cloud instances, etc.), or behind a software layer. -# -# **3.3. Usage not permitted under this Agreement.** If You want to use a Perceptron Model or a Derivative for any purpose that is not expressly authorized under this Agreement, You must request a license from Perceptron, Inc., which Perceptron, Inc. may grant to You in Perceptron, Inc.โ€™s sole discretion. Please contact Perceptron, Inc. at the following e-mail address if You want to discuss such a license: sales@perceptron.inc -# -# ## 4. Intellectual Property -# **4.1. Trademarks.** No trademark licenses are granted under this Agreement, and in connection with the Perceptron Models, You may not use any name or mark owned by or associated with Perceptron, Inc. or any of its affiliates, except (i) as required for reasonable and customary use in describing and Distributing the Perceptron Models and Derivatives made by or for Perceptron, Inc. and (ii) for attribution purposes as required by this Agreement. -# -# **4.2. Outputs.** We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs You generate and their subsequent uses in accordance with this Agreement. -# -# **4.3. Derivatives.** By entering into this Agreement, You accept that any Derivatives that You may create or that may be created for You shall be subject to the restrictions set out in Section 3 of this Agreement. -# -# # 5. Liability -# **5.1. Limitation of liability.** In no event, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall Perceptron, Inc. be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this Agreement or out of the use or inability to use the Perceptron Models and Derivatives (including but not limited to damages for loss of data, loss of goodwill, loss of expected profit or savings, work stoppage, computer failure or malfunction, or any damage caused by malware or security breaches), even if Perceptron, Inc. has been advised of the possibility of such damages. +# coding=utf-8 +# Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved. # -# **5.2. Indemnification.** You agree to indemnify and hold harmless Perceptron, Inc. from and against any claims, damages, or losses arising out of or related to Your use or Distribution of the Perceptron Models and Derivatives. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# ## 6. Warranty -# **6.1. Disclaimer.** Unless required by applicable law or agreed to in writing, Perceptron, Inc. provides the Perceptron Models and Derivatives on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. Perceptron, Inc. does not represent nor warrant that the Perceptron Models and Derivatives will be error-free, meet Your or any third partyโ€™s requirements, be secure or will allow You or any third party to achieve any kind of result or generate any kind of content. You are solely responsible for determining the appropriateness of using or Distributing the Perceptron Models and Derivatives and assume any risks associated with Your exercise of rights under this Agreement. +# http://www.apache.org/licenses/LICENSE-2.0 # -# # 7. Termination -# **7.1. Term.** This Agreement is effective as of the date of your acceptance of this Agreement or access to the concerned Perceptron Models or Derivatives and will continue until terminated in accordance with the following terms. -# -# **7.2. Termination.** Perceptron, Inc. may terminate this Agreement at any time if You are in breach of this Agreement. Upon termination of this Agreement, You must cease to use all Perceptron Models and Derivatives and shall permanently delete any copy thereof. Sections 5, 6, 7 and 8 shall survive the termination of this Agreement. -# -# **7.3. Litigation.** If You initiate any legal action or proceedings against Us or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the Model or a Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement will immediately terminate as of the date such legal action or claim is filed or initiated. -# -# # 8. General provisions -# 8.1. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the State of Washington, without regard to its conflict of law principles. -# -# 8.2. Jurisdiction. The state and federal courts located in King County, Washington shall have exclusive jurisdiction over any dispute arising out of or relating to this Agreement, and You and We consent to personal jurisdiction and venue in such courts. -# -# **8.3. Severability.** If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. -# -# # 9. Definitions -# **โ€œAgreementโ€**: means this Perceptron, Inc. Non-Production License agreement governing the access, use, and Distribution of the Perceptron Models and Derivatives. -# -# **โ€œDerivativeโ€**: means any (i) modified version of the Perceptron Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the Perceptron Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered as Derivatives under this Agreement. -# -# **โ€œDistributionโ€**, **โ€œDistributingโ€**, **โ€œDistributeโ€** or **โ€œDistributedโ€**: means providing or making available, by any means, a copy of the Perceptron Models and/or the Derivatives as the case may be, subject to Section 3 of this Agreement. -# -# **โ€œPerceptron, Inc.โ€**, **โ€œWeโ€** or **โ€œUsโ€**: means Perceptron, Inc., a Delaware corporation with its principal place of business at 10900 NE 8th St Suite 613, Bellevue, WA 98004. -# -# **โ€œPerceptron Modelโ€**: means the foundational large language model(s), and its elements which include algorithms, software, instructed checkpoints, parameters, source code (inference code, evaluation code and, if applicable, fine-tuning code) and any other elements associated thereto made available by Perceptron, Inc. under this Agreement, including, if any, the technical documentation, manuals and instructions for the use and operation thereof. -# -# **โ€œNon-Production Environmentโ€**: means any setting, use case, or application of the Perceptron Models or Derivatives that expressly excludes live, real-world conditions, commercial operations, revenue-generating activities, or direct interactions with or impacts on end users (such as, for instance, Your employees or customers). Non-Production Environment may include, but is not limited to, any setting, use case, or application for research, development, testing, quality assurance, training, internal evaluation (other than any internal usage by employees in the context of the companyโ€™s business activities), and demonstration purposes. -# -# **โ€œOutputsโ€**: means any content generated by the operation of the Perceptron Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a Perceptron Models, such as any fine-tuned versions of the Perceptron Models, the weights, or parameters. -# -# **โ€œPersonalโ€**: means any use of a Perceptron Model or a Derivative that is (i) solely for personal, non-profit and non-commercial purposes and (ii) not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities. For illustration purposes, Personal use of a Model or a Derivative does not include any usage by individuals employed in companies in the context of their daily tasks, any activity that is intended to generate revenue, or that is performed on behalf of a commercial entity. -# -# **โ€œYouโ€**: means the individual or entity entering into this Agreement with Perceptron, Inc.. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import math import re @@ -93,7 +26,6 @@ from ...feature_extraction_utils import BatchFeature from ...models.auto.tokenization_auto import AutoTokenizer from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding from ...utils import TensorType from ...utils.import_utils import is_perceptron_available, is_torch_available, is_vision_available from .configuration_isaac import IsaacConfig @@ -278,26 +210,8 @@ def __call__( Returns: BatchFeature with input_ids and tensor_stream """ - # Normalize inputs to lists and support BatchEncoding (v5 apply_chat_template default) - encoding_input = None - if isinstance(text, BatchEncoding): - encoding_input = text - elif isinstance(text, dict) and "input_ids" in text: - encoding_input = BatchEncoding(text) - - if encoding_input is not None: - input_ids_field = encoding_input["input_ids"] - if isinstance(input_ids_field, torch.Tensor): - ids = input_ids_field - else: - ids = torch.tensor(input_ids_field) - if ids.ndim == 1: - ids = ids.unsqueeze(0) - if ids.size(0) != 1: - raise ValueError("IsaacProcessor currently supports batch_size=1 for chat templates.") - decoded_text = self.tokenizer.decode(ids[0].tolist(), skip_special_tokens=False) - texts = [decoded_text] - elif isinstance(text, str): + # Normalize inputs to lists + if isinstance(text, str): texts = [text] else: texts = text diff --git a/tests/models/isaac/isaac_checkpoint_hashes.json b/tests/fixtures/isaac/isaac_checkpoint_hashes.json similarity index 100% rename from tests/models/isaac/isaac_checkpoint_hashes.json rename to tests/fixtures/isaac/isaac_checkpoint_hashes.json diff --git a/tests/models/isaac/isaac_generation_golden.json b/tests/fixtures/isaac/isaac_generation_golden.json similarity index 100% rename from tests/models/isaac/isaac_generation_golden.json rename to tests/fixtures/isaac/isaac_generation_golden.json diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py index bcf928c882e9..a353d4df9d59 100644 --- a/tests/models/isaac/test_modeling_isaac.py +++ b/tests/models/isaac/test_modeling_isaac.py @@ -1,3 +1,19 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Testing suite for the Isaac model.""" + import base64 import hashlib import io @@ -19,11 +35,24 @@ PythonBackend, is_torch_available, ) +from transformers.image_utils import load_image +from transformers.masking_utils import eager_mask, sdpa_mask from transformers.models.isaac.configuration_isaac import IsaacVisionConfig from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessorFast -from transformers.models.isaac.modeling_isaac import IsaacVisionAttention +from transformers.models.isaac.modeling_isaac import ( + IsaacVisionAttention, + document_mask_function_from_cu_seqlens, + ensure_document_attention_mask, +) from transformers.models.isaac.processing_isaac import IsaacProcessor -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import ( + get_tests_dir, + require_flash_attn, + require_torch, + require_vision, + slow, + torch_device, +) from transformers.utils import is_vision_available from transformers.utils.import_utils import is_perceptron_available @@ -33,7 +62,6 @@ else: Image = None -from ...test_configuration_common import ConfigTester from ...test_modeling_common import ids_tensor @@ -52,8 +80,9 @@ MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1-Base") MODEL_REVISION = os.environ.get("ISAAC_TEST_MODEL_REVISION", "refs/pr/3") or None LOCAL_CHECKPOINT = os.environ.get("ISAAC_TEST_MODEL_PATH") -HASH_FILE = Path(__file__).with_name("isaac_checkpoint_hashes.json") -GENERATION_GOLDEN_FILE = Path(__file__).with_name("isaac_generation_golden.json") +FIXTURES_DIR = Path(get_tests_dir("fixtures/isaac")) +HASH_FILE = FIXTURES_DIR / "isaac_checkpoint_hashes.json" +GENERATION_GOLDEN_FILE = FIXTURES_DIR / "isaac_generation_golden.json" HASH_FILTERS = { "full_model": {"include": None, "exclude": None}, "core_model": {"include": None, "exclude": {"vision_embedding", "audio_embedding", "inv_freq"}}, @@ -77,11 +106,47 @@ def tensor_stream_snapshot(ts: TensorStream) -> dict[str, object]: } -def _assert_tensor_stream_snapshot_equal(actual: dict[str, object], expected: dict[str, object]) -> None: - assert actual["shape"] == expected["shape"], "TensorStream shape changed" - assert actual["token_view"] == expected["token_view"], "TensorStream token view changed" - assert actual["modality_mask"] == expected["modality_mask"], "TensorStream modality mask changed" - assert actual["role_mask"] == expected["role_mask"], "TensorStream role mask changed" +def document_to_messages( + document: list[dict], vision_token: str = "" +) -> tuple[list[dict[str, str]], list[Image]]: + """ + Convert a Document to messages format compatible with chat templates. + Each content turn creates its own message entry. + + Args: + document: list of dicts containing Text and/or Image content + vision_token: Token to use for image placeholder + + Returns: + Tuple of (messages, images) where messages is a list of dicts with 'role' and 'content' + """ + messages = [] + images = [] + + for item in document: + itype = item.get("type") + if itype == "text": + content = item.get("content") + if content: + messages.append( + { + "role": item.get("role", "user"), + "content": content, + } + ) + elif itype == "image": + content = item.get("content") + if content: + img = load_image(content) + images.append(img) + messages.append( + { + "role": item.get("role", "user"), + "content": vision_token, + } + ) + + return messages, images def _tensor_to_bytes(tensor): @@ -171,6 +236,147 @@ def tokenizer(isaac_reference_checkpoint): ) +@require_torch +def test_document_mask_function_from_cu_seqlens(): + cu_seqlens = torch.tensor([0, 3, 5], dtype=torch.int32) + mask_fn = document_mask_function_from_cu_seqlens(cu_seqlens) + + assert mask_fn is not None + # Same document (indices 1 and 2) + assert mask_fn(0, 0, 1, 2) + # Cross-document (index 1 in first doc, 3 in second doc) + assert not mask_fn(0, 0, 1, 3) + # Same second document (indices 3 and 4) + assert mask_fn(0, 0, 4, 3) + + +@require_torch +def test_ensure_document_attention_mask_prefers_callable_when_requested(): + cu_seqlens = torch.tensor([0, 2, 5], dtype=torch.int32) + total_tokens = 5 + dtype = torch.float32 + + mask_callable = ensure_document_attention_mask( + attention_mask=None, + cu_seqlens=cu_seqlens, + total_tokens=total_tokens, + dtype=dtype, + device=cu_seqlens.device, + return_mask_function=True, + ) + assert callable(mask_callable) + + additive = ensure_document_attention_mask( + attention_mask=None, + cu_seqlens=cu_seqlens, + total_tokens=total_tokens, + dtype=dtype, + device=cu_seqlens.device, + return_mask_function=False, + ) + assert additive is None + + +def create_isaac_processor( + tokenizer, + isaac_config, + *, + image_processor=None, + **overrides, +): + """Helper to construct IsaacProcessor without requiring an IsaacConfig instance.""" + params = { + "vision_token": isaac_config.vision_token, + "max_sequence_length": isaac_config.max_sequence_length, + "vision_patch_size": isaac_config.vision_patch_size, + "vision_max_num_patches": isaac_config.vision_max_num_patches, + "vision_min_num_patches": isaac_config.vision_min_num_patches, + "pixel_shuffle_scale": isaac_config.pixel_shuffle_scale, + "rescale_factor": isaac_config.vision_rescale_factor, + "image_mean": tuple(isaac_config.vision_mean), + "image_std": tuple(isaac_config.vision_std), + } + params.update(overrides) + + processor_image = image_processor + if processor_image is None: + processor_image = IsaacImageProcessorFast( + patch_size=params["vision_patch_size"], + max_num_patches=params["vision_max_num_patches"], + min_num_patches=params["vision_min_num_patches"], + pixel_shuffle_scale=params["pixel_shuffle_scale"], + rescale_factor=params["rescale_factor"], + image_mean=params["image_mean"], + image_std=params["image_std"], + ) + processor_params = { + "vision_token": isaac_config.vision_token, + "max_sequence_length": isaac_config.max_sequence_length, + "rescale_factor": isaac_config.vision_rescale_factor, + } + + return IsaacProcessor( + image_processor=processor_image, + tokenizer=tokenizer, + **processor_params, + ) + + +@require_torch +def test_document_mask_function_materializes_with_masking_utils(): + cu_seqlens = torch.tensor([0, 2, 4], dtype=torch.int32) + total_tokens = 4 + mask_fn = document_mask_function_from_cu_seqlens(cu_seqlens) + + cache_position = torch.arange(total_tokens, device=cu_seqlens.device, dtype=torch.long) + expected_bool = torch.tensor( + [ + [ + [ + [True, True, False, False], + [True, True, False, False], + [False, False, True, True], + [False, False, True, True], + ] + ] + ], + device=cu_seqlens.device, + ) + + sdpa = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=total_tokens, + kv_offset=0, + mask_function=mask_fn, + attention_mask=None, + allow_is_causal_skip=False, + allow_is_bidirectional_skip=False, + allow_torch_fix=False, + use_vmap=False, + ) + # sdpa_mask returns True for allowed positions; SDPA expects True to mean "mask out" + assert torch.equal(sdpa, expected_bool) + + eager = eager_mask( + batch_size=1, + cache_position=cache_position, + kv_length=total_tokens, + kv_offset=0, + mask_function=mask_fn, + attention_mask=None, + allow_is_bidirectional_skip=False, + use_vmap=False, + dtype=torch.float32, + ) + expected_additive = torch.where( + expected_bool, + torch.tensor(0.0, device=cu_seqlens.device, dtype=torch.float32), + torch.tensor(torch.finfo(torch.float32).min, device=cu_seqlens.device, dtype=torch.float32), + ) + assert torch.equal(eager, expected_additive) + + @require_torch def test_isaac_sdpa_attention_backend(): config = IsaacVisionConfig( @@ -200,10 +406,34 @@ def test_isaac_sdpa_attention_backend(): assert attn_weights is None -def _hash_tensor(tensor): - hasher = hashlib.sha256() - hasher.update(_tensor_to_bytes(tensor)) - return hasher.hexdigest() +@require_torch +@require_flash_attn +def test_isaac_flash_attention_backend(): + config = IsaacVisionConfig( + hidden_size=32, + intermediate_size=64, + num_hidden_layers=1, + num_attention_heads=4, + num_channels=3, + num_patches=16, + patch_size=4, + ) + config._attn_implementation = "flash_attention_3" + + attn_module = IsaacVisionAttention(config).half().eval().cuda() + seq_len = 8 + hidden_states = torch.randn(1, seq_len, config.hidden_size, device=torch.device("cuda"), dtype=torch.float16) + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=torch.device("cuda")) + + with torch.no_grad(): + outputs, attn_weights = attn_module( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=seq_len, + ) + + assert outputs.shape == hidden_states.shape + assert attn_weights is None @lru_cache(maxsize=1) @@ -469,17 +699,6 @@ class IsaacModelTest(unittest.TestCase): def setUp(self): self.model_tester = IsaacModelTester(self) - self.config_tester = ConfigTester( - self, - config_class=IsaacConfig, - has_text_modality=True, - common_properties=["hidden_size"], - text_config=self.model_tester.text_config, - vision_config=self.model_tester.vision_config, - ) - - def test_config(self): - self.config_tester.run_common_tests() @require_tensorstream def test_model_forward(self): @@ -533,16 +752,6 @@ def test_isaac_config_extends_qwen3_defaults(isaac_tiny_config): assert isaac_tiny_config.vision_token == "" -def test_isaac_config_migrates_legacy_rope_theta(): - cfg = IsaacConfig(text_config={"rope_theta": 12345}) - assert cfg.rope_parameters.get("rope_theta") == 12345 - assert cfg.rope_parameters.get("rope_type") == "default" - serialized = cfg.to_dict() - assert "rope_theta" not in serialized - assert "rope_theta" not in serialized.get("text_config", {}) - assert serialized["rope_parameters"].get("rope_theta") == 12345 - - @require_torch @require_tensorstream def test_isaac_for_conditional_generation_initialization(isaac_tiny_config): @@ -698,73 +907,6 @@ def test_isaac_generation_with_tensor_stream(isaac_processor, isaac_tiny_config) assert decoded_prompt.strip() != "" -@require_torch -@slow -@require_tensorstream -def test_isaac_checkpoint_hashes(isaac_reference_model): - isaac_reference_model = isaac_reference_model.to("cpu") - expected_hashes = _load_expected_hashes() - if not expected_hashes: - pytest.skip(f"Missing golden hashes file at {HASH_FILE}.") - - missing = [subset for subset in HASH_FILTERS if subset not in expected_hashes] - if missing: - pytest.skip(f"Golden hashes missing entries for: {', '.join(missing)}") - - isaac_reference_model.to("cpu") - state_dict = isaac_reference_model.state_dict() - for subset, filters in HASH_FILTERS.items(): - current_hash = _hash_state_dict(state_dict, include=filters["include"], exclude=filters["exclude"]) - assert current_hash == expected_hashes[subset], f"Hash mismatch for subset '{subset}'" - - -def create_isaac_processor( - tokenizer, - isaac_config, - *, - image_processor=None, - **overrides, -): - """Helper to construct IsaacProcessor without requiring an IsaacConfig instance.""" - params = { - "vision_token": isaac_config.vision_token, - "max_sequence_length": isaac_config.max_sequence_length, - "vision_patch_size": isaac_config.vision_patch_size, - "vision_max_num_patches": isaac_config.vision_max_num_patches, - "vision_min_num_patches": isaac_config.vision_min_num_patches, - "pixel_shuffle_scale": isaac_config.pixel_shuffle_scale, - "rescale_factor": isaac_config.vision_rescale_factor, - "image_mean": tuple(isaac_config.vision_mean), - "image_std": tuple(isaac_config.vision_std), - "vision_attn_implementation": isaac_config.vision_attn_implementation, - } - params.update(overrides) - - processor_image = image_processor - if processor_image is None: - processor_image = IsaacImageProcessorFast( - patch_size=params["vision_patch_size"], - max_num_patches=params["vision_max_num_patches"], - min_num_patches=params["vision_min_num_patches"], - pixel_shuffle_scale=params["pixel_shuffle_scale"], - rescale_factor=params["rescale_factor"], - image_mean=params["image_mean"], - image_std=params["image_std"], - ) - processor_params = { - "vision_token": isaac_config.vision_token, - "max_sequence_length": isaac_config.max_sequence_length, - "rescale_factor": isaac_config.vision_rescale_factor, - } - - return IsaacProcessor( - image_processor=processor_image, - tokenizer=tokenizer, - config=isaac_config, - **processor_params, - ) - - @require_torch @require_vision @slow @@ -833,3 +975,93 @@ def test_hf_generate_vs_training_generate_logits(isaac_reference_model, isaac_re ) isaac_reference_model.to("cpu") + + +@require_torch +@require_vision +@slow +@require_tensorstream +@require_flash_attn +class IsaacGenerationIntegrationTest(unittest.TestCase): + max_new_tokens = 25 + dtype = torch.bfloat16 + + def setUp(self): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.checkpoint = _reference_checkpoint_or_skip() + self.hf_config = IsaacConfig.from_pretrained(self.checkpoint, revision=MODEL_REVISION) + self.tokenizer = AutoTokenizer.from_pretrained( + self.checkpoint, trust_remote_code=True, use_fast=False, revision=MODEL_REVISION + ) + self.processor = create_isaac_processor(self.tokenizer, self.hf_config) + self.hf_config.vision_config._attn_implementation = "flash_attention_2" + self.hf_config.vision_config.attn_implementation = "flash_attention_2" + self.model = IsaacForConditionalGeneration.from_pretrained( + self.checkpoint, config=self.hf_config, revision=MODEL_REVISION + ) + self.model = self.model.to(device=self.device, dtype=self.dtype) + self.model.eval() + + def _generate_from_messages(self, messages, images, num_tokens=None): + prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() + processor_output = self.processor(text=prompt, images=images, return_tensors="pt") + tensor_stream = processor_output["tensor_stream"].to(self.device) + + with torch.no_grad(): + outputs = self.model.generate( + tensor_stream=tensor_stream, + max_new_tokens=num_tokens or self.max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + output_logits=True, + ) + + generated_ids = outputs.sequences + generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) + return generated_text + + def test_generate_from_image_text(self): + image = _load_red_dot_image() + if image is None: + pytest.skip("PIL.Image is required for Isaac generation tests.") + + messages = [ + {"role": "user", "content": "Describe this image:"}, + {"role": "user", "content": ""}, + ] + generated_text = self._generate_from_messages(messages, [image]) + expected_fragment = "The image is a close-up photograph of a red cross symbol." + assert expected_fragment in generated_text + + def test_generate_from_text_only(self): + document = [ + { + "type": "text", + "content": "What is the pythogorean theorem?", + "role": "user", + } + ] + messages, _ = document_to_messages(document) + generated_text = self._generate_from_messages(messages, [], num_tokens=100) + expected_fragmenet = "The Pythagorean theorem is a fundamental principle in geometry that relates the lengths of the sides of a right-angled triangle. Let's break down the theorem step by step:" + assert expected_fragmenet in generated_text + + def test_vqa_from_image(self): + document = [ + { + "type": "image", + "content": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", + "role": "user", + }, + { + "type": "text", + "content": "Is it safe to cross the street at this moment?", + "role": "user", + }, + ] + messages, images = document_to_messages(document) + generated_text = self._generate_from_messages(messages, images, num_tokens=256) + expected_response = "\nNo, it is not safe to cross the street at this moment. The traffic light for pedestrians is red, indicating that it is not safe to cross." + assert generated_text == expected_response From aa31c36a1aa213d7935c9e5b93fe5406d56a2dda Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Wed, 17 Dec 2025 17:42:03 +0400 Subject: [PATCH 52/77] transformers attention interface + modeling test suite (#11) * style: drop unused function * fix: do not expose gradient checkpoint flag * refactor: move test files to fixtures direcrtory, refer to them in test file * style: remove redundant routing * feat: isaac model forward autodoc + check inputs * docs: add docstring for isaac forward * style: remove conditions on pixel shuffle scale * style: explicitly use torch dtype's min value to avoid float casting * style: remove unnecessary forward() configuration handling * refactor: move to updated masking API * chore: license + module docstring for test * docs: update license * refactor: doc mask rework wip 2 doc mask refactor finished callable rework * style: pass all args down for interface compatibility * style: remove extra cast on attention implementation * test: update tests * test: remove outdated tests * refactor: isolated vision embedding class * refactor: simplify attention flow to prepare for proper handling * refactor: simplify config * chore: convert artifacts * test: add text only test * style: cache * tests: expand integration testing * feat: use HF transformers attention_interface * fix: make config roundtrip * test: bring back config test * test: add common test mixins * fix: can generate is class method * wip 60 fail * fix: return hidden states * fix: make setting input embeddings vocab size aware * wip 2: 29 fail * fix: tied weight keys to correct submodule "text_model" * wip scary change allowing input embds * fix: post init call for tp_plan * fix: enable gradient checkpointing if specific by config * feat: init all weights if not from pretrained * fix: explicitly do not support flex attention * fix: allow attention setting * chore: convert artifacts * fix: temporarily drop _init_weights * test: skip assisted decoding tests, qwen3 doesn't support it * feat: handle 2d position ids for compatibility with HF tests * test: state expectation that model is composite * test: do not test for attention outputs given Qwen3 decoder * wip 3 failures * test: update skips * test: no longer asset that position ids is non * fix: sdpa default * wip 1 test failing * reduce diffs 1 (all tests passing) * reduce diffs 2 (all tests passing) * neeeded for modular tests (all tests passing) * final: latest diff reducer (all tests passing) * test: test flash attention 2 not 3 * attempt 1 logit equiv * move around * fix: hardcode forward to prevent copy issues from conversion tool * fix: allow gradient checkpointing * test: remove redundant tests implemented by HF harness * test: drop unused utility * test: move processor tests to isolated file * test: refactor document mask tests to isolated class for organization * test: drop redundant utilities * test: move logit equivalence test to proper setup in class * test: drop unused fixtures * test: organize imports * test: refactor isolated test to IsaacModelTest * test: remove unneeded generation fixture * test: drop unused logit stats assert helper * style: cleanup forward (all tests passing) * style: cleanup forward more * test: add point extraction test * test: clean up constants * test: separate isaac and base model * test: drop remaining fixtures tests all pass * tests: delete unused fixtures * chore: convert artifact --- .../models/isaac/configuration_isaac.py | 39 +- .../models/isaac/modeling_isaac.py | 743 +++++++------- .../models/isaac/modular_isaac.py | 713 ++++++-------- .../isaac/isaac_checkpoint_hashes.json | 5 - .../isaac/isaac_generation_golden.json | 450 --------- tests/models/isaac/test_modeling_isaac.py | 921 +++++++----------- tests/models/isaac/test_processing_isaac.py | 264 +++++ 7 files changed, 1286 insertions(+), 1849 deletions(-) delete mode 100644 tests/fixtures/isaac/isaac_checkpoint_hashes.json delete mode 100644 tests/fixtures/isaac/isaac_generation_golden.json create mode 100644 tests/models/isaac/test_processing_isaac.py diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index 38e785ef1528..ec8e0f74f967 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -19,8 +19,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import copy from typing import Any, Optional, Union from ...configuration_utils import PreTrainedConfig, PretrainedConfig, layer_type_validation @@ -74,6 +72,10 @@ def __init__( # Add our custom fields self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor + # Ensure a sensible default attention backend + if getattr(self, "_attn_implementation", None) is None: + self._attn_implementation = "sdpa" + class IsaacConfig(PretrainedConfig): """Configuration class for Isaac multimodal model. @@ -96,24 +98,12 @@ def __init__( **kwargs, ): self._rope_parameters: Optional[dict[str, Any]] = None - resolved_text_config = kwargs.pop("text_config", text_config) - if isinstance(resolved_text_config, Qwen3Config): - text_config_kwargs = copy.deepcopy(resolved_text_config.to_dict()) - elif isinstance(resolved_text_config, dict): - text_config_kwargs = copy.deepcopy(resolved_text_config) - elif resolved_text_config is None: - text_config_kwargs = {} - else: - raise TypeError("`text_config` must be a mapping or `Qwen3Config` instance when provided.") + attn_implementation = kwargs.get("attn_implementation") - text_config_kwargs.update(kwargs) - - self.text_config = self.sub_configs["text_config"](**text_config_kwargs) - if not hasattr(self.text_config, "rope_theta"): - rope_theta_override = text_config_kwargs.get("rope_theta", kwargs.get("rope_theta")) - if rope_theta_override is None: - rope_theta_override = getattr(Qwen3Config(), "rope_theta", 10000.0) - self.text_config.rope_theta = rope_theta_override + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + self.text_config = self.sub_configs["text_config"]() super().__init__(**kwargs) @@ -133,7 +123,7 @@ def __init__( self.head_dim = self.text_config.head_dim self.hidden_act = self.text_config.hidden_act self.use_cache = self.text_config.use_cache - self.rope_theta = self.text_config.rope_theta + self.rope_theta = self.text_config.rope_parameters["rope_theta"] # Validate rotary parameters now that they have been mirrored locally. rope_config_validation(self) @@ -149,6 +139,15 @@ def __init__( elif vision_config is None: self.vision_config = self.sub_configs["vision_config"]() + # Propagate user-requested attention backend to the vision sub-config when provided. + if attn_implementation is not None: + if isinstance(attn_implementation, dict): + vision_attn = attn_implementation.get("vision_config", attn_implementation.get("", None)) + else: + vision_attn = attn_implementation + if vision_attn is not None: + self.vision_config._attn_implementation = vision_attn + # Vision normalization parameters self.vision_rescale_factor = float(vision_rescale_factor) diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 03869bff8c83..748d7a8db3a2 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -28,8 +28,8 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation.utils import GenerationMixin -from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub -from ...masking_utils import create_masks_for_generate, eager_mask, packed_sequence_mask_function, sdpa_mask +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func +from ...masking_utils import create_masks_for_generate, packed_sequence_mask_function from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast @@ -220,72 +220,9 @@ def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor return torch.cat(output_chunks, dim=0) -def document_mask_function_from_cu_seqlens(cu_seqlens: Optional[torch.Tensor]) -> Optional[Callable]: - """Return a mask function that blocks cross-document attention from packed ``cu_seqlens``. - - The returned callable matches the signature expected by ``masking_utils`` mask factories and - yields ``True`` only when query/key positions belong to the same packed segment. - """ - - if cu_seqlens is None: - return None - - if cu_seqlens.numel() < 2: - return None - - seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() - if seq_sizes.numel() == 0: - return None - - total_tokens = int(seq_sizes.sum().item()) - seg_ids = torch.repeat_interleave(torch.arange(seq_sizes.numel(), device=cu_seqlens.device), seq_sizes) - packed_sequence_mask = seg_ids.view(1, total_tokens) - return packed_sequence_mask_function(packed_sequence_mask) - - -def ensure_document_attention_mask( - attention_mask: Optional[torch.Tensor], - cu_seqlens: Optional[torch.Tensor], - total_tokens: int, - dtype: torch.dtype, - device: torch.device, - *, - return_mask_function: bool = False, -) -> Optional[Union[torch.Tensor, Callable]]: - """Return the provided mask, a callable mask from ``cu_seqlens``, or ``None``. - - ``return_mask_function=True`` yields a callable suitable for ``masking_utils``; otherwise - ``None`` is returned when no explicit ``attention_mask`` is provided. The legacy additive mask - has been removed in favor of the callable-based path. - """ - - if attention_mask is not None: - return attention_mask - - if cu_seqlens is None: - return None - - if return_mask_function: - return document_mask_function_from_cu_seqlens(cu_seqlens) - - return None - - class IsaacVisionAttention(nn.Module): """Custom attention that supports variable-length sequences with flash attention.""" - ATTENTION_KEY_MAP: dict[str, str] = { - "flash_attention_2": "isaac_flash_attention_2", - "flash_attention_3": "isaac_flash_attention_3", - "isaac_flash_attention_2": "isaac_flash_attention_2", - "isaac_flash_attention_3": "isaac_flash_attention_3", - "sdpa": "isaac_sdpa", - "isaac_sdpa": "isaac_sdpa", - "eager": "isaac_eager", - "isaac_eager": "isaac_eager", - } - _FLASH_IMPLS = frozenset(("isaac_flash_attention_2", "isaac_flash_attention_3")) - def __init__(self, config): super().__init__() self.config = config @@ -305,7 +242,6 @@ def __init__(self, config): self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - self._variable_length_metadata = None def forward( self, @@ -320,259 +256,81 @@ def forward( **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" - # Unused arguments are accepted for interface compatibility + # Ignore unused arguments for interface compatibility _ = position_ids _ = past_key_value _ = is_causal - _ = output_attentions - kwargs.pop("output_hidden_states", None) kwargs.pop("return_dict", None) - if kwargs: - unexpected = ", ".join(sorted(kwargs)) - raise TypeError(f"Unexpected kwargs for IsaacVisionAttention.forward: {unexpected}") - - cached_cu, cached_max = self._consume_variable_length_metadata() - if cu_seqlens is None: - cu_seqlens = cached_cu - if max_seqlen is None: - max_seqlen = cached_max - - # Expect packed sequences with batch_size == 1 - batch_size, L, _ = hidden_states.shape - if batch_size != 1: - raise ValueError("packed variable-length attention expects batch_size=1") - x = hidden_states[0] # (L, E) - - H = self.num_heads - D = self.head_dim - p_drop = self.dropout if self.training else 0.0 - - # Project and reshape to (L, H, D) - q = self.q_proj(x).view(L, H, D) - k = self.k_proj(x).view(L, H, D) - v = self.v_proj(x).view(L, H, D) - - resolved_key = "isaac_sdpa" - if self.config._attn_implementation != "sdpa": - resolved_key = self.ATTENTION_KEY_MAP.get(self.config._attn_implementation, resolved_key) - attn_mask = ensure_document_attention_mask( - attention_mask, - cu_seqlens, - L, - q.dtype, - q.device, - return_mask_function=True, - ) + batch_size, seq_length, embed_dim = hidden_states.shape + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) - attn_weights = None - if resolved_key in self._FLASH_IMPLS: - y_lhd = self._flash_attention_forward( - q_lhd=q, - k_lhd=k, - v_lhd=v, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - dropout=p_drop, - ) - elif resolved_key == "isaac_sdpa": - y_lhd = self._sdpa_attention_forward( - q_lhd=q, - k_lhd=k, - v_lhd=v, - attention_mask=attn_mask, - cu_seqlens=cu_seqlens, - dropout=p_drop, - ) - elif resolved_key == "isaac_eager": - y_lhd, attn_weights = self._eager_attention_forward( - q_lhd=q, - k_lhd=k, - v_lhd=v, - attention_mask=attn_mask, - dropout=p_drop, - ) - else: - attention_fn = ALL_ATTENTION_FUNCTIONS.get(resolved_key) - if attention_fn is None: - raise ValueError(f"Attention implementation {resolved_key} not found.") - - query_states = q.transpose(0, 1).unsqueeze(0) - key_states = k.transpose(0, 1).unsqueeze(0) - value_states = v.transpose(0, 1).unsqueeze(0) - - attention_kwargs: dict[str, Any] = { - "dropout": p_drop, - "scaling": self.scale, - "is_causal": False, - } - if cu_seqlens is not None: - attention_kwargs["cu_seq_lens_q"] = cu_seqlens - attention_kwargs["cu_seq_lens_k"] = cu_seqlens - if max_seqlen is not None: - attention_kwargs["max_length_q"] = max_seqlen - attention_kwargs["max_length_k"] = max_seqlen - - attn_output, attn_weights = attention_fn( - self, - query_states, - key_states, - value_states, - attn_mask, - **attention_kwargs, - ) - - y_lhd = attn_output.squeeze(0).permute(1, 0, 2).contiguous() - - # Merge heads and project - y = self.out_proj(y_lhd.reshape(L, self.embed_dim)) - return y.unsqueeze(0), attn_weights # (1, L, E) + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - def _variable_length_context(self, *, cu_seqlens=None, max_seqlen=None): - """Store packed-sequence metadata for the next forward call.""" - self._variable_length_metadata = (cu_seqlens, max_seqlen) + if not queries.is_contiguous(): + queries = queries.contiguous() + if not keys.is_contiguous(): + keys = keys.contiguous() + if not values.is_contiguous(): + values = values.contiguous() - def _consume_variable_length_metadata(self): - if self._variable_length_metadata is None: - return None, None - cu_seqlens, max_seqlen = self._variable_length_metadata - self._variable_length_metadata = None - return cu_seqlens, max_seqlen - - @staticmethod - def _max_from_cu(cu: Optional[torch.Tensor], fallback: int) -> int: - if cu is None or cu.numel() < 2: - return fallback - return int((cu[1:] - cu[:-1]).max().item()) - - def _flash_attention_forward( - self, - *, - q_lhd: torch.Tensor, - k_lhd: torch.Tensor, - v_lhd: torch.Tensor, - cu_seqlens: Optional[torch.Tensor], - max_seqlen: Optional[int], - dropout: float, - ) -> torch.Tensor: - L = q_lhd.size(0) + L = queries.size(0) if max_seqlen is not None: max_q = max_k = int(max_seqlen) else: max_q = max_k = self._max_from_cu(cu_seqlens, L) - if not q_lhd.is_contiguous(): - q_lhd = q_lhd.contiguous() - if not k_lhd.is_contiguous(): - k_lhd = k_lhd.contiguous() - if not v_lhd.is_contiguous(): - v_lhd = v_lhd.contiguous() - - out_lhd, *_ = torch.ops.aten._flash_attention_forward( - query=q_lhd, - key=k_lhd, - value=v_lhd, - cum_seq_q=cu_seqlens, - cum_seq_k=cu_seqlens, - max_q=max_q, - max_k=max_k, - dropout_p=dropout, - is_causal=False, - return_debug_mask=False, - scale=self.scale, - window_size_left=-1, - window_size_right=-1, - alibi_slopes=None, - ) - return out_lhd + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"] + if self.config._attn_implementation != "sdpa": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - def _sdpa_attention_forward( - self, - *, - q_lhd: torch.Tensor, - k_lhd: torch.Tensor, - v_lhd: torch.Tensor, - attention_mask: Optional[Union[torch.Tensor, Callable]], - cu_seqlens: Optional[torch.Tensor], - dropout: float, - ) -> torch.Tensor: - L = q_lhd.size(0) - attn_mask = attention_mask + dropout = 0.0 if not self.training else self.dropout + attention_kwargs: dict[str, Any] = { + "is_causal": False, + "scaling": self.scale, + "dropout": dropout, + } + if cu_seqlens is not None: + attention_kwargs["cu_seq_lens_q"] = cu_seqlens + attention_kwargs["cu_seq_lens_k"] = cu_seqlens + if max_seqlen is not None: + attention_kwargs["max_length_q"] = max_q + attention_kwargs["max_length_k"] = max_k + if output_attentions: + attention_kwargs["output_attentions"] = True - if callable(attn_mask): - cache_position = torch.arange(L, device=q_lhd.device, dtype=torch.long) - attn_mask = sdpa_mask( - batch_size=1, - cache_position=cache_position, - kv_length=L, - kv_offset=0, - mask_function=attn_mask, - attention_mask=None, - allow_is_causal_skip=False, - allow_is_bidirectional_skip=False, - allow_torch_fix=False, - use_vmap=False, - ) - # sdpa_mask returns True for allowed positions; SDPA expects True to mean "mask out" - if attn_mask is not None and attn_mask.dtype == torch.bool: - attn_mask = ~attn_mask - - q = q_lhd.permute(1, 0, 2).unsqueeze(0) - k = k_lhd.permute(1, 0, 2).unsqueeze(0) - v = v_lhd.permute(1, 0, 2).unsqueeze(0) - - if attn_mask is not None and attn_mask.dtype != q.dtype: - attn_mask = attn_mask.to(q.dtype) - - output = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=attn_mask, - dropout_p=dropout, - scale=self.scale, - is_causal=False, + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + **attention_kwargs, ) - return output.squeeze(0).permute(1, 0, 2).contiguous() - def _eager_attention_forward( - self, - *, - q_lhd: torch.Tensor, - k_lhd: torch.Tensor, - v_lhd: torch.Tensor, - attention_mask: Optional[Union[torch.Tensor, Callable]], - dropout: float, - ) -> tuple[torch.Tensor, torch.Tensor]: - L = q_lhd.size(0) - attn_mask = attention_mask - if callable(attn_mask): - cache_position = torch.arange(L, device=q_lhd.device, dtype=torch.long) - attn_mask = eager_mask( - batch_size=1, - cache_position=cache_position, - kv_length=L, - kv_offset=0, - mask_function=attn_mask, - attention_mask=None, - allow_is_bidirectional_skip=False, - use_vmap=False, - dtype=q_lhd.dtype, - ) - if attn_mask is not None and attn_mask.dim() == 4: - attn_mask = attn_mask.squeeze(0).squeeze(0) + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() - attn_weights = torch.matmul(q_lhd, k_lhd.transpose(1, 2)) * self.scale - if attn_mask is not None: - attn_weights = attn_weights + attn_mask + # Align projection inputs with parameter dtype to avoid mixed-dtype matmul errors + out_proj_dtype = self.out_proj.weight.dtype + if attn_output.dtype != out_proj_dtype: + attn_output = attn_output.to(out_proj_dtype) - attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_lhd.dtype) - if dropout and self.training: - attn_weights = F.dropout(attn_weights, p=dropout, training=True) + attn_output = self.out_proj(attn_output) + if attn_output.dtype != hidden_states.dtype: + attn_output = attn_output.to(hidden_states.dtype) - attn_output_lhd = torch.matmul(attn_weights, v_lhd) - return attn_output_lhd, attn_weights + return attn_output, attn_weights + + @staticmethod + def _max_from_cu(cu: Optional[torch.Tensor], fallback: int) -> int: + if cu is None or cu.numel() < 2: + return fallback + return int((cu[1:] - cu[:-1]).max().item()) class IsaacMLP(nn.Module): @@ -590,6 +348,57 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states +def document_mask_function_from_cu_seqlens(cu_seqlens: Optional[torch.Tensor]) -> Optional[Callable]: + """Return a mask function that blocks cross-document attention from packed ``cu_seqlens``. + + The returned callable matches the signature expected by ``masking_utils`` mask factories and + yields ``True`` only when query/key positions belong to the same packed segment. + """ + + if cu_seqlens is None: + return None + + if cu_seqlens.numel() < 2: + return None + + seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() + if seq_sizes.numel() == 0: + return None + + total_tokens = int(seq_sizes.sum().item()) + seg_ids = torch.repeat_interleave(torch.arange(seq_sizes.numel(), device=cu_seqlens.device), seq_sizes) + packed_sequence_mask = seg_ids.view(1, total_tokens) + return packed_sequence_mask_function(packed_sequence_mask) + + +def ensure_document_attention_mask( + attention_mask: Optional[torch.Tensor], + cu_seqlens: Optional[torch.Tensor], + total_tokens: int, + dtype: torch.dtype, + device: torch.device, + *, + return_mask_function: bool = False, +) -> Optional[Union[torch.Tensor, Callable]]: + """Return the provided mask, a callable mask from ``cu_seqlens``, or ``None``. + + ``return_mask_function=True`` yields a callable suitable for ``masking_utils``; otherwise + ``None`` is returned when no explicit ``attention_mask`` is provided. The legacy additive mask + has been removed in favor of the callable-based path. + """ + + if attention_mask is not None: + return attention_mask + + if cu_seqlens is None: + return None + + if return_mask_function: + return document_mask_function_from_cu_seqlens(cu_seqlens) + + return None + + class IsaacVisionEncoderLayer(GradientCheckpointingLayer): """Isaac vision encoder layer with variable-length attention.""" @@ -619,34 +428,39 @@ def forward( Maximum document length referenced by `cu_seqlens`. Passed to FlashAttention so it can size temporary buffers for packed variable-length attention. """ - if cu_seqlens is not None or max_seqlen is not None: - self.self_attn._variable_length_context( - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - attention_mask = ensure_document_attention_mask( attention_mask, cu_seqlens, hidden_states.size(1), hidden_states.dtype, hidden_states.device, + return_mask_function=False, ) - residual = hidden_states + # Run attention directly so variable-length metadata reaches FlashAttention. + residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states, _ = self.self_attn( - hidden_states=hidden_states, + attn_outputs = self.self_attn( + hidden_states, attention_mask=attention_mask, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_attentions=output_attentions, **kwargs, ) - hidden_states = residual + hidden_states + if isinstance(attn_outputs, tuple): + attn_output, attn_weights = attn_outputs + else: + attn_output, attn_weights = attn_outputs, None + hidden_states = residual + attn_output residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states + if output_attentions: + return hidden_states, attn_weights return hidden_states @@ -672,36 +486,33 @@ def forward( return_dict: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: - self.__variable_length_context(cu_seqlens, max_seqlen) - attention_mask = ensure_document_attention_mask( attention_mask, cu_seqlens, inputs_embeds.size(1), inputs_embeds.dtype, inputs_embeds.device, + return_mask_function=False, ) + hidden_states = inputs_embeds + kwargs.update( + { + "max_seqlen": max_seqlen, + "cu_seqlens": cu_seqlens, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + } + ) for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, attention_mask, **kwargs, ) - return BaseModelOutput(last_hidden_state=hidden_states) - def __variable_length_context(self, cu_seqlens, max_seqlen) -> None: - if cu_seqlens is None and max_seqlen is None: - return - - for layer in self.layers: - if isinstance(layer, IsaacVisionEncoderLayer): - layer.self_attn._variable_length_context( - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - def create_pixel_shuffle_index_map( seq_sizes: torch.Tensor, @@ -831,6 +642,8 @@ def pixel_shuffle_varlen( class IsaacVisionTransformer(nn.Module): + _supports_sdpa = True + def __init__(self, config: IsaacVisionConfig): super().__init__() self.config = config @@ -881,6 +694,8 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): class IsaacVisionEmbedding(nn.Module): """Vision embedding wrapper exposing tower and projector.""" + _supports_sdpa = True + def __init__(self, config: IsaacConfig): super().__init__() vision_cfg = config.vision_config @@ -967,7 +782,8 @@ def forward( with torch.no_grad(): pos = position_ids.clone() - not_spatial = modality_tensor != VisionType.image.value + image_value = VisionType.image.value if VisionType is not None else 1 + not_spatial = modality_tensor != image_value if not_spatial.any(): data_1d = pos[not_spatial][..., 0].unsqueeze(-1) pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) @@ -1079,6 +895,7 @@ def eager_attention_forward( return attn_output, attn_weights +@use_kernelized_func(apply_rotary_pos_emb) class IsaacAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -1105,7 +922,6 @@ def __init__(self, config: IsaacConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - self.rotary_fn = apply_rotary_pos_emb self.q_norm = IsaacRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = IsaacRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None @@ -1201,29 +1017,6 @@ def forward( return hidden_states -# ============================================================================ -# Model -# ============================================================================ - - -def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: - r"""Create 3D positional indices for token input. - - Args: - input_ids (`torch.Tensor`): - Tensor of shape `(batch_size, seq_len)` containing token ids. - - Returns: - `torch.Tensor`: Positional indices with shape `(batch_size, seq_len, 3)` where each channel duplicates the - 1D position so it can be consumed by the 3-axis MRoPE rotary embedding. - """ - batch_size, seq_length = input_ids.shape - position_ids = torch.arange(seq_length, device=input_ids.device) - position_ids = position_ids.view(1, -1).expand(batch_size, -1) - position_ids = position_ids.unsqueeze(2).expand(-1, -1, 3) # Add 3D for MRoPE - return position_ids - - @auto_docstring class IsaacModel(PreTrainedModel): config: IsaacConfig @@ -1233,14 +1026,15 @@ class IsaacModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True - _supports_flex_attn = True - - _can_compile_fullgraph = True + _supports_flex_attn = False + _can_compile_fullgraph = False _supports_attention_backend = True _can_record_outputs = { "hidden_states": IsaacDecoderLayer, "attentions": IsaacAttention, } + # Expose tied-weights mapping even if empty for base model tests. + all_tied_weights_keys: dict[str, str] = {} def __init__(self, config: IsaacConfig): Qwen3PreTrainedModel.__init__(self, config) @@ -1257,6 +1051,7 @@ def __init__(self, config: IsaacConfig): raise ValueError("IsaacConfig should always have vision_config") self.vision_embedding = IsaacVisionEmbedding(config) + self.vision_embedding._supports_sdpa = True # Dispatch table for TensorStream balanced embedding (text + vision) self.embed_fns = { @@ -1269,11 +1064,24 @@ def __init__(self, config: IsaacConfig): self.vision_rescale_factor = config.vision_rescale_factor self.vision_token = config.vision_token + # Initialize weights and parallel plans (including tp_plan from the text model) + self.post_init() + + # Respect config-specified gradient checkpointing + if getattr(config, "gradient_checkpointing", False): + self.gradient_checkpointing_enable() + def get_input_embeddings(self) -> nn.Module: return self.text_model.get_input_embeddings() def set_input_embeddings(self, value: nn.Module) -> None: self.text_model.set_input_embeddings(value) + vocab_size = getattr(value, "num_embeddings", None) + if vocab_size is not None: + self.config.vocab_size = vocab_size + if hasattr(self.config, "text_config"): + self.config.text_config.vocab_size = vocab_size + self.text_model.config.vocab_size = vocab_size @property def embed_tokens(self) -> nn.Module: @@ -1291,6 +1099,14 @@ def layers(self) -> nn.ModuleList: def norm(self) -> nn.Module: return self.text_model.norm + @property + def vision_model(self) -> nn.Module: + return self.vision_embedding.vision_tower + + @property + def vision_tower(self) -> nn.Module: + return self.vision_embedding.vision_tower + def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: """Embed text tokens, squeezing singleton dimensions.""" # Text events are shaped as (..., 1); squeeze the singleton index dim @@ -1350,6 +1166,7 @@ def forward( past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, @@ -1371,47 +1188,72 @@ def forward( omitted. """ + text_value = TextType.text.value if TextType is not None else 0 + # Get inputs if tensor_stream is not None and inputs_embeds is not None: raise ValueError("You cannot specify both tensor_stream and inputs_embeds") - elif tensor_stream is not None: - # Embed TensorStream directly - inputs_embeds = self.embed_stream(tensor_stream) - # Create modality tensor if not provided - if modality_tensor is None: - modality_tensor = modality_mask(tensor_stream) - elif input_ids is not None and inputs_embeds is not None: + if tensor_stream is None and input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + + # Resolve the input source (TensorStream takes precedence over token ids). + if tensor_stream is not None: + inputs_embeds = self.embed_stream(tensor_stream) elif input_ids is not None: inputs_embeds = self.text_model.embed_tokens(input_ids) - # Create text modality tensor if not provided - if modality_tensor is None: - batch_size, seq_length = input_ids.shape - modality_tensor = torch.full( - (batch_size, seq_length), TextType.text.value, device=input_ids.device, dtype=torch.long - ) elif inputs_embeds is None: raise ValueError("You have to specify either tensor_stream, input_ids or inputs_embeds") + batch_size, seq_len = inputs_embeds.shape[:2] + # Ensure cache exists when requested if use_cache and past_key_values is None: cache_config = self.config.get_text_config() if hasattr(self.config, "get_text_config") else self.config past_key_values = DynamicCache(config=cache_config) - if cache_position is None and (past_key_values is not None or use_cache): + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - ) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long) - # Create default position_ids if not provided + # Normalize modality tensor + if modality_tensor is None: + if tensor_stream is not None: + modality_tensor = modality_mask(tensor_stream) + else: + modality_tensor = torch.full( + (batch_size, seq_len), text_value, device=inputs_embeds.device, dtype=torch.long + ) + else: + modality_tensor = modality_tensor.to(dtype=torch.long) + + if modality_tensor.shape[1] != seq_len: + if modality_tensor.shape[1] > seq_len: + modality_tensor = modality_tensor[:, :seq_len] + else: + pad = modality_tensor[:, -1:].expand(-1, seq_len - modality_tensor.shape[1]) + modality_tensor = torch.cat([modality_tensor, pad], dim=1) + + # Normalize position ids if position_ids is None: if tensor_stream is not None: position_ids = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) else: - position_ids = compute_position_ids_input_ids(input_ids) + position_ids = cache_position.view(1, -1).expand(modality_tensor.shape[0], -1) + + # Expand 2D position ids (from generic padding tests or decode cache positions) to 3D MRoPE coords + if position_ids.ndim == 2: + position_ids = position_ids.to(device=inputs_embeds.device) + position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) + + # Align lengths so rotary embedding sees matching shapes + if position_ids.shape[1] != seq_len: + start_positions = position_ids[:, :1, 0] + position_ids = torch.arange(seq_len, device=inputs_embeds.device).view(1, -1) + position_ids = position_ids + start_positions + position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) # Compute MRoPE position embeddings if we have custom rotary_emb cos, sin = self.rotary_emb( @@ -1422,38 +1264,46 @@ def forward( cos = cos.to(inputs_embeds.dtype) sin = sin.to(inputs_embeds.dtype) - # Prepare attention mask + # Flash attention expects 1D position_ids; keep 3D only for rotary phases + decoder_position_ids = position_ids[..., 0] if position_ids.ndim == 3 else position_ids + # Prepare attention mask if not isinstance(attention_mask, dict): - mask_kwargs = { - "config": self.config, - "input_embeds": inputs_embeds, - "attention_mask": attention_mask, - "cache_position": cache_position, - "past_key_values": past_key_values, - "position_ids": position_ids, - } - attention_mask = create_masks_for_generate(**mask_kwargs) + attention_mask = create_masks_for_generate( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=decoder_position_ids, + ) + + is_attention_mask_dict = isinstance(attention_mask, dict) # Initialize hidden states hidden_states = inputs_embeds + all_attentions = [] if output_attentions else None for decoder_layer in self.text_model.layers: layer_attention_mask = ( - attention_mask[decoder_layer.attention_type] if isinstance(attention_mask, dict) else attention_mask + attention_mask[decoder_layer.attention_type] if is_attention_mask_dict else attention_mask ) layer_outputs = decoder_layer( hidden_states, attention_mask=layer_attention_mask, - position_ids=position_ids, + position_ids=decoder_position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=(cos, sin), + output_attentions=output_attentions, **kwargs, ) - hidden_states = layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs + layer_outputs_is_tuple = isinstance(layer_outputs, tuple) + hidden_states = layer_outputs[0] if layer_outputs_is_tuple else layer_outputs + if output_attentions and layer_outputs_is_tuple: + all_attentions.append(layer_outputs[1]) # Final layer norm hidden_states = self.text_model.norm(hidden_states) @@ -1461,6 +1311,8 @@ def forward( return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, + hidden_states=(hidden_states,), + attentions=tuple(all_attentions) if output_attentions else None, ) @@ -1483,15 +1335,40 @@ class IsaacPreTrainedModel(PreTrainedModel): } +# ============================================================================ +# Model +# ============================================================================ + + +def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: + r"""Create 3D positional indices for token input. + + Args: + input_ids (`torch.Tensor`): + Tensor of shape `(batch_size, seq_len)` containing token ids. + + Returns: + `torch.Tensor`: Positional indices with shape `(batch_size, seq_len, 3)` where each channel duplicates the + 1D position so it can be consumed by the 3-axis MRoPE rotary embedding. + """ + batch_size, seq_length = input_ids.shape + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.unsqueeze(2).expand(-1, -1, 3) # Add 3D for MRoPE + return position_ids + + @auto_docstring class IsaacForConditionalGeneration(IsaacPreTrainedModel, GenerationMixin): """Isaac multimodal model for conditional generation.""" - _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} config_class = IsaacConfig + _can_compile_fullgraph = False + all_tied_weights_keys: dict[str, str] = {"lm_head.weight": "model.text_model.embed_tokens.weight"} def __init__(self, config: IsaacConfig): super().__init__(config) @@ -1516,6 +1393,7 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, @@ -1536,30 +1414,47 @@ def forward( if input_ids is None and inputs_embeds is None and tensor_stream is None: raise ValueError("Either input_ids, inputs_embeds, or tensor_stream must be provided.") + text_value = TextType.text.value if TextType is not None else 0 + + if tensor_stream is None: + if input_ids is not None: + batch_size, seq_len = input_ids.shape + input_device = input_ids.device + else: + batch_size, seq_len = inputs_embeds.shape[:2] + input_device = inputs_embeds.device + # Build position ids (MRoPE) if needed and tensor_stream is available # During decode we reuse `self.rope_deltas` computed on the initial forward pass; `rope_delta` captures how far # cached rotary phases have progressed so we can advance `position_ids` without rebuilding the TensorStream. - if position_ids is None and tensor_stream is not None: - position_ids, self.rope_deltas = self.get_rope_index(input_ids, tensor_stream, attention_mask) - elif position_ids is None and input_ids is not None: - # For text inputs build position ids and modality tensor - position_ids = compute_position_ids_input_ids(input_ids) - if cache_position is not None and self.rope_deltas is not None: - # Combine the incremental decode step (`cache_position`) with cached offsets so hidden states continue - # rotating in lockstep across generation steps. - rope_delta = (cache_position[0] + self.rope_deltas).to(input_ids.device) + if position_ids is None: + if tensor_stream is not None: + position_ids, self.rope_deltas = self.get_rope_index(input_ids, tensor_stream, attention_mask) + elif input_ids is None: + dummy_ids = torch.zeros((batch_size, seq_len), device=input_device, dtype=torch.long) + position_ids = compute_position_ids_input_ids(dummy_ids) else: + position_ids = compute_position_ids_input_ids(input_ids) + rope_delta = 0 - if cache_position is not None and not isinstance(rope_delta, int): # otherwise `deltas` is an int `0` - batch_size = input_ids.shape[0] - rope_delta = rope_delta.repeat_interleave(batch_size // rope_delta.shape[0], dim=0) - position_ids = position_ids.add(rope_delta) + if cache_position is not None and self.rope_deltas is not None: + # Combine the incremental decode step (`cache_position`) with cached offsets so hidden states continue + # rotating in lockstep across generation steps. + rope_delta = (cache_position[0] + self.rope_deltas).to(input_ids.device) + if not isinstance(rope_delta, int): # otherwise `deltas` is an int `0` + rope_delta = rope_delta.repeat_interleave(batch_size // rope_delta.shape[0], dim=0) + + position_ids = position_ids.add(rope_delta) + + if attention_mask is None and tensor_stream is None: + attention_mask = torch.ones((batch_size, seq_len), device=input_device, dtype=torch.long) if tensor_stream is not None: modality_tensor = modality_mask(tensor_stream) else: - batch_size, seq_len = input_ids.shape - modality_tensor = torch.empty(batch_size, seq_len, device=position_ids.device).fill_(TextType.text.value) + modality_tensor = torch.full( + (batch_size, seq_len), text_value, device=position_ids.device, dtype=torch.long + ) outputs = self.model( input_ids=input_ids, @@ -1570,6 +1465,7 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, + output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, @@ -1588,9 +1484,22 @@ def forward( logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, - attentions=None, + attentions=outputs.attentions if output_attentions else None, ) + def set_input_embeddings(self, value: nn.Module) -> None: + self.model.set_input_embeddings(value) + vocab_size = getattr(value, "num_embeddings", None) + if vocab_size is not None: + self.config.vocab_size = vocab_size + self.model.config.vocab_size = vocab_size + if hasattr(self.model, "text_model"): + self.model.text_model.config.vocab_size = vocab_size + if self.lm_head.weight.shape[0] != vocab_size: + self.lm_head = nn.Linear(self.config.hidden_size, vocab_size, bias=False) + if hasattr(self.model, "embed_tokens"): + self.lm_head.weight = self.model.text_model.embed_tokens.weight + def get_rope_index( self, input_ids: Optional[torch.Tensor], @@ -1670,17 +1579,35 @@ def prepare_inputs_for_generation( cache_position = model_inputs.get("cache_position", cache_position) - # Handle TensorStream for first forward pass only - if tensor_stream is not None and (cache_position is None or cache_position[0] == 0): + # Handle TensorStream only for the prefill step + first_step = cache_position is None or cache_position[0] == 0 + if tensor_stream is not None and first_step: model_inputs["tensor_stream"] = tensor_stream - # Let forward rebuild position_ids using cached deltas during decode - model_inputs["position_ids"] = None - # Drop tensor_stream after step 0 - if cache_position is not None and cache_position[0] != 0: + # Let forward rebuild MRoPE coordinates from the TensorStream + model_inputs["position_ids"] = None + else: model_inputs["tensor_stream"] = None + + # TensorStream decode path: preserve rotary offsets from prefill + if tensor_stream is not None and not first_step and self.rope_deltas is not None: + model_inputs["position_ids"] = None + return model_inputs + + # For decode steps, synthesize position_ids that continue from the cache offsets + if model_inputs.get("position_ids") is None and cache_position is not None and not first_step: + batch_size = 1 + if model_inputs.get("input_ids") is not None: + batch_size = model_inputs["input_ids"].shape[0] + elif model_inputs.get("inputs_embeds") is not None: + batch_size = model_inputs["inputs_embeds"].shape[0] + pos_ids = cache_position.view(1, -1).expand(batch_size, -1) + pos_ids = pos_ids.unsqueeze(-1).expand(-1, -1, 3) + model_inputs["position_ids"] = pos_ids + return model_inputs - def can_generate(self) -> bool: + @classmethod + def can_generate(cls) -> bool: return True diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 289c6af9f70d..7f973a1731aa 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -87,8 +87,8 @@ ChannelDimension, PILImageResampling, ) -from ...masking_utils import create_masks_for_generate, eager_mask, packed_sequence_mask_function, sdpa_mask -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...masking_utils import create_masks_for_generate, packed_sequence_mask_function +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...models.auto.modeling_auto import AutoModel @@ -157,6 +157,10 @@ def __init__( # Add our custom fields self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor + # Ensure a sensible default attention backend + if getattr(self, "_attn_implementation", None) is None: + self._attn_implementation = "sdpa" + class IsaacImageProcessorKwargs(ImagesKwargs, total=False): patch_size: Optional[int] @@ -605,34 +609,6 @@ def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor class IsaacVisionAttention(Siglip2Attention): """Custom attention that supports variable-length sequences with flash attention.""" - ATTENTION_KEY_MAP: dict[str, str] = { - "flash_attention_2": "isaac_flash_attention_2", - "flash_attention_3": "isaac_flash_attention_3", - "isaac_flash_attention_2": "isaac_flash_attention_2", - "isaac_flash_attention_3": "isaac_flash_attention_3", - "sdpa": "isaac_sdpa", - "isaac_sdpa": "isaac_sdpa", - "eager": "isaac_eager", - "isaac_eager": "isaac_eager", - } - _FLASH_IMPLS = frozenset(("isaac_flash_attention_2", "isaac_flash_attention_3")) - - def __init__(self, config): - super().__init__(config) - self.config = config - self._variable_length_metadata = None - - def _variable_length_context(self, *, cu_seqlens=None, max_seqlen=None): - """Store packed-sequence metadata for the next forward call.""" - self._variable_length_metadata = (cu_seqlens, max_seqlen) - - def _consume_variable_length_metadata(self): - if self._variable_length_metadata is None: - return None, None - cu_seqlens, max_seqlen = self._variable_length_metadata - self._variable_length_metadata = None - return cu_seqlens, max_seqlen - def forward( self, hidden_states: torch.Tensor, @@ -645,248 +621,81 @@ def forward( max_seqlen: Optional[int] = None, **kwargs, ): - # Unused arguments are accepted for interface compatibility + # Ignore unused arguments for interface compatibility _ = position_ids _ = past_key_value _ = is_causal - _ = output_attentions - kwargs.pop("output_hidden_states", None) kwargs.pop("return_dict", None) - if kwargs: - unexpected = ", ".join(sorted(kwargs)) - raise TypeError(f"Unexpected kwargs for IsaacVisionAttention.forward: {unexpected}") - - cached_cu, cached_max = self._consume_variable_length_metadata() - if cu_seqlens is None: - cu_seqlens = cached_cu - if max_seqlen is None: - max_seqlen = cached_max - - # Expect packed sequences with batch_size == 1 - batch_size, L, _ = hidden_states.shape - if batch_size != 1: - raise ValueError("packed variable-length attention expects batch_size=1") - x = hidden_states[0] # (L, E) - - H = self.num_heads - D = self.head_dim - p_drop = self.dropout if self.training else 0.0 - - # Project and reshape to (L, H, D) - q = self.q_proj(x).view(L, H, D) - k = self.k_proj(x).view(L, H, D) - v = self.v_proj(x).view(L, H, D) - - resolved_key = "isaac_sdpa" - if self.config._attn_implementation != "sdpa": - resolved_key = self.ATTENTION_KEY_MAP.get(self.config._attn_implementation, resolved_key) - attn_mask = ensure_document_attention_mask( - attention_mask, - cu_seqlens, - L, - q.dtype, - q.device, - return_mask_function=True, - ) + batch_size, seq_length, embed_dim = hidden_states.shape + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) - attn_weights = None - if resolved_key in self._FLASH_IMPLS: - y_lhd = self._flash_attention_forward( - q_lhd=q, - k_lhd=k, - v_lhd=v, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - dropout=p_drop, - ) - elif resolved_key == "isaac_sdpa": - y_lhd = self._sdpa_attention_forward( - q_lhd=q, - k_lhd=k, - v_lhd=v, - attention_mask=attn_mask, - cu_seqlens=cu_seqlens, - dropout=p_drop, - ) - elif resolved_key == "isaac_eager": - y_lhd, attn_weights = self._eager_attention_forward( - q_lhd=q, - k_lhd=k, - v_lhd=v, - attention_mask=attn_mask, - dropout=p_drop, - ) - else: - attention_fn = ALL_ATTENTION_FUNCTIONS.get(resolved_key) - if attention_fn is None: - raise ValueError(f"Attention implementation {resolved_key} not found.") - - query_states = q.transpose(0, 1).unsqueeze(0) - key_states = k.transpose(0, 1).unsqueeze(0) - value_states = v.transpose(0, 1).unsqueeze(0) - - attention_kwargs: dict[str, Any] = { - "dropout": p_drop, - "scaling": self.scale, - "is_causal": False, - } - if cu_seqlens is not None: - attention_kwargs["cu_seq_lens_q"] = cu_seqlens - attention_kwargs["cu_seq_lens_k"] = cu_seqlens - if max_seqlen is not None: - attention_kwargs["max_length_q"] = max_seqlen - attention_kwargs["max_length_k"] = max_seqlen - - attn_output, attn_weights = attention_fn( - self, - query_states, - key_states, - value_states, - attn_mask, - **attention_kwargs, - ) - - y_lhd = attn_output.squeeze(0).permute(1, 0, 2).contiguous() - - # Merge heads and project - y = self.out_proj(y_lhd.reshape(L, self.embed_dim)) - return y.unsqueeze(0), attn_weights # (1, L, E) + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - @staticmethod - def _max_from_cu(cu: Optional[torch.Tensor], fallback: int) -> int: - if cu is None or cu.numel() < 2: - return fallback - return int((cu[1:] - cu[:-1]).max().item()) + if not queries.is_contiguous(): + queries = queries.contiguous() + if not keys.is_contiguous(): + keys = keys.contiguous() + if not values.is_contiguous(): + values = values.contiguous() - def _flash_attention_forward( - self, - *, - q_lhd: torch.Tensor, - k_lhd: torch.Tensor, - v_lhd: torch.Tensor, - cu_seqlens: Optional[torch.Tensor], - max_seqlen: Optional[int], - dropout: float, - ) -> torch.Tensor: - L = q_lhd.size(0) + L = queries.size(0) if max_seqlen is not None: max_q = max_k = int(max_seqlen) else: max_q = max_k = self._max_from_cu(cu_seqlens, L) - if not q_lhd.is_contiguous(): - q_lhd = q_lhd.contiguous() - if not k_lhd.is_contiguous(): - k_lhd = k_lhd.contiguous() - if not v_lhd.is_contiguous(): - v_lhd = v_lhd.contiguous() - - out_lhd, *_ = torch.ops.aten._flash_attention_forward( - query=q_lhd, - key=k_lhd, - value=v_lhd, - cum_seq_q=cu_seqlens, - cum_seq_k=cu_seqlens, - max_q=max_q, - max_k=max_k, - dropout_p=dropout, - is_causal=False, - return_debug_mask=False, - scale=self.scale, - window_size_left=-1, - window_size_right=-1, - alibi_slopes=None, - ) - return out_lhd - - def _sdpa_attention_forward( - self, - *, - q_lhd: torch.Tensor, - k_lhd: torch.Tensor, - v_lhd: torch.Tensor, - attention_mask: Optional[Union[torch.Tensor, Callable]], - cu_seqlens: Optional[torch.Tensor], - dropout: float, - ) -> torch.Tensor: - L = q_lhd.size(0) - attn_mask = attention_mask + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"] + if self.config._attn_implementation != "sdpa": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - if callable(attn_mask): - cache_position = torch.arange(L, device=q_lhd.device, dtype=torch.long) - attn_mask = sdpa_mask( - batch_size=1, - cache_position=cache_position, - kv_length=L, - kv_offset=0, - mask_function=attn_mask, - attention_mask=None, - allow_is_causal_skip=False, - allow_is_bidirectional_skip=False, - allow_torch_fix=False, - use_vmap=False, - ) - # sdpa_mask returns True for allowed positions; SDPA expects True to mean "mask out" - if attn_mask is not None and attn_mask.dtype == torch.bool: - attn_mask = ~attn_mask - - q = q_lhd.permute(1, 0, 2).unsqueeze(0) - k = k_lhd.permute(1, 0, 2).unsqueeze(0) - v = v_lhd.permute(1, 0, 2).unsqueeze(0) - - if attn_mask is not None and attn_mask.dtype != q.dtype: - attn_mask = attn_mask.to(q.dtype) - - output = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=attn_mask, - dropout_p=dropout, - scale=self.scale, - is_causal=False, + dropout = 0.0 if not self.training else self.dropout + attention_kwargs: dict[str, Any] = { + "is_causal": False, + "scaling": self.scale, + "dropout": dropout, + } + if cu_seqlens is not None: + attention_kwargs["cu_seq_lens_q"] = cu_seqlens + attention_kwargs["cu_seq_lens_k"] = cu_seqlens + if max_seqlen is not None: + attention_kwargs["max_length_q"] = max_q + attention_kwargs["max_length_k"] = max_k + if output_attentions: + attention_kwargs["output_attentions"] = True + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + **attention_kwargs, ) - return output.squeeze(0).permute(1, 0, 2).contiguous() - def _eager_attention_forward( - self, - *, - q_lhd: torch.Tensor, - k_lhd: torch.Tensor, - v_lhd: torch.Tensor, - attention_mask: Optional[Union[torch.Tensor, Callable]], - dropout: float, - ) -> tuple[torch.Tensor, torch.Tensor]: - L = q_lhd.size(0) - attn_mask = attention_mask - if callable(attn_mask): - cache_position = torch.arange(L, device=q_lhd.device, dtype=torch.long) - attn_mask = eager_mask( - batch_size=1, - cache_position=cache_position, - kv_length=L, - kv_offset=0, - mask_function=attn_mask, - attention_mask=None, - allow_is_bidirectional_skip=False, - use_vmap=False, - dtype=q_lhd.dtype, - ) - if attn_mask is not None and attn_mask.dim() == 4: - attn_mask = attn_mask.squeeze(0).squeeze(0) + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + + # Align projection inputs with parameter dtype to avoid mixed-dtype matmul errors + out_proj_dtype = self.out_proj.weight.dtype + if attn_output.dtype != out_proj_dtype: + attn_output = attn_output.to(out_proj_dtype) - attn_weights = torch.matmul(q_lhd, k_lhd.transpose(1, 2)) * self.scale - if attn_mask is not None: - attn_weights = attn_weights + attn_mask + attn_output = self.out_proj(attn_output) + if attn_output.dtype != hidden_states.dtype: + attn_output = attn_output.to(hidden_states.dtype) - attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_lhd.dtype) - if dropout and self.training: - attn_weights = F.dropout(attn_weights, p=dropout, training=True) + return attn_output, attn_weights - attn_output_lhd = torch.matmul(attn_weights, v_lhd) - return attn_output_lhd, attn_weights + @staticmethod + def _max_from_cu(cu: Optional[torch.Tensor], fallback: int) -> int: + if cu is None or cu.numel() < 2: + return fallback + return int((cu[1:] - cu[:-1]).max().item()) class IsaacVisionEncoderLayer(Siglip2EncoderLayer): @@ -913,26 +722,40 @@ def forward( Maximum document length referenced by `cu_seqlens`. Passed to FlashAttention so it can size temporary buffers for packed variable-length attention. """ - if cu_seqlens is not None or max_seqlen is not None: - self.self_attn._variable_length_context( - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - attention_mask = ensure_document_attention_mask( attention_mask, cu_seqlens, hidden_states.size(1), hidden_states.dtype, hidden_states.device, + return_mask_function=False, ) - return super().forward( + # Run attention directly so variable-length metadata reaches FlashAttention. + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + attn_outputs = self.self_attn( hidden_states, attention_mask=attention_mask, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, output_attentions=output_attentions, **kwargs, ) + if isinstance(attn_outputs, tuple): + attn_output, attn_weights = attn_outputs + else: + attn_output, attn_weights = attn_outputs, None + hidden_states = residual + attn_output + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + if output_attentions: + return hidden_states, attn_weights + return hidden_states class IsaacVisionEncoder(Siglip2Encoder): @@ -942,17 +765,6 @@ def __init__(self, config: IsaacVisionConfig): super().__init__(config) self.layers = nn.ModuleList([IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) - def __variable_length_context(self, cu_seqlens, max_seqlen) -> None: - if cu_seqlens is None and max_seqlen is None: - return - - for layer in self.layers: - if isinstance(layer, IsaacVisionEncoderLayer): - layer.self_attn._variable_length_context( - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - @can_return_tuple def forward( self, @@ -965,24 +777,32 @@ def forward( return_dict: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): - self.__variable_length_context(cu_seqlens, max_seqlen) - attention_mask = ensure_document_attention_mask( attention_mask, cu_seqlens, inputs_embeds.size(1), inputs_embeds.dtype, inputs_embeds.device, + return_mask_function=False, ) - return super().forward( - inputs_embeds, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - **kwargs, + hidden_states = inputs_embeds + kwargs.update( + { + "max_seqlen": max_seqlen, + "cu_seqlens": cu_seqlens, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + } ) + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask, + **kwargs, + ) + return BaseModelOutput(last_hidden_state=hidden_states) def create_pixel_shuffle_index_map( @@ -1113,6 +933,8 @@ def pixel_shuffle_varlen( class IsaacVisionTransformer(nn.Module): + _supports_sdpa = True + def __init__(self, config: IsaacVisionConfig): super().__init__() self.config = config @@ -1163,6 +985,8 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): class IsaacVisionEmbedding(nn.Module): """Vision embedding wrapper exposing tower and projector.""" + _supports_sdpa = True + def __init__(self, config: IsaacConfig): super().__init__() vision_cfg = config.vision_config @@ -1318,24 +1142,12 @@ def __init__( **kwargs, ): self._rope_parameters: Optional[dict[str, Any]] = None - resolved_text_config = kwargs.pop("text_config", text_config) - if isinstance(resolved_text_config, Qwen3Config): - text_config_kwargs = copy.deepcopy(resolved_text_config.to_dict()) - elif isinstance(resolved_text_config, dict): - text_config_kwargs = copy.deepcopy(resolved_text_config) - elif resolved_text_config is None: - text_config_kwargs = {} - else: - raise TypeError("`text_config` must be a mapping or `Qwen3Config` instance when provided.") - - text_config_kwargs.update(kwargs) + attn_implementation = kwargs.get("attn_implementation") - self.text_config = self.sub_configs["text_config"](**text_config_kwargs) - if not hasattr(self.text_config, "rope_theta"): - rope_theta_override = text_config_kwargs.get("rope_theta", kwargs.get("rope_theta")) - if rope_theta_override is None: - rope_theta_override = getattr(Qwen3Config(), "rope_theta", 10000.0) - self.text_config.rope_theta = rope_theta_override + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + self.text_config = self.sub_configs["text_config"]() super().__init__(**kwargs) @@ -1355,7 +1167,7 @@ def __init__( self.head_dim = self.text_config.head_dim self.hidden_act = self.text_config.hidden_act self.use_cache = self.text_config.use_cache - self.rope_theta = self.text_config.rope_theta + self.rope_theta = self.text_config.rope_parameters["rope_theta"] # Validate rotary parameters now that they have been mirrored locally. rope_config_validation(self) @@ -1371,6 +1183,15 @@ def __init__( elif vision_config is None: self.vision_config = self.sub_configs["vision_config"]() + # Propagate user-requested attention backend to the vision sub-config when provided. + if attn_implementation is not None: + if isinstance(attn_implementation, dict): + vision_attn = attn_implementation.get("vision_config", attn_implementation.get("", None)) + else: + vision_attn = attn_implementation + if vision_attn is not None: + self.vision_config._attn_implementation = vision_attn + # Vision normalization parameters self.vision_rescale_factor = float(vision_rescale_factor) @@ -1715,7 +1536,8 @@ def forward( with torch.no_grad(): pos = position_ids.clone() - not_spatial = modality_tensor != VisionType.image.value + image_value = VisionType.image.value if VisionType is not None else 1 + not_spatial = modality_tensor != image_value if not_spatial.any(): data_1d = pos[not_spatial][..., 0].unsqueeze(-1) pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) @@ -1735,6 +1557,10 @@ def forward( class IsaacModel(Qwen3PreTrainedModel): supports_gradient_checkpointing = True + _can_compile_fullgraph = False + _supports_flex_attn = False + # Expose tied-weights mapping even if empty for base model tests. + all_tied_weights_keys: dict[str, str] = {} def __init__(self, config: IsaacConfig): Qwen3PreTrainedModel.__init__(self, config) @@ -1751,6 +1577,7 @@ def __init__(self, config: IsaacConfig): raise ValueError("IsaacConfig should always have vision_config") self.vision_embedding = IsaacVisionEmbedding(config) + self.vision_embedding._supports_sdpa = True # Dispatch table for TensorStream balanced embedding (text + vision) self.embed_fns = { @@ -1763,11 +1590,24 @@ def __init__(self, config: IsaacConfig): self.vision_rescale_factor = config.vision_rescale_factor self.vision_token = config.vision_token + # Initialize weights and parallel plans (including tp_plan from the text model) + self.post_init() + + # Respect config-specified gradient checkpointing + if getattr(config, "gradient_checkpointing", False): + self.gradient_checkpointing_enable() + def get_input_embeddings(self) -> nn.Module: return self.text_model.get_input_embeddings() def set_input_embeddings(self, value: nn.Module) -> None: self.text_model.set_input_embeddings(value) + vocab_size = getattr(value, "num_embeddings", None) + if vocab_size is not None: + self.config.vocab_size = vocab_size + if hasattr(self.config, "text_config"): + self.config.text_config.vocab_size = vocab_size + self.text_model.config.vocab_size = vocab_size @property def embed_tokens(self) -> nn.Module: @@ -1785,6 +1625,14 @@ def layers(self) -> nn.ModuleList: def norm(self) -> nn.Module: return self.text_model.norm + @property + def vision_model(self) -> nn.Module: + return self.vision_embedding.vision_tower + + @property + def vision_tower(self) -> nn.Module: + return self.vision_embedding.vision_tower + def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: """Embed text tokens, squeezing singleton dimensions.""" # Text events are shaped as (..., 1); squeeze the singleton index dim @@ -1844,6 +1692,7 @@ def forward( past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, @@ -1865,47 +1714,72 @@ def forward( omitted. """ + text_value = TextType.text.value if TextType is not None else 0 + # Get inputs if tensor_stream is not None and inputs_embeds is not None: raise ValueError("You cannot specify both tensor_stream and inputs_embeds") - elif tensor_stream is not None: - # Embed TensorStream directly - inputs_embeds = self.embed_stream(tensor_stream) - # Create modality tensor if not provided - if modality_tensor is None: - modality_tensor = modality_mask(tensor_stream) - elif input_ids is not None and inputs_embeds is not None: + if tensor_stream is None and input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + + # Resolve the input source (TensorStream takes precedence over token ids). + if tensor_stream is not None: + inputs_embeds = self.embed_stream(tensor_stream) elif input_ids is not None: inputs_embeds = self.text_model.embed_tokens(input_ids) - # Create text modality tensor if not provided - if modality_tensor is None: - batch_size, seq_length = input_ids.shape - modality_tensor = torch.full( - (batch_size, seq_length), TextType.text.value, device=input_ids.device, dtype=torch.long - ) elif inputs_embeds is None: raise ValueError("You have to specify either tensor_stream, input_ids or inputs_embeds") + batch_size, seq_len = inputs_embeds.shape[:2] + # Ensure cache exists when requested if use_cache and past_key_values is None: cache_config = self.config.get_text_config() if hasattr(self.config, "get_text_config") else self.config past_key_values = DynamicCache(config=cache_config) - if cache_position is None and (past_key_values is not None or use_cache): + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - ) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long) + + # Normalize modality tensor + if modality_tensor is None: + if tensor_stream is not None: + modality_tensor = modality_mask(tensor_stream) + else: + modality_tensor = torch.full( + (batch_size, seq_len), text_value, device=inputs_embeds.device, dtype=torch.long + ) + else: + modality_tensor = modality_tensor.to(dtype=torch.long) - # Create default position_ids if not provided + if modality_tensor.shape[1] != seq_len: + if modality_tensor.shape[1] > seq_len: + modality_tensor = modality_tensor[:, :seq_len] + else: + pad = modality_tensor[:, -1:].expand(-1, seq_len - modality_tensor.shape[1]) + modality_tensor = torch.cat([modality_tensor, pad], dim=1) + + # Normalize position ids if position_ids is None: if tensor_stream is not None: position_ids = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) else: - position_ids = compute_position_ids_input_ids(input_ids) + position_ids = cache_position.view(1, -1).expand(modality_tensor.shape[0], -1) + + # Expand 2D position ids (from generic padding tests or decode cache positions) to 3D MRoPE coords + if position_ids.ndim == 2: + position_ids = position_ids.to(device=inputs_embeds.device) + position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) + + # Align lengths so rotary embedding sees matching shapes + if position_ids.shape[1] != seq_len: + start_positions = position_ids[:, :1, 0] + position_ids = torch.arange(seq_len, device=inputs_embeds.device).view(1, -1) + position_ids = position_ids + start_positions + position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) # Compute MRoPE position embeddings if we have custom rotary_emb cos, sin = self.rotary_emb( @@ -1916,38 +1790,46 @@ def forward( cos = cos.to(inputs_embeds.dtype) sin = sin.to(inputs_embeds.dtype) - # Prepare attention mask + # Flash attention expects 1D position_ids; keep 3D only for rotary phases + decoder_position_ids = position_ids[..., 0] if position_ids.ndim == 3 else position_ids + # Prepare attention mask if not isinstance(attention_mask, dict): - mask_kwargs = { - "config": self.config, - "input_embeds": inputs_embeds, - "attention_mask": attention_mask, - "cache_position": cache_position, - "past_key_values": past_key_values, - "position_ids": position_ids, - } - attention_mask = create_masks_for_generate(**mask_kwargs) + attention_mask = create_masks_for_generate( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=decoder_position_ids, + ) + + is_attention_mask_dict = isinstance(attention_mask, dict) # Initialize hidden states hidden_states = inputs_embeds + all_attentions = [] if output_attentions else None for decoder_layer in self.text_model.layers: layer_attention_mask = ( - attention_mask[decoder_layer.attention_type] if isinstance(attention_mask, dict) else attention_mask + attention_mask[decoder_layer.attention_type] if is_attention_mask_dict else attention_mask ) layer_outputs = decoder_layer( hidden_states, attention_mask=layer_attention_mask, - position_ids=position_ids, + position_ids=decoder_position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=(cos, sin), + output_attentions=output_attentions, **kwargs, ) - hidden_states = layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs + layer_outputs_is_tuple = isinstance(layer_outputs, tuple) + hidden_states = layer_outputs[0] if layer_outputs_is_tuple else layer_outputs + if output_attentions and layer_outputs_is_tuple: + all_attentions.append(layer_outputs[1]) # Final layer norm hidden_states = self.text_model.norm(hidden_states) @@ -1955,6 +1837,8 @@ def forward( return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, + hidden_states=(hidden_states,), + attentions=tuple(all_attentions) if output_attentions else None, ) @@ -1962,6 +1846,9 @@ class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): """Isaac multimodal model for conditional generation.""" config_class = IsaacConfig + _can_compile_fullgraph = False + _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"} + all_tied_weights_keys: dict[str, str] = {"lm_head.weight": "model.text_model.embed_tokens.weight"} def __init__(self, config: IsaacConfig): super().__init__(config) @@ -1971,39 +1858,6 @@ def __init__(self, config: IsaacConfig): # Tracks rotary position offsets computed during a full forward pass so decode steps can reuse them. self.rope_deltas = None - def get_rope_index( - self, - input_ids: Optional[torch.Tensor], - tensor_stream: Optional[TensorStream], - attention_mask: Optional[torch.Tensor], - ) -> tuple[torch.Tensor, torch.Tensor]: - """Compute MRoPE position ids from a TensorStream (or 1D fallback). - - Returns (position_ids, rope_deltas). position_ids is (B,L,3) for MRoPE. - rope_deltas is (B,1) used to advance positions in decode. - """ - # tensor_stream present: compute 3D coords - if tensor_stream is None and input_ids is None: - raise ValueError("`tensor_stream` or `input_ids` must be provided to compute rope indices") - - if tensor_stream is not None: - pos_3d = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) - else: - pos_3d = compute_position_ids_input_ids(input_ids) - B, L, _ = pos_3d.shape - - # Max position per batch across the 3 planes and sequence dimension: (B,) - m_per_batch = pos_3d.amax(dim=(1, 2)) - - # Sequence lengths per batch: (B,) - if attention_mask is None: - seq_lens = torch.full_like(m_per_batch, L) - else: - seq_lens = attention_mask.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=m_per_batch.device) - - rope_deltas = (m_per_batch + 1 - seq_lens).to(dtype=pos_3d.dtype).unsqueeze(1) - return pos_3d, rope_deltas - def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -2014,6 +1868,7 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, @@ -2034,30 +1889,47 @@ def forward( if input_ids is None and inputs_embeds is None and tensor_stream is None: raise ValueError("Either input_ids, inputs_embeds, or tensor_stream must be provided.") + text_value = TextType.text.value if TextType is not None else 0 + + if tensor_stream is None: + if input_ids is not None: + batch_size, seq_len = input_ids.shape + input_device = input_ids.device + else: + batch_size, seq_len = inputs_embeds.shape[:2] + input_device = inputs_embeds.device + # Build position ids (MRoPE) if needed and tensor_stream is available # During decode we reuse `self.rope_deltas` computed on the initial forward pass; `rope_delta` captures how far # cached rotary phases have progressed so we can advance `position_ids` without rebuilding the TensorStream. - if position_ids is None and tensor_stream is not None: - position_ids, self.rope_deltas = self.get_rope_index(input_ids, tensor_stream, attention_mask) - elif position_ids is None and input_ids is not None: - # For text inputs build position ids and modality tensor - position_ids = compute_position_ids_input_ids(input_ids) - if cache_position is not None and self.rope_deltas is not None: - # Combine the incremental decode step (`cache_position`) with cached offsets so hidden states continue - # rotating in lockstep across generation steps. - rope_delta = (cache_position[0] + self.rope_deltas).to(input_ids.device) + if position_ids is None: + if tensor_stream is not None: + position_ids, self.rope_deltas = self.get_rope_index(input_ids, tensor_stream, attention_mask) + elif input_ids is None: + dummy_ids = torch.zeros((batch_size, seq_len), device=input_device, dtype=torch.long) + position_ids = compute_position_ids_input_ids(dummy_ids) else: + position_ids = compute_position_ids_input_ids(input_ids) + rope_delta = 0 - if cache_position is not None and not isinstance(rope_delta, int): # otherwise `deltas` is an int `0` - batch_size = input_ids.shape[0] - rope_delta = rope_delta.repeat_interleave(batch_size // rope_delta.shape[0], dim=0) - position_ids = position_ids.add(rope_delta) + if cache_position is not None and self.rope_deltas is not None: + # Combine the incremental decode step (`cache_position`) with cached offsets so hidden states continue + # rotating in lockstep across generation steps. + rope_delta = (cache_position[0] + self.rope_deltas).to(input_ids.device) + if not isinstance(rope_delta, int): # otherwise `deltas` is an int `0` + rope_delta = rope_delta.repeat_interleave(batch_size // rope_delta.shape[0], dim=0) + + position_ids = position_ids.add(rope_delta) + + if attention_mask is None and tensor_stream is None: + attention_mask = torch.ones((batch_size, seq_len), device=input_device, dtype=torch.long) if tensor_stream is not None: modality_tensor = modality_mask(tensor_stream) else: - batch_size, seq_len = input_ids.shape - modality_tensor = torch.empty(batch_size, seq_len, device=position_ids.device).fill_(TextType.text.value) + modality_tensor = torch.full( + (batch_size, seq_len), text_value, device=position_ids.device, dtype=torch.long + ) outputs = self.model( input_ids=input_ids, @@ -2068,6 +1940,7 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, + output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, @@ -2086,9 +1959,55 @@ def forward( logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, - attentions=None, + attentions=outputs.attentions if output_attentions else None, ) + def set_input_embeddings(self, value: nn.Module) -> None: + self.model.set_input_embeddings(value) + vocab_size = getattr(value, "num_embeddings", None) + if vocab_size is not None: + self.config.vocab_size = vocab_size + self.model.config.vocab_size = vocab_size + if hasattr(self.model, "text_model"): + self.model.text_model.config.vocab_size = vocab_size + if self.lm_head.weight.shape[0] != vocab_size: + self.lm_head = nn.Linear(self.config.hidden_size, vocab_size, bias=False) + if hasattr(self.model, "embed_tokens"): + self.lm_head.weight = self.model.text_model.embed_tokens.weight + + def get_rope_index( + self, + input_ids: Optional[torch.Tensor], + tensor_stream: Optional[TensorStream], + attention_mask: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute MRoPE position ids from a TensorStream (or 1D fallback). + + Returns (position_ids, rope_deltas). position_ids is (B,L,3) for MRoPE. + rope_deltas is (B,1) used to advance positions in decode. + """ + # tensor_stream present: compute 3D coords + if tensor_stream is None and input_ids is None: + raise ValueError("`tensor_stream` or `input_ids` must be provided to compute rope indices") + + if tensor_stream is not None: + pos_3d = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) + else: + pos_3d = compute_position_ids_input_ids(input_ids) + B, L, _ = pos_3d.shape + + # Max position per batch across the 3 planes and sequence dimension: (B,) + m_per_batch = pos_3d.amax(dim=(1, 2)) + + # Sequence lengths per batch: (B,) + if attention_mask is None: + seq_lens = torch.full_like(m_per_batch, L) + else: + seq_lens = attention_mask.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=m_per_batch.device) + + rope_deltas = (m_per_batch + 1 - seq_lens).to(dtype=pos_3d.dtype).unsqueeze(1) + return pos_3d, rope_deltas + def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, @@ -2135,17 +2054,35 @@ def prepare_inputs_for_generation( cache_position = model_inputs.get("cache_position", cache_position) - # Handle TensorStream for first forward pass only - if tensor_stream is not None and (cache_position is None or cache_position[0] == 0): + # Handle TensorStream only for the prefill step + first_step = cache_position is None or cache_position[0] == 0 + if tensor_stream is not None and first_step: model_inputs["tensor_stream"] = tensor_stream - # Let forward rebuild position_ids using cached deltas during decode - model_inputs["position_ids"] = None - # Drop tensor_stream after step 0 - if cache_position is not None and cache_position[0] != 0: + # Let forward rebuild MRoPE coordinates from the TensorStream + model_inputs["position_ids"] = None + else: model_inputs["tensor_stream"] = None + + # TensorStream decode path: preserve rotary offsets from prefill + if tensor_stream is not None and not first_step and self.rope_deltas is not None: + model_inputs["position_ids"] = None + return model_inputs + + # For decode steps, synthesize position_ids that continue from the cache offsets + if model_inputs.get("position_ids") is None and cache_position is not None and not first_step: + batch_size = 1 + if model_inputs.get("input_ids") is not None: + batch_size = model_inputs["input_ids"].shape[0] + elif model_inputs.get("inputs_embeds") is not None: + batch_size = model_inputs["inputs_embeds"].shape[0] + pos_ids = cache_position.view(1, -1).expand(batch_size, -1) + pos_ids = pos_ids.unsqueeze(-1).expand(-1, -1, 3) + model_inputs["position_ids"] = pos_ids + return model_inputs - def can_generate(self) -> bool: + @classmethod + def can_generate(cls) -> bool: return True diff --git a/tests/fixtures/isaac/isaac_checkpoint_hashes.json b/tests/fixtures/isaac/isaac_checkpoint_hashes.json deleted file mode 100644 index 1898c3e23955..000000000000 --- a/tests/fixtures/isaac/isaac_checkpoint_hashes.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "full_model": "e00d024be29cc0a6790dc9f3c2504ad12176dea2332fe342a6272d2c92efdef5", - "core_model": "24bd017cf86d113aefb08bf7d109a196fae831e7a57300c4be05a8f78d2b4b6e", - "vision_modules": "c90c2e60f270c96a7a5c7c2e93815158b5f3247ab076c819da0f4f6358d033c5" -} diff --git a/tests/fixtures/isaac/isaac_generation_golden.json b/tests/fixtures/isaac/isaac_generation_golden.json deleted file mode 100644 index c662e22e2670..000000000000 --- a/tests/fixtures/isaac/isaac_generation_golden.json +++ /dev/null @@ -1,450 +0,0 @@ -{ - "logits_statistics": { - "shape": [ - 10, - 151936 - ], - "numel": 1519360, - "mean": 0.0666899336, - "std": 2.8427821364, - "min": -12.0625, - "max": 31.0, - "sum": 101326.0175427794, - "l2_norm": 3505.0433579135 - }, - "input_ids": [ - [ - 151644, - 872, - 198, - 74785, - 419, - 2168, - 25, - 151645, - 198, - 151644, - 872, - 198, - 768, - 743, - 480, - 159, - -154, - -256, - -256, - -256, - -256, - -256, - -256, - -154, - 159, - 480, - 743, - 768, - 743, - 718, - 462, - 149, - -157, - -256, - -256, - -256, - -256, - -256, - -256, - -157, - 149, - 462, - 718, - 743, - 480, - 462, - 273, - 42, - -183, - -256, - -256, - -256, - -256, - -256, - -256, - -183, - 42, - 273, - 462, - 480, - 159, - 149, - 43, - -87, - -214, - -256, - -256, - -256, - -256, - -256, - -256, - -214, - -87, - 43, - 149, - 159, - 151645, - 198, - 151644, - 77091 - ] - ], - "tensor_stream": { - "shape": [ - 1, - 80 - ], - "token_view": [ - [ - 151644, - 872, - 198, - 74785, - 419, - 2168, - 25, - 151645, - 198, - 151644, - 872, - 198, - 768, - 743, - 480, - 159, - -154, - -256, - -256, - -256, - -256, - -256, - -256, - -154, - 159, - 480, - 743, - 768, - 743, - 718, - 462, - 149, - -157, - -256, - -256, - -256, - -256, - -256, - -256, - -157, - 149, - 462, - 718, - 743, - 480, - 462, - 273, - 42, - -183, - -256, - -256, - -256, - -256, - -256, - -256, - -183, - 42, - 273, - 462, - 480, - 159, - 149, - 43, - -87, - -214, - -256, - -256, - -256, - -256, - -256, - -256, - -214, - -87, - 43, - 149, - 159, - 151645, - 198, - 151644, - 77091 - ] - ], - "modality_mask": [ - [ - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1, - 1, - 1, - 1 - ] - ], - "role_mask": [ - [ - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1, - -1 - ] - ] - }, - "decoded_text": "user\nDescribe this image:\nuser\nug\tifable\ufffd\ufffdable\tifug\tifadd pro\ufffd\ufffd proadd\tifable proleKKle proable\ufffd\ufffdLL\ufffd\ufffd\nassistant\n\n\n\n\nThe image is a close", - "token_ids": [ - 151644, - 872, - 198, - 74785, - 419, - 2168, - 25, - 151645, - 198, - 151644, - 872, - 198, - 768, - 743, - 480, - 159, - -154, - -256, - -256, - -256, - -256, - -256, - -256, - -154, - 159, - 480, - 743, - 768, - 743, - 718, - 462, - 149, - -157, - -256, - -256, - -256, - -256, - -256, - -256, - -157, - 149, - 462, - 718, - 743, - 480, - 462, - 273, - 42, - -183, - -256, - -256, - -256, - -256, - -256, - -256, - -183, - 42, - 273, - 462, - 480, - 159, - 149, - 43, - -87, - -214, - -256, - -256, - -256, - -256, - -256, - -256, - -214, - -87, - 43, - 149, - 159, - 151645, - 198, - 151644, - 77091, - 198, - 151667, - 271, - 151668, - 271, - 785, - 2168, - 374, - 264, - 3265 - ] -} diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py index a353d4df9d59..99c3ffb0962b 100644 --- a/tests/models/isaac/test_modeling_isaac.py +++ b/tests/models/isaac/test_modeling_isaac.py @@ -15,9 +15,7 @@ """Testing suite for the Isaac model.""" import base64 -import hashlib import io -import json import os import unittest from functools import lru_cache @@ -26,8 +24,10 @@ import pytest from huggingface_hub import is_offline_mode +from tests.generation.test_utils import GenerationTesterMixin +from tests.test_configuration_common import ConfigTester +from tests.test_pipeline_mixin import PipelineTesterMixin from transformers import ( - AutoProcessor, AutoTokenizer, IsaacConfig, IsaacForConditionalGeneration, @@ -37,16 +37,13 @@ ) from transformers.image_utils import load_image from transformers.masking_utils import eager_mask, sdpa_mask -from transformers.models.isaac.configuration_isaac import IsaacVisionConfig from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessorFast from transformers.models.isaac.modeling_isaac import ( - IsaacVisionAttention, document_mask_function_from_cu_seqlens, ensure_document_attention_mask, ) from transformers.models.isaac.processing_isaac import IsaacProcessor from transformers.testing_utils import ( - get_tests_dir, require_flash_attn, require_torch, require_vision, @@ -62,48 +59,30 @@ else: Image = None -from ...test_modeling_common import ids_tensor +from ...test_modeling_common import ModelTesterMixin, ids_tensor if is_torch_available(): import torch if is_perceptron_available(): - from perceptron.tensorstream.ops import modality_mask, role_mask, tensor_stream_token_view + from perceptron.pointing.parser import extract_points from perceptron.tensorstream.tensorstream import TensorStream else: TensorStream = None + extract_points = None require_tensorstream = pytest.mark.skipif(TensorStream is None, reason="TensorStream backend is not available") -MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1-Base") -MODEL_REVISION = os.environ.get("ISAAC_TEST_MODEL_REVISION", "refs/pr/3") or None -LOCAL_CHECKPOINT = os.environ.get("ISAAC_TEST_MODEL_PATH") -FIXTURES_DIR = Path(get_tests_dir("fixtures/isaac")) -HASH_FILE = FIXTURES_DIR / "isaac_checkpoint_hashes.json" -GENERATION_GOLDEN_FILE = FIXTURES_DIR / "isaac_generation_golden.json" -HASH_FILTERS = { - "full_model": {"include": None, "exclude": None}, - "core_model": {"include": None, "exclude": {"vision_embedding", "audio_embedding", "inv_freq"}}, - "vision_modules": {"include": {"vision_embedding"}, "exclude": None}, -} -RED_DOT_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg==" - +BASE_MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1-Base") +MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1") -def tensor_stream_snapshot(ts: TensorStream) -> dict[str, object]: - """Summarize TensorStream tokens/modalities using public utilities.""" +BASE_MODEL_REVISION = os.environ.get("ISAAC_TEST_MODEL_REVISION", "refs/pr/3") or None +MODEL_REVISION = os.environ.get("ISAAC_TEST_MODEL_REVISION", "refs/pr/5") or None - token_view = tensor_stream_token_view(ts).cpu().tolist() - modality = modality_mask(ts).cpu().tolist() - roles = role_mask(ts).cpu().tolist() - - return { - "shape": list(ts.shape), - "token_view": token_view, - "modality_mask": modality, - "role_mask": roles, - } +LOCAL_CHECKPOINT = os.environ.get("ISAAC_TEST_MODEL_PATH") +RED_DOT_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg==" def document_to_messages( @@ -149,32 +128,6 @@ def document_to_messages( return messages, images -def _tensor_to_bytes(tensor): - cpu_tensor = tensor.detach().cpu().contiguous() - if cpu_tensor.is_floating_point(): - cpu_tensor = cpu_tensor.to(dtype=torch.float32) - return cpu_tensor.numpy().tobytes() - - -def _iter_filtered_items(state_dict, include=None, exclude=None): - for name, tensor in state_dict.items(): - if include and not any(token in name for token in include): - continue - if exclude and any(token in name for token in exclude): - continue - yield name, tensor - - -def _hash_state_dict(state_dict, *, include=None, exclude=None): - hasher = hashlib.sha256() - items = sorted(_iter_filtered_items(state_dict, include=include, exclude=exclude), key=lambda kv: kv[0]) - for name, tensor in items: - hasher.update(name.encode("utf-8")) - hasher.update(b"\0") - hasher.update(_tensor_to_bytes(tensor)) - return hasher.hexdigest() - - def compute_logits_statistics(tensor: torch.Tensor) -> dict[str, object]: """ Summarize logits with simple statistics that are stable across minor @@ -199,82 +152,97 @@ def _rounded(value: torch.Tensor | float) -> float: } -def _assert_logits_statistics_close( - actual: dict[str, object], - expected: dict[str, object], - *, - rel: float = 1e-5, - abs_tol: float = 1e-6, -) -> None: - assert actual["shape"] == expected["shape"], "Logits shape changed" - assert actual["numel"] == expected["numel"], "Logits numel changed" - for key in ("mean", "std", "min", "max", "sum", "l2_norm"): - assert actual[key] == pytest.approx( - expected[key], - rel=rel, - abs=abs_tol, - ), f"Logits statistic '{key}' drifted" - - -def _hf_from_pretrained(cls, pretrained_id, **kwargs): - """ - Wrapper around `cls.from_pretrained` that automatically injects - the test revision (if any) from MODEL_REVISION. - """ - if MODEL_REVISION is not None: - kwargs.setdefault("revision", MODEL_REVISION) - return cls.from_pretrained(pretrained_id, **kwargs) - - -@pytest.fixture(scope="session") -def tokenizer(isaac_reference_checkpoint): - """Load the tokenizer from the converted Perceptron HF checkpoint.""" - return _hf_from_pretrained( - AutoTokenizer, - isaac_reference_checkpoint, - trust_remote_code=True, - ) - - @require_torch -def test_document_mask_function_from_cu_seqlens(): - cu_seqlens = torch.tensor([0, 3, 5], dtype=torch.int32) - mask_fn = document_mask_function_from_cu_seqlens(cu_seqlens) +class IsaacDocumentMaskingTest(unittest.TestCase): + def test_document_mask_function_from_cu_seqlens(self): + cu_seqlens = torch.tensor([0, 3, 5], dtype=torch.int32) + mask_fn = document_mask_function_from_cu_seqlens(cu_seqlens) + + self.assertIsNotNone(mask_fn) + # Same document (indices 1 and 2) + self.assertTrue(mask_fn(0, 0, 1, 2)) + # Cross-document (index 1 in first doc, 3 in second doc) + self.assertFalse(mask_fn(0, 0, 1, 3)) + # Same second document (indices 3 and 4) + self.assertTrue(mask_fn(0, 0, 4, 3)) + + def test_ensure_document_attention_mask_prefers_callable_when_requested(self): + cu_seqlens = torch.tensor([0, 2, 5], dtype=torch.int32) + total_tokens = 5 + dtype = torch.float32 + + mask_callable = ensure_document_attention_mask( + attention_mask=None, + cu_seqlens=cu_seqlens, + total_tokens=total_tokens, + dtype=dtype, + device=cu_seqlens.device, + return_mask_function=True, + ) + self.assertTrue(callable(mask_callable)) - assert mask_fn is not None - # Same document (indices 1 and 2) - assert mask_fn(0, 0, 1, 2) - # Cross-document (index 1 in first doc, 3 in second doc) - assert not mask_fn(0, 0, 1, 3) - # Same second document (indices 3 and 4) - assert mask_fn(0, 0, 4, 3) + additive = ensure_document_attention_mask( + attention_mask=None, + cu_seqlens=cu_seqlens, + total_tokens=total_tokens, + dtype=dtype, + device=cu_seqlens.device, + return_mask_function=False, + ) + self.assertIsNone(additive) + def test_document_mask_function_materializes_with_masking_utils(self): + cu_seqlens = torch.tensor([0, 2, 4], dtype=torch.int32) + total_tokens = 4 + mask_fn = document_mask_function_from_cu_seqlens(cu_seqlens) -@require_torch -def test_ensure_document_attention_mask_prefers_callable_when_requested(): - cu_seqlens = torch.tensor([0, 2, 5], dtype=torch.int32) - total_tokens = 5 - dtype = torch.float32 - - mask_callable = ensure_document_attention_mask( - attention_mask=None, - cu_seqlens=cu_seqlens, - total_tokens=total_tokens, - dtype=dtype, - device=cu_seqlens.device, - return_mask_function=True, - ) - assert callable(mask_callable) - - additive = ensure_document_attention_mask( - attention_mask=None, - cu_seqlens=cu_seqlens, - total_tokens=total_tokens, - dtype=dtype, - device=cu_seqlens.device, - return_mask_function=False, - ) - assert additive is None + cache_position = torch.arange(total_tokens, device=cu_seqlens.device, dtype=torch.long) + expected_bool = torch.tensor( + [ + [ + [ + [True, True, False, False], + [True, True, False, False], + [False, False, True, True], + [False, False, True, True], + ] + ] + ], + device=cu_seqlens.device, + ) + + sdpa = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=total_tokens, + kv_offset=0, + mask_function=mask_fn, + attention_mask=None, + allow_is_causal_skip=False, + allow_is_bidirectional_skip=False, + allow_torch_fix=False, + use_vmap=False, + ) + # sdpa_mask returns True for allowed positions; SDPA expects True to mean "mask out" + self.assertTrue(torch.equal(sdpa, expected_bool)) + + eager = eager_mask( + batch_size=1, + cache_position=cache_position, + kv_length=total_tokens, + kv_offset=0, + mask_function=mask_fn, + attention_mask=None, + allow_is_bidirectional_skip=False, + use_vmap=False, + dtype=torch.float32, + ) + expected_additive = torch.where( + expected_bool, + torch.tensor(0.0, device=cu_seqlens.device, dtype=torch.float32), + torch.tensor(torch.finfo(torch.float32).min, device=cu_seqlens.device, dtype=torch.float32), + ) + self.assertTrue(torch.equal(eager, expected_additive)) def create_isaac_processor( @@ -322,148 +290,6 @@ def create_isaac_processor( ) -@require_torch -def test_document_mask_function_materializes_with_masking_utils(): - cu_seqlens = torch.tensor([0, 2, 4], dtype=torch.int32) - total_tokens = 4 - mask_fn = document_mask_function_from_cu_seqlens(cu_seqlens) - - cache_position = torch.arange(total_tokens, device=cu_seqlens.device, dtype=torch.long) - expected_bool = torch.tensor( - [ - [ - [ - [True, True, False, False], - [True, True, False, False], - [False, False, True, True], - [False, False, True, True], - ] - ] - ], - device=cu_seqlens.device, - ) - - sdpa = sdpa_mask( - batch_size=1, - cache_position=cache_position, - kv_length=total_tokens, - kv_offset=0, - mask_function=mask_fn, - attention_mask=None, - allow_is_causal_skip=False, - allow_is_bidirectional_skip=False, - allow_torch_fix=False, - use_vmap=False, - ) - # sdpa_mask returns True for allowed positions; SDPA expects True to mean "mask out" - assert torch.equal(sdpa, expected_bool) - - eager = eager_mask( - batch_size=1, - cache_position=cache_position, - kv_length=total_tokens, - kv_offset=0, - mask_function=mask_fn, - attention_mask=None, - allow_is_bidirectional_skip=False, - use_vmap=False, - dtype=torch.float32, - ) - expected_additive = torch.where( - expected_bool, - torch.tensor(0.0, device=cu_seqlens.device, dtype=torch.float32), - torch.tensor(torch.finfo(torch.float32).min, device=cu_seqlens.device, dtype=torch.float32), - ) - assert torch.equal(eager, expected_additive) - - -@require_torch -def test_isaac_sdpa_attention_backend(): - config = IsaacVisionConfig( - hidden_size=32, - intermediate_size=64, - num_hidden_layers=1, - num_attention_heads=4, - num_channels=3, - num_patches=16, - patch_size=4, - ) - config._attn_implementation = "sdpa" - - attn_module = IsaacVisionAttention(config).eval() - seq_len = 8 - hidden_states = torch.randn(1, seq_len, config.hidden_size) - cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32) - - with torch.no_grad(): - outputs, attn_weights = attn_module( - hidden_states=hidden_states, - cu_seqlens=cu_seqlens, - max_seqlen=seq_len, - ) - - assert outputs.shape == hidden_states.shape - assert attn_weights is None - - -@require_torch -@require_flash_attn -def test_isaac_flash_attention_backend(): - config = IsaacVisionConfig( - hidden_size=32, - intermediate_size=64, - num_hidden_layers=1, - num_attention_heads=4, - num_channels=3, - num_patches=16, - patch_size=4, - ) - config._attn_implementation = "flash_attention_3" - - attn_module = IsaacVisionAttention(config).half().eval().cuda() - seq_len = 8 - hidden_states = torch.randn(1, seq_len, config.hidden_size, device=torch.device("cuda"), dtype=torch.float16) - cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=torch.device("cuda")) - - with torch.no_grad(): - outputs, attn_weights = attn_module( - hidden_states=hidden_states, - cu_seqlens=cu_seqlens, - max_seqlen=seq_len, - ) - - assert outputs.shape == hidden_states.shape - assert attn_weights is None - - -@lru_cache(maxsize=1) -def _load_expected_hashes(): - if not HASH_FILE.exists(): - return None - with HASH_FILE.open("r", encoding="utf-8") as fh: - return json.load(fh) - - -@lru_cache(maxsize=1) -def _load_generation_golden(): - if not GENERATION_GOLDEN_FILE.exists(): - return None - with GENERATION_GOLDEN_FILE.open("r", encoding="utf-8") as fh: - return json.load(fh) - - -def safe_decode(tokenizer, token_ids): - if isinstance(token_ids, torch.Tensor): - token_ids = token_ids.tolist() - try: - text = tokenizer.decode(token_ids, skip_special_tokens=True) - except Exception: - tokens = tokenizer.convert_ids_to_tokens(token_ids, skip_special_tokens=True) - tokens = [tok for tok in tokens if tok is not None] - text = tokenizer.convert_tokens_to_string(tokens) - return text.strip() if isinstance(text, str) else text - - @lru_cache(maxsize=1) def _load_red_dot_image(): if Image is None: @@ -472,9 +298,18 @@ def _load_red_dot_image(): return Image.open(io.BytesIO(data)).convert("RGB") +def _base_reference_checkpoint_or_skip(): + if LOCAL_CHECKPOINT: + resolved = Path(LOCAL_CHECKPOINT).expanduser() + if not resolved.exists(): + pytest.skip(f"Local checkpoint path {resolved} does not exist.") + return str(resolved) + if is_offline_mode(): + pytest.skip("Offline mode: set ISAAC_TEST_MODEL_PATH to a local checkpoint to run these tests.") + return BASE_MODEL_ID + + def _reference_checkpoint_or_skip(): - if TensorStream is None: - pytest.skip("TensorStream dependency is required for Isaac integration tests.") if LOCAL_CHECKPOINT: resolved = Path(LOCAL_CHECKPOINT).expanduser() if not resolved.exists(): @@ -545,79 +380,6 @@ def save_vocabulary(self, save_directory, filename_prefix=None): return () -def _make_dummy_image(size=(32, 32), color=(255, 0, 0)): - if Image is None: - raise RuntimeError("PIL.Image is not available in this environment.") - return Image.new("RGB", size, color=color) - - -@pytest.fixture -def isaac_tiny_config(): - tester = IsaacModelTester(parent=None) - return tester.get_config() - - -@pytest.fixture -def isaac_tokenizer(): - return SimpleIsaacTokenizer() - - -@pytest.fixture -def isaac_processor(isaac_tokenizer, isaac_tiny_config): - vision_config = isaac_tiny_config.vision_config - image_processor = IsaacImageProcessorFast( - patch_size=vision_config.patch_size, - max_num_patches=vision_config.num_patches, - pixel_shuffle_scale=vision_config.pixel_shuffle_scale_factor, - rescale_factor=isaac_tiny_config.vision_rescale_factor, - ) - return IsaacProcessor( - image_processor=image_processor, - tokenizer=isaac_tokenizer, - config=isaac_tiny_config, - ) - - -@pytest.fixture(scope="session") -def isaac_reference_checkpoint(): - return _reference_checkpoint_or_skip() - - -@pytest.fixture(scope="session") -def isaac_config(isaac_reference_checkpoint): - """Load IsaacConfig from the converted checkpoint.""" - # Load the config directly from the converted checkpoint - config = _hf_from_pretrained(IsaacConfig, isaac_reference_checkpoint) - # Most tests assume flash attention in vision unless they explicitly override it. - config.vision_attn_implementation = "flash_attention_2" - return config - - -@pytest.fixture(scope="session") -def isaac_reference_model(isaac_reference_checkpoint, isaac_config): - model_config = IsaacConfig.from_dict(isaac_config.to_dict()) - model_config.vision_attn_implementation = isaac_config.vision_attn_implementation - model = _hf_from_pretrained( - IsaacForConditionalGeneration, - isaac_reference_checkpoint, - config=model_config, - attn_implementation="sdpa", - ) - return model - - -@pytest.fixture(scope="session") -def isaac_reference_processor(isaac_reference_checkpoint): - try: - processor = _hf_from_pretrained(AutoProcessor, isaac_reference_checkpoint) - except (OSError, ValueError) as error: - raise RuntimeError(f"Unable to load reference Isaac processor from {isaac_reference_checkpoint}") from error - print(f"[Isaac tests] Loaded processor type: {type(processor)} from {isaac_reference_checkpoint}") - if not isinstance(processor, IsaacProcessor): - pytest.skip("Loaded processor is not an IsaacProcessor instance.") - return processor - - class IsaacModelTester: def __init__( self, @@ -636,6 +398,8 @@ def __init__( self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads + self.is_training = True + self.expected_num_hidden_layers = 1 self.text_config = { "bos_token_id": 0, @@ -692,13 +456,66 @@ def prepare_config_and_inputs(self): labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) return config, input_ids, attention_mask, labels + def prepare_config_and_inputs_for_common(self): + config, input_ids, attention_mask, labels = self.prepare_config_and_inputs() + position_ids = torch.arange(self.seq_length, device=torch_device).view(1, -1) + position_ids = position_ids.expand(self.batch_size, -1).unsqueeze(2).expand(-1, -1, 3) + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + } + if labels is not None: + inputs_dict["labels"] = labels + return config, inputs_dict + @require_torch -class IsaacModelTest(unittest.TestCase): +class IsaacModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (IsaacModel, IsaacForConditionalGeneration) if is_torch_available() else () + pipeline_model_mapping = ( + { + "image-to-text": IsaacForConditionalGeneration, + "image-text-to-text": IsaacForConditionalGeneration, + } + if is_torch_available() + else {} + ) + _is_composite = True + test_attention_outputs = False + test_all_params_have_gradient = False def setUp(self): self.model_tester = IsaacModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=IsaacConfig, + has_text_modality=False, + ) + + def test_config(self): + self.maxDiff = None + self.config_tester.run_common_tests() + + @unittest.skip(reason="Assisted decoding not supported; Qwen3 backbone does not implement returning attentions") + def test_assisted_decoding_matches_greedy_search_0_random(self): + pass + + @unittest.skip(reason="Assisted decoding not supported; Qwen3 backbone does not implement returning attentions") + def test_assisted_decoding_matches_greedy_search_1_same(self): + pass + + @unittest.skip(reason="Assisted decoding not supported; Qwen3 backbone does not implement returning attentions") + def test_assisted_decoding_sample(self): + pass + + @unittest.skip(reason="Prompt lookup decoding not supported; Qwen3 backbone does not return attentions") + def test_prompt_lookup_decoding_matches_greedy_search(self): + pass + + @unittest.skip(reason="Output attentions not supported") + def test_retain_grad_hidden_states_attentions(self): + pass @require_tensorstream def test_model_forward(self): @@ -737,244 +554,78 @@ def test_prepare_inputs_for_generation(self): prepared_inputs = model.prepare_inputs_for_generation(input_ids=input_ids, attention_mask=attention_mask) self.assertIn("input_ids", prepared_inputs) self.assertIn("position_ids", prepared_inputs) - self.assertIsNone(prepared_inputs["position_ids"]) - - -def test_isaac_config_extends_qwen3_defaults(isaac_tiny_config): - assert isaac_tiny_config.hidden_size == isaac_tiny_config.text_config.hidden_size - assert isaac_tiny_config.num_attention_heads == isaac_tiny_config.text_config.num_attention_heads - assert isaac_tiny_config.model_type == "isaac" - assert isaac_tiny_config.vision_config is not None - assert isaac_tiny_config.vision_config.patch_size == 4 - assert isaac_tiny_config.vision_config.num_patches == 64 - assert isaac_tiny_config.max_sequence_length == 16384 - assert isaac_tiny_config.vision_rescale_factor == pytest.approx(1 / 255) - assert isaac_tiny_config.vision_token == "" + @require_tensorstream + def test_isaac_for_conditional_generation_initialization(self): + config = self.model_tester.get_config() + model = IsaacForConditionalGeneration(config) + model.to(torch_device) -@require_torch -@require_tensorstream -def test_isaac_for_conditional_generation_initialization(isaac_tiny_config): - model = IsaacForConditionalGeneration(isaac_tiny_config) - model.to(torch_device) - assert hasattr(model, "model") - assert hasattr(model, "lm_head") - assert hasattr(model.model, "vision_embedding") - assert hasattr(model.model, "embed_fns") - - input_ids = torch.randint(0, isaac_tiny_config.vocab_size, (1, 10), device=torch_device, dtype=torch.long) - with torch.no_grad(): - outputs = model(input_ids=input_ids, return_dict=True) - assert outputs.logits.shape == (1, 10, isaac_tiny_config.vocab_size) - - -@require_torch -@require_tensorstream -def test_isaac_for_conditional_generation_loss_and_generate_flag(isaac_tiny_config): - model = IsaacForConditionalGeneration(isaac_tiny_config).to(torch_device) - assert model.can_generate() - - batch_size, seq_len = 1, 8 - input_ids = torch.randint(0, isaac_tiny_config.vocab_size, (batch_size, seq_len), device=torch_device) - labels = torch.randint(0, isaac_tiny_config.vocab_size, (batch_size, seq_len), device=torch_device) - with torch.no_grad(): - outputs = model(input_ids=input_ids, labels=labels, return_dict=True) - assert outputs.loss is not None - assert outputs.loss.ndim == 0 - assert outputs.logits.shape == (batch_size, seq_len, isaac_tiny_config.vocab_size) - - -@require_torch -@require_vision -@require_tensorstream -def test_isaac_processor_matches_config_defaults(isaac_processor, isaac_tiny_config): - assert isaac_processor.vision_token == isaac_tiny_config.vision_token - assert isaac_processor.max_sequence_length == isaac_tiny_config.max_sequence_length - assert isaac_processor.config is isaac_tiny_config - assert isinstance(isaac_processor.image_processor, IsaacImageProcessorFast) - assert isaac_processor.image_processor.rescale_factor == pytest.approx(isaac_tiny_config.vision_rescale_factor) - - -@require_torch -@require_vision -@require_tensorstream -def test_isaac_processor_text_only_round_trip(isaac_processor): - messages = [{"role": "user", "content": "Hello, how are you?"}] - prompt = isaac_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - outputs = isaac_processor(text=prompt, images=None, return_tensors="pt") - - assert "input_ids" in outputs - assert "tensor_stream" in outputs - assert isinstance(outputs["tensor_stream"], TensorStream) - assert outputs["input_ids"].shape[0] == 1 - - -@require_torch -@require_tensorstream -def test_isaac_processor_accepts_batchencoding_chat_template(isaac_processor): - messages = [{"role": "user", "content": "Hello, how are you?"}] - batch_encoding = isaac_processor.apply_chat_template(messages, add_generation_prompt=True) - - outputs = isaac_processor(text=batch_encoding, images=None, return_tensors="pt") - - assert "input_ids" in outputs - assert "tensor_stream" in outputs - assert isinstance(outputs["tensor_stream"], TensorStream) - assert outputs["input_ids"].shape[0] == 1 - - -@require_torch -@require_vision -@require_tensorstream -def test_isaac_processor_with_single_image(isaac_processor): - vision_token = isaac_processor.vision_token - text = f"Look at this {vision_token} and describe it." - image = _make_dummy_image() - - outputs = isaac_processor(text=text, images=[image], return_tensors="pt") - assert isinstance(outputs["tensor_stream"], TensorStream) - assert outputs["input_ids"].ndim == 2 - - -@require_torch -@require_vision -@require_tensorstream -def test_isaac_processor_with_multiple_images(isaac_processor): - vision_token = isaac_processor.vision_token - text = f"First {vision_token} then {vision_token}" - images = [_make_dummy_image(color=(255, 0, 0)), _make_dummy_image(color=(0, 255, 0))] - - outputs = isaac_processor(text=text, images=images, return_tensors="pt") - assert isinstance(outputs["tensor_stream"], TensorStream) - assert outputs["input_ids"].shape[0] == 1 - - -@require_torch -@require_vision -@require_tensorstream -def test_isaac_processor_error_on_image_mismatch(isaac_processor): - vision_token = isaac_processor.vision_token - text = f"{vision_token} {vision_token}" - image = _make_dummy_image() - - with pytest.raises(ValueError, match="must match number of images"): - isaac_processor(text=text, images=[image], return_tensors="pt") - - -@require_torch -@require_vision -@require_tensorstream -def test_isaac_processor_consistent_tensor_stream_types(isaac_processor): - text_only = "Simple question?" - text_with_image = f"Describe this {isaac_processor.vision_token}" - image = _make_dummy_image() - - outputs_text = isaac_processor(text=text_only, images=None, return_tensors="pt") - outputs_image = isaac_processor(text=text_with_image, images=[image], return_tensors="pt") - - assert isinstance(outputs_text["tensor_stream"], TensorStream) - assert isinstance(outputs_image["tensor_stream"], TensorStream) - assert outputs_text["input_ids"].shape[0] == outputs_image["input_ids"].shape[0] == 1 - + self.assertTrue(hasattr(model, "model")) + self.assertTrue(hasattr(model, "lm_head")) + self.assertTrue(hasattr(model.model, "vision_embedding")) + self.assertTrue(hasattr(model.model, "embed_fns")) -@require_torch -@require_vision -@require_tensorstream -def test_isaac_generation_with_tensor_stream(isaac_processor, isaac_tiny_config): - model = IsaacForConditionalGeneration(isaac_tiny_config).to(torch_device) - model.eval() - - messages = [{"role": "user", "content": "Hello there!"}] - prompt = isaac_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - processed = isaac_processor(text=prompt, images=None, return_tensors="pt") - - input_ids = processed["input_ids"].to(torch_device) - tensor_stream = processed["tensor_stream"] - tensor_stream = tensor_stream.to(torch_device) - generated = model.generate( - input_ids=input_ids, - tensor_stream=tensor_stream, - max_new_tokens=5, - do_sample=False, - pad_token_id=isaac_processor.tokenizer.pad_token_id, - eos_token_id=isaac_processor.tokenizer.eos_token_id, - ) + input_ids = torch.randint(0, config.vocab_size, (1, 10), device=torch_device, dtype=torch.long) + with torch.no_grad(): + outputs = model(input_ids=input_ids, return_dict=True) + self.assertEqual(outputs.logits.shape, (1, 10, config.vocab_size)) - assert generated.shape[0] == 1 - assert generated.shape[1] >= input_ids.shape[1] - decoded_prompt = isaac_processor.tokenizer.decode(generated[0], skip_special_tokens=True) - assert isinstance(decoded_prompt, str) - assert decoded_prompt.strip() != "" + @require_tensorstream + def test_isaac_for_conditional_generation_loss_and_generate_flag(self): + config = self.model_tester.get_config() + model = IsaacForConditionalGeneration(config).to(torch_device) + self.assertTrue(model.can_generate()) + + batch_size, seq_len = 1, 8 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=torch_device) + labels = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=torch_device) + with torch.no_grad(): + outputs = model(input_ids=input_ids, labels=labels, return_dict=True) + self.assertIsNotNone(outputs.loss) + self.assertEqual(outputs.loss.ndim, 0) + self.assertEqual(outputs.logits.shape, (batch_size, seq_len, config.vocab_size)) + @require_vision + @require_tensorstream + def test_isaac_generation_with_tensor_stream(self): + config = self.model_tester.get_config() + tokenizer = SimpleIsaacTokenizer() + image_processor = IsaacImageProcessorFast( + patch_size=config.vision_config.patch_size, + max_num_patches=config.vision_config.num_patches, + pixel_shuffle_scale=config.vision_config.pixel_shuffle_scale_factor, + rescale_factor=config.vision_rescale_factor, + ) + processor = IsaacProcessor( + image_processor=image_processor, + tokenizer=tokenizer, + config=config, + ) -@require_torch -@require_vision -@slow -@require_tensorstream -def test_hf_generate_vs_training_generate_logits(isaac_reference_model, isaac_reference_processor): - device = "cuda" - dtype = torch.bfloat16 - isaac_reference_model = isaac_reference_model.to(device=device, dtype=dtype) - isaac_reference_model.eval() - golden = _load_generation_golden() - if not golden: - pytest.skip(f"Missing generation golden file at {GENERATION_GOLDEN_FILE}.") + model = IsaacForConditionalGeneration(config).to(torch_device) + model.eval() - image = _load_red_dot_image() - if image is None: - pytest.skip("PIL.Image is required for Isaac generation tests.") + messages = [{"role": "user", "content": "Hello there!"}] + prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + processed = processor(text=prompt, images=None, return_tensors="pt") - messages = [ - { - "role": "user", - "content": "Describe this image:", - }, - { - "role": "user", - "content": "", - }, - ] - prompt = isaac_reference_processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ).strip() - batch = isaac_reference_processor(text=prompt, images=[image], return_tensors="pt") - - input_ids = batch["input_ids"] - tensor_stream = batch["tensor_stream"] - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - isaac_reference_model.to(device) - input_ids = input_ids.to(device) - if tensor_stream is not None and hasattr(tensor_stream, "to"): - tensor_stream = tensor_stream.to(device) - - torch.manual_seed(0) - with torch.no_grad(): - outputs = isaac_reference_model.generate( + input_ids = processed["input_ids"].to(torch_device) + tensor_stream = processed["tensor_stream"].to(torch_device) + generated = model.generate( input_ids=input_ids, tensor_stream=tensor_stream, - max_new_tokens=10, + max_new_tokens=5, do_sample=False, - pad_token_id=isaac_reference_processor.tokenizer.eos_token_id, - eos_token_id=isaac_reference_processor.tokenizer.eos_token_id, - return_dict_in_generate=True, - output_logits=True, + pad_token_id=processor.tokenizer.pad_token_id, + eos_token_id=processor.tokenizer.eos_token_id, ) - logits = torch.cat(outputs.logits, dim=0).to(torch.float32).cpu() - logits_stats = compute_logits_statistics(logits) - generated_ids = outputs.sequences[0].tolist() - - assert generated_ids == golden["token_ids"], "Generated token ids changed" - if "logits_statistics" in golden: - _assert_logits_statistics_close(logits_stats, golden["logits_statistics"]) - else: - pytest.fail( - "Golden file missing both logits_statistics and logits_hash. " - f"Regenerate {GENERATION_GOLDEN_FILE} via scripts/update_isaac_hashes.py." - ) - - isaac_reference_model.to("cpu") + self.assertEqual(generated.shape[0], 1) + self.assertGreaterEqual(generated.shape[1], input_ids.shape[1]) + decoded_prompt = processor.tokenizer.decode(generated[0], skip_special_tokens=True) + self.assertIsInstance(decoded_prompt, str) + self.assertNotEqual(decoded_prompt.strip(), "") @require_torch @@ -988,16 +639,16 @@ class IsaacGenerationIntegrationTest(unittest.TestCase): def setUp(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.checkpoint = _reference_checkpoint_or_skip() - self.hf_config = IsaacConfig.from_pretrained(self.checkpoint, revision=MODEL_REVISION) + self.checkpoint = _base_reference_checkpoint_or_skip() + self.hf_config = IsaacConfig.from_pretrained(self.checkpoint, revision=BASE_MODEL_REVISION) self.tokenizer = AutoTokenizer.from_pretrained( - self.checkpoint, trust_remote_code=True, use_fast=False, revision=MODEL_REVISION + self.checkpoint, trust_remote_code=True, use_fast=False, revision=BASE_MODEL_REVISION ) self.processor = create_isaac_processor(self.tokenizer, self.hf_config) self.hf_config.vision_config._attn_implementation = "flash_attention_2" self.hf_config.vision_config.attn_implementation = "flash_attention_2" self.model = IsaacForConditionalGeneration.from_pretrained( - self.checkpoint, config=self.hf_config, revision=MODEL_REVISION + self.checkpoint, config=self.hf_config, revision=BASE_MODEL_REVISION ) self.model = self.model.to(device=self.device, dtype=self.dtype) self.model.eval() @@ -1065,3 +716,117 @@ def test_vqa_from_image(self): generated_text = self._generate_from_messages(messages, images, num_tokens=256) expected_response = "\nNo, it is not safe to cross the street at this moment. The traffic light for pedestrians is red, indicating that it is not safe to cross." assert generated_text == expected_response + + def test_logit_equivalence(self): + image = _load_red_dot_image() + if image is None: + pytest.skip("PIL.Image is required for Isaac generation tests.") + image_bytes = base64.b64decode(RED_DOT_B64) + pil_image = Image.open(io.BytesIO(image_bytes)) + images = [] + images.append(pil_image) + num_tokens = 10 + + messages = [ + {"role": "user", "content": "Describe this image:"}, + {"role": "user", "content": ""}, + ] + prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() + processor_output = self.processor(text=prompt, images=images, return_tensors="pt") + tensor_stream = processor_output["tensor_stream"].to(self.device) + + with torch.no_grad(): + outputs = self.model.generate( + tensor_stream=tensor_stream, + max_new_tokens=num_tokens or self.max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + output_logits=True, + ) + hf_logits = torch.cat(outputs.logits, dim=0) + logit_stats = compute_logits_statistics(hf_logits) + expected_logit_stats = { + "shape": [10, 151936], + "numel": 1519360, + "mean": 0.0879677375, + "std": 2.8382794404, + "min": -12.125, + "max": 31.0, + "sum": 133654.661714755, + "l2_norm": 3500.2090570868, + } + assert logit_stats == expected_logit_stats + + +@require_torch +@require_vision +@slow +@require_tensorstream +@require_flash_attn +class IsaacBoxPointingIntegrationTest(unittest.TestCase): + max_new_tokens = 256 + dtype = torch.bfloat16 + + def setUp(self): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.checkpoint = _reference_checkpoint_or_skip() + self.hf_config = IsaacConfig.from_pretrained(self.checkpoint, revision=MODEL_REVISION) + self.tokenizer = AutoTokenizer.from_pretrained( + self.checkpoint, trust_remote_code=True, use_fast=False, revision=MODEL_REVISION + ) + self.processor = create_isaac_processor(self.tokenizer, self.hf_config) + self.hf_config.vision_config._attn_implementation = "flash_attention_2" + self.hf_config.vision_config.attn_implementation = "flash_attention_2" + self.model = IsaacForConditionalGeneration.from_pretrained( + self.checkpoint, config=self.hf_config, revision=MODEL_REVISION + ) + self.model = self.model.to(device=self.device, dtype=self.dtype) + self.model.eval() + + def test_hf_generate_box_points(self): + document = [ + { + "type": "text", + "content": "BOX", + "role": "user", + }, + { + "type": "image", + "content": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", + "role": "user", + }, + { + "type": "text", + "content": "Determine whether it is safe to cross the street. Look for signage and moving traffic.", + "role": "user", + }, + ] + messages, images = document_to_messages(document, vision_token=self.hf_config.vision_token) + prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() + processor_output = self.processor(text=prompt, images=images, return_tensors="pt") + tensor_stream = processor_output["tensor_stream"].to(self.device) + + with torch.no_grad(): + outputs = self.model.generate( + tensor_stream=tensor_stream, + max_new_tokens=self.max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + ) + + generated_ids = outputs.sequences + hf_generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) + points = extract_points(hf_generated_text) + assert len(points) == 1 + first_point = points[0] + assert first_point.top_left.x < first_point.bottom_right.x + assert first_point.top_left.y < first_point.bottom_right.y + assert first_point.mention == "traffic light" + assert first_point.top_left.x == 808 + assert first_point.top_left.y == 247 + assert first_point.bottom_right.x == 863 + assert first_point.bottom_right.y == 386 diff --git a/tests/models/isaac/test_processing_isaac.py b/tests/models/isaac/test_processing_isaac.py new file mode 100644 index 000000000000..4d173fdac305 --- /dev/null +++ b/tests/models/isaac/test_processing_isaac.py @@ -0,0 +1,264 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Testing suite for the Isaac processor.""" + +import pytest + +from transformers import IsaacConfig, PythonBackend +from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessorFast +from transformers.models.isaac.processing_isaac import IsaacProcessor +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_vision_available +from transformers.utils.import_utils import is_perceptron_available + + +if is_vision_available(): + from PIL import Image +else: + Image = None + + +if is_perceptron_available(): + from perceptron.tensorstream.tensorstream import TensorStream +else: + TensorStream = None + + +require_tensorstream = pytest.mark.skipif(TensorStream is None, reason="TensorStream backend is not available") + + +class SimpleIsaacTokenizer(PythonBackend): + vocab_files_names = {} + model_input_names = ["input_ids"] + + def __init__(self): + self._vocab = { + "": 0, + "": 1, + "": 2, + "": 3, + "": 4, + } + self._ids_to_tokens = {idx: tok for tok, idx in self._vocab.items()} + super().__init__( + bos_token="", + eos_token="", + pad_token="", + unk_token="", + extra_special_tokens=[""], + model_max_length=512, + ) + self.chat_template = ( + "{% for message in messages %}" + "{{ message['role'] }}: {{ message['content'] | trim }}\n" + "{% endfor %}" + "{% if add_generation_prompt %}assistant:{% endif %}" + ) + + def get_vocab(self): + return dict(self._vocab) + + def _tokenize(self, text): + clean = text.replace("\n", " ").strip() + if not clean: + return [] + return [token for token in clean.split(" ") if token] + + def _convert_token_to_id(self, token): + if token not in self._vocab: + next_id = len(self._vocab) + self._vocab[token] = next_id + self._ids_to_tokens[next_id] = token + return self._vocab[token] + + def _convert_id_to_token(self, index): + return self._ids_to_tokens.get(index, self.unk_token) + + @property + def vocab_size(self) -> int: + return len(self._vocab) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + if token_ids_1 is not None: + token_ids_0 = token_ids_0 + token_ids_1 + return [self.bos_token_id] + list(token_ids_0) + [self.eos_token_id] + + def save_vocabulary(self, save_directory, filename_prefix=None): + return () + + +def _make_dummy_image(size=(32, 32), color=(255, 0, 0)): + if Image is None: + raise RuntimeError("PIL.Image is not available in this environment.") + return Image.new("RGB", size, color=color) + + +@pytest.fixture +def isaac_tiny_config(): + text_config = { + "bos_token_id": 0, + "eos_token_id": 1, + "pad_token_id": 2, + "hidden_act": "silu", + "head_dim": 32 // 4, + "hidden_size": 32, + "vocab_size": 99, + "intermediate_size": 32 * 3, + "max_position_embeddings": 128, + "model_type": "qwen3", + "num_attention_heads": 4, + "num_hidden_layers": 2, + "num_key_value_heads": 4, + "rope_parameters": {"rope_type": "default", "mrope_section": [2, 1, 1], "mrope_interleaved": True}, + "tie_word_embeddings": True, + } + + vision_config = { + "hidden_size": 32, + "intermediate_size": 32 * 2, + "num_hidden_layers": 1, + "num_attention_heads": 4, + "num_channels": 3, + "num_patches": 64, + "patch_size": 4, + "pixel_shuffle_scale_factor": 1, + "attention_dropout": 0.0, + "layer_norm_eps": 1e-6, + } + + config = IsaacConfig(text_config=text_config, vision_config=vision_config) + config._attn_implementation = "sdpa" + config.text_config._attn_implementation = "sdpa" + config.vision_attn_implementation = "sdpa" + return config + + +@pytest.fixture +def isaac_tokenizer(): + return SimpleIsaacTokenizer() + + +@pytest.fixture +def isaac_processor(isaac_tokenizer, isaac_tiny_config): + vision_config = isaac_tiny_config.vision_config + image_processor = IsaacImageProcessorFast( + patch_size=vision_config.patch_size, + max_num_patches=vision_config.num_patches, + pixel_shuffle_scale=vision_config.pixel_shuffle_scale_factor, + rescale_factor=isaac_tiny_config.vision_rescale_factor, + ) + return IsaacProcessor( + image_processor=image_processor, + tokenizer=isaac_tokenizer, + config=isaac_tiny_config, + ) + + +@require_torch +@require_vision +@require_tensorstream +def test_isaac_processor_matches_config_defaults(isaac_processor, isaac_tiny_config): + assert isaac_processor.vision_token == isaac_tiny_config.vision_token + assert isaac_processor.max_sequence_length == isaac_tiny_config.max_sequence_length + assert isaac_processor.config is isaac_tiny_config + assert isinstance(isaac_processor.image_processor, IsaacImageProcessorFast) + assert isaac_processor.image_processor.rescale_factor == pytest.approx(isaac_tiny_config.vision_rescale_factor) + + +@require_torch +@require_vision +@require_tensorstream +def test_isaac_processor_text_only_round_trip(isaac_processor): + messages = [{"role": "user", "content": "Hello, how are you?"}] + prompt = isaac_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + outputs = isaac_processor(text=prompt, images=None, return_tensors="pt") + + assert "input_ids" in outputs + assert "tensor_stream" in outputs + assert TensorStream is not None + assert isinstance(outputs["tensor_stream"], TensorStream) + assert outputs["input_ids"].shape[0] == 1 + + +@require_torch +@require_tensorstream +def test_isaac_processor_accepts_batchencoding_chat_template(isaac_processor): + messages = [{"role": "user", "content": "Hello, how are you?"}] + batch_encoding = isaac_processor.apply_chat_template(messages, add_generation_prompt=True) + + outputs = isaac_processor(text=batch_encoding, images=None, return_tensors="pt") + + assert "input_ids" in outputs + assert "tensor_stream" in outputs + assert TensorStream is not None + assert isinstance(outputs["tensor_stream"], TensorStream) + assert outputs["input_ids"].shape[0] == 1 + + +@require_torch +@require_vision +@require_tensorstream +def test_isaac_processor_with_single_image(isaac_processor): + vision_token = isaac_processor.vision_token + text = f"Look at this {vision_token} and describe it." + image = _make_dummy_image() + + outputs = isaac_processor(text=text, images=[image], return_tensors="pt") + assert TensorStream is not None + assert isinstance(outputs["tensor_stream"], TensorStream) + assert outputs["input_ids"].ndim == 2 + + +@require_torch +@require_vision +@require_tensorstream +def test_isaac_processor_with_multiple_images(isaac_processor): + vision_token = isaac_processor.vision_token + text = f"First {vision_token} then {vision_token}" + images = [_make_dummy_image(color=(255, 0, 0)), _make_dummy_image(color=(0, 255, 0))] + + outputs = isaac_processor(text=text, images=images, return_tensors="pt") + assert TensorStream is not None + assert isinstance(outputs["tensor_stream"], TensorStream) + assert outputs["input_ids"].shape[0] == 1 + + +@require_torch +@require_vision +@require_tensorstream +def test_isaac_processor_error_on_image_mismatch(isaac_processor): + vision_token = isaac_processor.vision_token + text = f"{vision_token} {vision_token}" + image = _make_dummy_image() + + with pytest.raises(ValueError, match="must match number of images"): + isaac_processor(text=text, images=[image], return_tensors="pt") + + +@require_torch +@require_vision +@require_tensorstream +def test_isaac_processor_consistent_tensor_stream_types(isaac_processor): + text_only = "Simple question?" + text_with_image = f"Describe this {isaac_processor.vision_token}" + image = _make_dummy_image() + + outputs_text = isaac_processor(text=text_only, images=None, return_tensors="pt") + outputs_image = isaac_processor(text=text_with_image, images=[image], return_tensors="pt") + + assert TensorStream is not None + assert isinstance(outputs_text["tensor_stream"], TensorStream) + assert isinstance(outputs_image["tensor_stream"], TensorStream) + assert outputs_text["input_ids"].shape[0] == outputs_image["input_ids"].shape[0] == 1 From 9226a9caeee6a9b589dac040bf2bbedd1812c6ad Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Wed, 17 Dec 2025 20:06:45 +0400 Subject: [PATCH 53/77] Update src/transformers/models/isaac/modular_isaac.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/isaac/modular_isaac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 7f973a1731aa..3425b27b1069 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -1287,7 +1287,7 @@ def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> class IsaacProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = ("IsaacImageProcessorFast",) - tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + tokenizer_class = ("Qwen2Tokenizer",) def __init__( self, From a1892a576d4daea6105c00ad7f41257437f96887 Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Wed, 17 Dec 2025 20:08:17 +0400 Subject: [PATCH 54/77] Update src/transformers/models/isaac/modular_isaac.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/isaac/modular_isaac.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 3425b27b1069..8ced75bd2296 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -1848,7 +1848,6 @@ class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): config_class = IsaacConfig _can_compile_fullgraph = False _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"} - all_tied_weights_keys: dict[str, str] = {"lm_head.weight": "model.text_model.embed_tokens.weight"} def __init__(self, config: IsaacConfig): super().__init__(config) From 82f25d6da7fe61cc9e91783c85bb792436638225 Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Wed, 17 Dec 2025 20:36:57 +0400 Subject: [PATCH 55/77] Update src/transformers/models/isaac/modular_isaac.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/isaac/modular_isaac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 8ced75bd2296..4c67f7f0d355 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -162,7 +162,7 @@ def __init__( self._attn_implementation = "sdpa" -class IsaacImageProcessorKwargs(ImagesKwargs, total=False): +class IsaacImageProcessorFastKwargs(ImagesKwargs, total=False): patch_size: Optional[int] max_num_patches: Optional[int] min_num_patches: Optional[int] From 5422d9d25077077b17c753014c751ec8abf471c4 Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Thu, 18 Dec 2025 18:54:19 +0400 Subject: [PATCH 56/77] style: review revisions (#12) --- docs/source/en/model_doc/isaac.md | 3 - .../models/isaac/configuration_isaac.py | 53 +- .../isaac/image_processing_isaac_fast.py | 73 +- .../models/isaac/modeling_isaac.py | 979 +++++++++--------- .../models/isaac/modular_isaac.py | 497 ++++----- .../models/isaac/processing_isaac.py | 7 +- tests/models/isaac/test_modeling_isaac.py | 57 +- 7 files changed, 727 insertions(+), 942 deletions(-) diff --git a/docs/source/en/model_doc/isaac.md b/docs/source/en/model_doc/isaac.md index 7cd17f2a7f3d..538d33033d53 100644 --- a/docs/source/en/model_doc/isaac.md +++ b/docs/source/en/model_doc/isaac.md @@ -88,9 +88,6 @@ prompt = processor.apply_chat_template( # IsaacProcessor builds TensorStream events internally when both text and images are provided. batch = processor(text=prompt, images=images, return_tensors="pt") -tensor_stream = batch.pop("tensor_stream").to(model.device) -inputs = {name: tensor.to(model.device) for name, tensor in batch.items()} - with torch.inference_mode(): generated = model.generate( **inputs, diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index ec8e0f74f967..96ab227aa87e 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -19,10 +19,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Optional, Union from ...configuration_utils import PreTrainedConfig, PretrainedConfig, layer_type_validation -from ...modeling_rope_utils import rope_config_validation from ...models.qwen3.configuration_qwen3 import Qwen3Config @@ -97,25 +96,25 @@ def __init__( vision_token: str = "", **kwargs, ): - self._rope_parameters: Optional[dict[str, Any]] = None attn_implementation = kwargs.get("attn_implementation") if isinstance(text_config, dict): self.text_config = self.sub_configs["text_config"](**text_config) + elif isinstance(text_config, Qwen3Config): + self.text_config = text_config elif text_config is None: self.text_config = self.sub_configs["text_config"]() - super().__init__(**kwargs) + # Seed RoPE parameters before base init so the shared mixin can standardize/validate them. + self.rope_parameters = getattr(self.text_config, "rope_parameters", None) + self.layer_types = getattr(self.text_config, "layer_types", None) - if self._rope_scaling is None: - self._rope_scaling = getattr(self.text_config, "rope_scaling", None) - else: - self.text_config.rope_scaling = self._rope_scaling + super().__init__(**kwargs) - # Keep rope parameters alias in sync with upstream expectations - self._rope_parameters = self._rope_scaling + # Keep rope parameters aligned between the composite and text sub-configs. + self.text_config.rope_parameters = self.rope_parameters - # Mirror frequently accessed Qwen3 attributes at the composite config level for BC. + # Mirror frequently accessed Qwen3 attributes at the composite config level self.vocab_size = self.text_config.vocab_size self.hidden_size = self.text_config.hidden_size self.num_hidden_layers = self.text_config.num_hidden_layers @@ -123,10 +122,7 @@ def __init__( self.head_dim = self.text_config.head_dim self.hidden_act = self.text_config.hidden_act self.use_cache = self.text_config.use_cache - self.rope_theta = self.text_config.rope_parameters["rope_theta"] - - # Validate rotary parameters now that they have been mirrored locally. - rope_config_validation(self) + self.rope_theta = self.rope_parameters["rope_theta"] self.layer_types = getattr(self.text_config, "layer_types", None) layer_type_validation(self.layer_types, self.num_hidden_layers) @@ -155,33 +151,6 @@ def __init__( self.max_sequence_length = max_sequence_length self.vision_token = vision_token - @property - def rope_scaling(self): - if hasattr(self, "text_config") and self.text_config is not None: - return getattr(self.text_config, "rope_scaling", None) - return self._rope_scaling - - @rope_scaling.setter - def rope_scaling(self, value): - self._rope_scaling = value - if hasattr(self, "text_config") and self.text_config is not None: - self.text_config.rope_scaling = value - - @property - def rope_parameters(self) -> dict[str, Any] | None: - """Alias introduced upstream for rope scaling dictionaries.""" - value = self._rope_parameters - if value is None: - value = self.rope_scaling - if value is None: - return {"rope_type": "default"} - return value - - @rope_parameters.setter - def rope_parameters(self, value: dict[str, Any] | None) -> None: - self._rope_parameters = value - self.rope_scaling = value - def to_dict(self): output = super().to_dict() # Ensure nested configs round-trip through dict serialization diff --git a/src/transformers/models/isaac/image_processing_isaac_fast.py b/src/transformers/models/isaac/image_processing_isaac_fast.py index 1cbfccd70a2b..58735df5fd60 100644 --- a/src/transformers/models/isaac/image_processing_isaac_fast.py +++ b/src/transformers/models/isaac/image_processing_isaac_fast.py @@ -25,15 +25,17 @@ from ...feature_extraction_utils import BatchFeature from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict, group_images_by_shape, reorder_images -from ...image_utils import ChannelDimension, PILImageResampling +from ...image_utils import PILImageResampling from ...processing_utils import Unpack from ...utils import TensorType, auto_docstring # Vision preprocessing constants from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD -from ...utils.import_utils import is_torch_available -from .image_processing_isaac import IsaacImageProcessorKwargs +from ...utils.import_utils import ( + is_torch_available, +) +from .modeling_isaac import IsaacImageProcessorFastKwargs if is_torch_available(): @@ -41,6 +43,29 @@ import torch.nn.functional as F +# Disable as it causes issues with torch.compile +@torch.compiler.disable +def torch_extract_patches(image_tensor, patch_height, patch_width): + """ + Extract patches from image tensor. Returns tensor of shape (batch, rows, columns, patch_height*patch_width*channels). + + Args: + image_tensor (`torch.Tensor`): + Image tensor of shape (batch, channels, height, width). + patch_height (`int`): + Height of patches to extract. + patch_width (`int`): + Width of patches to extract. + """ + batch_size, channels, height, width = image_tensor.shape + patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width)) + patches = patches.reshape(batch_size, channels, patch_height, patch_width, -1) + patches = patches.permute(0, 4, 2, 3, 1).reshape( + batch_size, height // patch_height, width // patch_width, channels * patch_height * patch_width + ) + return patches + + def get_scaled_image_size( scale: float, original_size: int, @@ -131,33 +156,6 @@ def get_image_size_for_max_num_patches( return target_height, target_width -def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: - r"""Convert normalized images into flattened ViT-style patches. - - Args: - image (`torch.Tensor`): - Tensor of shape `(num_images, height, width, channels)`. - patch_size (`int`): - Edge length of the square patches - - Returns: - `torch.Tensor`: - Patch tensor where each position stores the flattened pixels belonging to that patch. - - Raises: - ValueError: If `height` or `width` is not divisible by `patch_size`. - """ - num_images, height, width, channels = image.shape - if height % patch_size or width % patch_size: - raise ValueError(f"Dimensions of images {image.shape} are not divisible by patch_size={patch_size}.") - patches = image.reshape(num_images, height // patch_size, patch_size, width // patch_size, patch_size, channels) - patches = patches.permute(0, 1, 3, 2, 4, 5) - patches = patches.reshape( - num_images, height // patch_size, width // patch_size, channels * patch_size * patch_size - ) - return patches - - def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> torch.Tensor: """Compute residuals for P-frames to stay in sync with the training pipeline.""" if not any(is_p_frame): @@ -178,36 +176,27 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): resample = PILImageResampling.BILINEAR model_input_names = ["patches", "token_grids"] - valid_kwargs = IsaacImageProcessorKwargs + valid_kwargs = IsaacImageProcessorFastKwargs unused_kwargs = ["size", "do_center_crop", "crop_size"] do_resize = True - size: Optional[SizeDict] = None - default_to_square: Optional[bool] = None do_center_crop = False - crop_size: Optional[SizeDict] = None patch_size: Optional[int] = 16 max_num_patches: Optional[int] = 256 min_num_patches: Optional[int] = None pixel_shuffle_scale: Optional[int] = 1 do_pad = False - pad_size: Optional[SizeDict] = None do_rescale = True - rescale_factor = 1 / 255 do_normalize = True image_mean = list(VISION_MEAN) image_std = list(VISION_STD) do_convert_rgb = True - return_tensors = None - data_format = ChannelDimension.FIRST - input_data_format = None - device = None disable_grouping = False size_divisor: Optional[int] = None def __init__( self, - **kwargs: Unpack[IsaacImageProcessorKwargs], + **kwargs: Unpack[IsaacImageProcessorFastKwargs], ) -> None: super().__init__(**kwargs) @@ -343,7 +332,7 @@ def _preprocess( nhwc_images = image_batch.permute(0, 2, 3, 1) nhwc_images = _compute_residual_p_frames(nhwc_images, is_p_frame=[False] * batch_size) - patches = patchify_vision(nhwc_images, patch_size=patch_size) + patches = torch_extract_patches(nhwc_images.permute(0, 3, 1, 2), patch_size, patch_size) _, height_tokens, width_tokens, _ = patches.shape token_grid = ( diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 748d7a8db3a2..cda31175685e 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -27,9 +27,11 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache +from ...configuration_utils import PretrainedConfig from ...generation.utils import GenerationMixin +from ...image_processing_utils_fast import ImagesKwargs from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func -from ...masking_utils import create_masks_for_generate, packed_sequence_mask_function +from ...masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, create_masks_for_generate, packed_sequence_mask_function from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast @@ -38,7 +40,7 @@ from ...models.qwen3.modeling_qwen3 import Qwen3PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring -from ...utils.generic import TransformersKwargs, can_return_tuple, check_model_inputs +from ...utils.generic import OutputRecorder, TransformersKwargs, can_return_tuple, check_model_inputs from ...utils.import_utils import ( is_perceptron_available, is_torch_available, @@ -53,7 +55,6 @@ import torch.nn as nn import torch.nn.functional as F - if is_perceptron_available(): from perceptron.tensorstream.ops import ( compute_mrope_pos_tensor, @@ -72,6 +73,13 @@ group_streams = None +class IsaacImageProcessorFastKwargs(ImagesKwargs, total=False): + patch_size: Optional[int] + max_num_patches: Optional[int] + min_num_patches: Optional[int] + pixel_shuffle_scale: Optional[int] + + class IsaacVisionEmbeddings(nn.Module): """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences.""" @@ -247,19 +255,12 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - past_key_value: Optional[torch.Tensor] = None, output_attentions: bool = False, - is_causal: bool = False, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" - # Ignore unused arguments for interface compatibility - _ = position_ids - _ = past_key_value - _ = is_causal kwargs.pop("output_hidden_states", None) kwargs.pop("return_dict", None) @@ -272,22 +273,10 @@ def forward( keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - if not queries.is_contiguous(): - queries = queries.contiguous() - if not keys.is_contiguous(): - keys = keys.contiguous() - if not values.is_contiguous(): - values = values.contiguous() - - L = queries.size(0) - if max_seqlen is not None: - max_q = max_k = int(max_seqlen) - else: - max_q = max_k = self._max_from_cu(cu_seqlens, L) - + attn_impl = self.config._attn_implementation attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"] - if self.config._attn_implementation != "sdpa": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if attn_impl != "sdpa": + attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] dropout = 0.0 if not self.training else self.dropout attention_kwargs: dict[str, Any] = { @@ -295,15 +284,36 @@ def forward( "scaling": self.scale, "dropout": dropout, } - if cu_seqlens is not None: - attention_kwargs["cu_seq_lens_q"] = cu_seqlens - attention_kwargs["cu_seq_lens_k"] = cu_seqlens - if max_seqlen is not None: - attention_kwargs["max_length_q"] = max_q - attention_kwargs["max_length_k"] = max_k - if output_attentions: + + supports_varlen = cu_seqlens is not None and attn_impl in { + "flash_attention_2", + "flash_attention_3", + "flex_attention", + "paged|flash_attention_2", + "paged|flash_attention_3", + } + + if output_attentions and attn_impl == "eager": attention_kwargs["output_attentions"] = True + if supports_varlen: + if max_seqlen is not None: + max_q = max_k = int(max_seqlen) + elif cu_seqlens.numel() >= 2: + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + max_q = max_k = lengths.max() if lengths.numel() > 0 else seq_length + else: + max_q = max_k = seq_length + + attention_kwargs.update( + { + "cu_seq_lens_q": cu_seqlens, + "cu_seq_lens_k": cu_seqlens, + "max_length_q": max_q, + "max_length_k": max_k, + } + ) + attn_output, attn_weights = attention_interface( self, queries, @@ -326,12 +336,6 @@ def forward( return attn_output, attn_weights - @staticmethod - def _max_from_cu(cu: Optional[torch.Tensor], fallback: int) -> int: - if cu is None or cu.numel() < 2: - return fallback - return int((cu[1:] - cu[:-1]).max().item()) - class IsaacMLP(nn.Module): def __init__(self, config): @@ -348,57 +352,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -def document_mask_function_from_cu_seqlens(cu_seqlens: Optional[torch.Tensor]) -> Optional[Callable]: - """Return a mask function that blocks cross-document attention from packed ``cu_seqlens``. - - The returned callable matches the signature expected by ``masking_utils`` mask factories and - yields ``True`` only when query/key positions belong to the same packed segment. - """ - - if cu_seqlens is None: - return None - - if cu_seqlens.numel() < 2: - return None - - seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() - if seq_sizes.numel() == 0: - return None - - total_tokens = int(seq_sizes.sum().item()) - seg_ids = torch.repeat_interleave(torch.arange(seq_sizes.numel(), device=cu_seqlens.device), seq_sizes) - packed_sequence_mask = seg_ids.view(1, total_tokens) - return packed_sequence_mask_function(packed_sequence_mask) - - -def ensure_document_attention_mask( - attention_mask: Optional[torch.Tensor], - cu_seqlens: Optional[torch.Tensor], - total_tokens: int, - dtype: torch.dtype, - device: torch.device, - *, - return_mask_function: bool = False, -) -> Optional[Union[torch.Tensor, Callable]]: - """Return the provided mask, a callable mask from ``cu_seqlens``, or ``None``. - - ``return_mask_function=True`` yields a callable suitable for ``masking_utils``; otherwise - ``None`` is returned when no explicit ``attention_mask`` is provided. The legacy additive mask - has been removed in favor of the callable-based path. - """ - - if attention_mask is not None: - return attention_mask - - if cu_seqlens is None: - return None - - if return_mask_function: - return document_mask_function_from_cu_seqlens(cu_seqlens) - - return None - - class IsaacVisionEncoderLayer(GradientCheckpointingLayer): """Isaac vision encoder layer with variable-length attention.""" @@ -428,30 +381,16 @@ def forward( Maximum document length referenced by `cu_seqlens`. Passed to FlashAttention so it can size temporary buffers for packed variable-length attention. """ - attention_mask = ensure_document_attention_mask( - attention_mask, - cu_seqlens, - hidden_states.size(1), - hidden_states.dtype, - hidden_states.device, - return_mask_function=False, - ) - # Run attention directly so variable-length metadata reaches FlashAttention. residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - attn_outputs = self.self_attn( + attn_output, _ = self.self_attn( hidden_states, attention_mask=attention_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, - output_attentions=output_attentions, **kwargs, ) - if isinstance(attn_outputs, tuple): - attn_output, attn_weights = attn_outputs - else: - attn_output, attn_weights = attn_outputs, None hidden_states = residual + attn_output residual = hidden_states @@ -475,26 +414,13 @@ def __init__(self, config: IsaacVisionConfig): # Ignore copy @can_return_tuple + @check_model_inputs def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: - attention_mask = ensure_document_attention_mask( - attention_mask, - cu_seqlens, - inputs_embeds.size(1), - inputs_embeds.dtype, - inputs_embeds.device, - return_mask_function=False, - ) - hidden_states = inputs_embeds kwargs.update( { @@ -514,6 +440,64 @@ def forward( return BaseModelOutput(last_hidden_state=hidden_states) +def document_mask_function_from_cu_seqlens(cu_seqlens: Optional[torch.Tensor]) -> Optional[Callable]: + """Return a mask function that blocks cross-document attention from packed ``cu_seqlens``. + + The returned callable matches the signature expected by ``masking_utils`` mask factories and + yields ``True`` only when query/key positions belong to the same packed segment. + """ + + if cu_seqlens is None: + return None + + if cu_seqlens.numel() < 2: + return None + + seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() + if seq_sizes.numel() == 0: + return None + + total_tokens = int(seq_sizes.sum().item()) + seg_ids = torch.repeat_interleave(torch.arange(seq_sizes.numel(), device=cu_seqlens.device), seq_sizes) + packed_sequence_mask = seg_ids.view(1, total_tokens) + return packed_sequence_mask_function(packed_sequence_mask) + + +def create_document_attention_mask( + config: PretrainedConfig, + input_embeds: torch.Tensor, + cu_seqlens: Optional[torch.Tensor], +) -> Optional[Union[torch.Tensor, Any]]: + """Materialize a backend-specific block-diagonal attention mask. + + This uses the standard `masking_utils` mask interface (same mechanism as Llama4), + so the returned object matches the selected attention backend (e.g. SDPA bool mask, + eager additive mask, or flex `BlockMask`). + """ + + mask_function = document_mask_function_from_cu_seqlens(cu_seqlens) + if mask_function is None: + return None + + seq_len = input_embeds.shape[1] + cache_position = torch.arange(seq_len, device=input_embeds.device, dtype=torch.long) + + mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] + return mask_interface( + batch_size=input_embeds.shape[0], + cache_position=cache_position, + kv_length=seq_len, + kv_offset=0, + mask_function=mask_function, + attention_mask=None, + allow_is_causal_skip=False, + allow_is_bidirectional_skip=False, + dtype=input_embeds.dtype, + config=config, + use_vmap=False, + ) + + def create_pixel_shuffle_index_map( seq_sizes: torch.Tensor, token_grids: torch.Tensor, @@ -607,15 +591,15 @@ def pixel_shuffle_varlen( Raises: ValueError: If more than one batch item is provided. """ - keep_batch_dim = x.dim() == 3 - if keep_batch_dim: + return_with_batch_dim = x.dim() == 3 + if return_with_batch_dim: if x.size(0) != 1: raise AssertionError("Packed sequence is expected to have batch_size == 1") - x_ = x.squeeze(0) # (seq, embed) + embeddings = x.squeeze(0) # (seq, embed) else: - x_ = x # (seq, embed) + embeddings = x # (seq, embed) - embed_dim = x_.size(-1) + embed_dim = embeddings.size(-1) scale_factor = int(scale_factor) # Calculate seq_sizes from token_grids @@ -626,17 +610,17 @@ def pixel_shuffle_varlen( seq_sizes=seq_sizes, token_grids=token_grids, scale_factor=scale_factor, - device=x_.device, + device=embeddings.device, ) # (new_seq, scale_factor**2) # Gather โ†’ (new_seq, scale_factor**2, embed_dim) - gathered = x_[gather_idx] # fancy indexing keeps gradient + gathered = embeddings[gather_idx] # fancy indexing keeps gradient # Merge the scale_factor**2 group dimension into channels to finish the shuffle out = gathered.reshape(gathered.size(0), embed_dim * scale_factor * scale_factor) # Restore batch dimension if needed - if keep_batch_dim: + if return_with_batch_dim: out = out.unsqueeze(0) return out @@ -665,14 +649,14 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): # Generate cumulative sequence lengths for variable-length attention cu_seqlens = torch.zeros(seq_sizes.size(0) + 1, dtype=torch.int32, device=hidden_states.device) cu_seqlens[1:] = seq_sizes.cumsum(0) - max_seqlen = int(seq_sizes.max().item()) if seq_sizes.numel() > 0 else 0 + + attention_mask = create_document_attention_mask(self.config, hidden_states, cu_seqlens) # Pass through encoder with variable-length attention parameters encoder_outputs = self.encoder( inputs_embeds=hidden_states, + attention_mask=attention_mask, cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - return_dict=True, ) hidden_states = encoder_outputs.last_hidden_state @@ -713,12 +697,10 @@ def forward(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Ten return self.multimodal_projector(hidden_states) -class IsaacRotaryEmbedding(nn.Module): +class IsaacRotaryEmbedding(qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding): EXTRA_ROPE_KEYS = {"mrope_section", "mrope_interleaved"} def __init__(self, config: IsaacConfig, device=None): - super().__init__() - rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} @@ -727,9 +709,9 @@ def __init__(self, config: IsaacConfig, device=None): config_for_rope.rope_scaling = sanitized_scaling if sanitized_scaling else None init_device = device if device is not None and getattr(device, "type", None) != "meta" else None - self._qwen_rotary = qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding(config_for_rope, device=init_device) + super().__init__(config_for_rope, device=init_device) - rotary_half_dim = self._qwen_rotary.inv_freq.shape[0] + rotary_half_dim = self.inv_freq.shape[0] self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim) self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size @@ -755,10 +737,6 @@ def _combine_axes(self, tensor: torch.Tensor) -> torch.Tensor: chunks = tensor.split(split_sections, dim=-1) return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) - @property - def inv_freq(self) -> torch.Tensor: - return self._qwen_rotary.inv_freq - def forward( self, position_ids: torch.Tensor, @@ -790,7 +768,7 @@ def forward( pos_axes = pos.permute(2, 0, 1).contiguous() - cos_axes, sin_axes = self._qwen_rotary(hidden_states, pos_axes) + cos_axes, sin_axes = super().forward(hidden_states, pos_axes) cos_axes = cos_axes.to(hidden_states.dtype) sin_axes = sin_axes.to(hidden_states.dtype) @@ -801,271 +779,75 @@ def forward( return cos_combined, sin_combined -@use_kernel_forward_from_hub("RMSNorm") -class IsaacRMSNorm(nn.Module): - def __init__(self, hidden_size, eps: float = 1e-6) -> None: - """ - IsaacRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) +# ============================================================================ +# Model +# ============================================================================ -@use_kernel_func_from_hub("rotary_pos_emb") -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. +def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: + r"""Create 3D positional indices for token input. Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + input_ids (`torch.Tensor`): + Tensor of shape `(batch_size, seq_len)` containing token ids. + Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + `torch.Tensor`: Positional indices with shape `(batch_size, seq_len, 3)` where each channel duplicates the + 1D position so it can be consumed by the 3-axis MRoPE rotary embedding. """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed + batch_size, seq_length = input_ids.shape + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.unsqueeze(2).expand(-1, -1, 3) # Add 3D for MRoPE + return position_ids -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +@auto_docstring +class IsaacModel(PreTrainedModel): + config: IsaacConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["IsaacDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = False + _can_compile_fullgraph = False + _supports_attention_backend = True + _can_record_outputs = {"attentions": OutputRecorder(IsaacVisionAttention, index=1)} + # Expose tied-weights mapping even if empty for base model tests. + all_tied_weights_keys: dict[str, str] = {} + def __init__(self, config: IsaacConfig): + Qwen3PreTrainedModel.__init__(self, config) -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs: Unpack[TransformersKwargs], -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) + text_cfg_source = config.text_config + text_cfg = copy.deepcopy(text_cfg_source) + self.text_model = AutoModel.from_config(text_cfg) + # Ensure downstream callers observe the composed config + self.text_model.config = config - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask + self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() + if config.vision_config is None: + raise ValueError("IsaacConfig should always have vision_config") - return attn_output, attn_weights + self.vision_embedding = IsaacVisionEmbedding(config) + self.vision_embedding._supports_sdpa = True + # Dispatch table for TensorStream balanced embedding (text + vision) + self.embed_fns = { + TextType: self.embed_text_tokens, + VisionType: self.embed_vision, + } -@use_kernelized_func(apply_rotary_pos_emb) -class IsaacAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" + # Keep track of config attributes that downstream utilities may query directly on the model. + self.max_sequence_length = config.max_sequence_length + self.vision_rescale_factor = config.vision_rescale_factor + self.vision_token = config.vision_token - def __init__(self, config: IsaacConfig, layer_idx: int): - super().__init__() - self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) - self.q_norm = IsaacRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! - self.k_norm = IsaacRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape - self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_values: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) - key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_values is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - sliding_window=self.sliding_window, # diff with Llama - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - -class IsaacDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: IsaacConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = IsaacAttention(config=config, layer_idx=layer_idx) - - self.mlp = IsaacMLP(config) - self.input_layernorm = IsaacRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = IsaacRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs: Unpack[TransformersKwargs], - ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - # Self Attention - hidden_states, _ = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states - - -@auto_docstring -class IsaacModel(PreTrainedModel): - config: IsaacConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["IsaacDecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn = True - _supports_sdpa = True - _supports_flex_attn = False - _can_compile_fullgraph = False - _supports_attention_backend = True - _can_record_outputs = { - "hidden_states": IsaacDecoderLayer, - "attentions": IsaacAttention, - } - # Expose tied-weights mapping even if empty for base model tests. - all_tied_weights_keys: dict[str, str] = {} - - def __init__(self, config: IsaacConfig): - Qwen3PreTrainedModel.__init__(self, config) - - text_cfg_source = config.text_config - text_cfg = copy.deepcopy(text_cfg_source) - self.text_model = AutoModel.from_config(text_cfg) - # Ensure downstream callers observe the composed config - self.text_model.config = config - - self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) - - if config.vision_config is None: - raise ValueError("IsaacConfig should always have vision_config") - - self.vision_embedding = IsaacVisionEmbedding(config) - self.vision_embedding._supports_sdpa = True - - # Dispatch table for TensorStream balanced embedding (text + vision) - self.embed_fns = { - TextType: self.embed_text_tokens, - VisionType: self.embed_vision, - } - - # Keep track of config attributes that downstream utilities may query directly on the model. - self.max_sequence_length = config.max_sequence_length - self.vision_rescale_factor = config.vision_rescale_factor - self.vision_token = config.vision_token - - # Initialize weights and parallel plans (including tp_plan from the text model) - self.post_init() + # Initialize weights and parallel plans (including tp_plan from the text model) + self.post_init() # Respect config-specified gradient checkpointing if getattr(config, "gradient_checkpointing", False): @@ -1091,22 +873,10 @@ def embed_tokens(self) -> nn.Module: def embed_tokens(self, value: nn.Module) -> None: self.text_model.embed_tokens = value - @property - def layers(self) -> nn.ModuleList: - return self.text_model.layers - - @property - def norm(self) -> nn.Module: - return self.text_model.norm - @property def vision_model(self) -> nn.Module: return self.vision_embedding.vision_tower - @property - def vision_tower(self) -> nn.Module: - return self.vision_embedding.vision_tower - def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: """Embed text tokens, squeezing singleton dimensions.""" # Text events are shaped as (..., 1); squeeze the singleton index dim @@ -1154,6 +924,62 @@ def embed_stream(self, tensor_stream: TensorStream) -> torch.Tensor: h = embedded_ts.compact() # (B, T, D) return h + @staticmethod + def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: + return compute_position_ids_input_ids(input_ids) + + def _prepare_position_and_modality( + self, + position_ids: Optional[torch.LongTensor], + modality_tensor: Optional[torch.LongTensor], + tensor_stream: Optional[TensorStream], + inputs_embeds: torch.Tensor, + cache_position: torch.LongTensor, + ) -> tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.Tensor, torch.Tensor]: + text_value = TextType.text.value if TextType is not None else 0 + batch_size, seq_len = inputs_embeds.shape[:2] + + if modality_tensor is None: + if tensor_stream is not None: + modality_tensor = modality_mask(tensor_stream) + else: + modality_tensor = torch.full( + (batch_size, seq_len), text_value, device=inputs_embeds.device, dtype=torch.long + ) + else: + modality_tensor = modality_tensor.to(device=inputs_embeds.device, dtype=torch.long) + expected_shape = (batch_size, seq_len) + if modality_tensor.shape != torch.Size(expected_shape): + raise ValueError( + f"modality_tensor must have shape (batch_size, seq_len) {expected_shape}, " + f"but got {tuple(modality_tensor.shape)}" + ) + + if position_ids is None: + if tensor_stream is not None: + position_ids = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) + else: + position_ids = cache_position.view(1, -1).expand(modality_tensor.shape[0], -1) + + if position_ids.ndim == 2: + position_ids = position_ids.to(device=inputs_embeds.device) + position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) + + if position_ids.shape[1] != seq_len: + start_positions = position_ids[:, :1, 0] + position_ids = torch.arange(seq_len, device=inputs_embeds.device).view(1, -1) + position_ids = position_ids + start_positions + position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) + + cos, sin = self.rotary_emb( + position_ids, + modality_tensor, + hidden_states=inputs_embeds, + ) + + decoder_position_ids = position_ids[..., 0] if position_ids.ndim == 3 else position_ids + return position_ids, modality_tensor, decoder_position_ids, cos, sin + @auto_docstring @check_model_inputs def forward( @@ -1166,11 +992,8 @@ def forward( past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPast: """ Forward pass with MRoPE position embeddings. @@ -1188,7 +1011,7 @@ def forward( omitted. """ - text_value = TextType.text.value if TextType is not None else 0 + output_attentions = kwargs.pop("output_attentions", None) # Get inputs if tensor_stream is not None and inputs_embeds is not None: @@ -1218,54 +1041,13 @@ def forward( if attention_mask is None: attention_mask = torch.ones((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long) - # Normalize modality tensor - if modality_tensor is None: - if tensor_stream is not None: - modality_tensor = modality_mask(tensor_stream) - else: - modality_tensor = torch.full( - (batch_size, seq_len), text_value, device=inputs_embeds.device, dtype=torch.long - ) - else: - modality_tensor = modality_tensor.to(dtype=torch.long) - - if modality_tensor.shape[1] != seq_len: - if modality_tensor.shape[1] > seq_len: - modality_tensor = modality_tensor[:, :seq_len] - else: - pad = modality_tensor[:, -1:].expand(-1, seq_len - modality_tensor.shape[1]) - modality_tensor = torch.cat([modality_tensor, pad], dim=1) - - # Normalize position ids - if position_ids is None: - if tensor_stream is not None: - position_ids = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) - else: - position_ids = cache_position.view(1, -1).expand(modality_tensor.shape[0], -1) - - # Expand 2D position ids (from generic padding tests or decode cache positions) to 3D MRoPE coords - if position_ids.ndim == 2: - position_ids = position_ids.to(device=inputs_embeds.device) - position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) - - # Align lengths so rotary embedding sees matching shapes - if position_ids.shape[1] != seq_len: - start_positions = position_ids[:, :1, 0] - position_ids = torch.arange(seq_len, device=inputs_embeds.device).view(1, -1) - position_ids = position_ids + start_positions - position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) - - # Compute MRoPE position embeddings if we have custom rotary_emb - cos, sin = self.rotary_emb( - position_ids, - modality_tensor, - hidden_states=inputs_embeds, + position_ids, modality_tensor, decoder_position_ids, cos, sin = self._prepare_position_and_modality( + position_ids=position_ids, + modality_tensor=modality_tensor, + tensor_stream=tensor_stream, + inputs_embeds=inputs_embeds, + cache_position=cache_position, ) - cos = cos.to(inputs_embeds.dtype) - sin = sin.to(inputs_embeds.dtype) - - # Flash attention expects 1D position_ids; keep 3D only for rotary phases - decoder_position_ids = position_ids[..., 0] if position_ids.ndim == 3 else position_ids # Prepare attention mask if not isinstance(attention_mask, dict): @@ -1316,6 +1098,222 @@ def forward( ) +@use_kernel_forward_from_hub("RMSNorm") +class IsaacRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + IsaacRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +@use_kernelized_func(apply_rotary_pos_emb) +class IsaacAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: IsaacConfig, layer_idx: int): + super().__init__() + self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = IsaacRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = IsaacRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape + self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class IsaacDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: IsaacConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = IsaacAttention(config=config, layer_idx=layer_idx) + + self.mlp = IsaacMLP(config) + self.input_layernorm = IsaacRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = IsaacRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + @auto_docstring class IsaacPreTrainedModel(PreTrainedModel): config: IsaacConfig @@ -1393,11 +1391,8 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | CausalLMOutputWithPast: r""" Forward pass for conditional generation supporting both standard inputs and TensorStream. @@ -1408,66 +1403,43 @@ def forward( `input_ids`. """ - # Don't compute embeddings here - let the model handle it + output_attentions = kwargs.pop("output_attentions", None) + + # Don't compute embeddings here - let the inner model handle it if tensor_stream is not None: input_ids = None if input_ids is None and inputs_embeds is None and tensor_stream is None: raise ValueError("Either input_ids, inputs_embeds, or tensor_stream must be provided.") - text_value = TextType.text.value if TextType is not None else 0 - - if tensor_stream is None: + # Record rope deltas on prefill when TensorStream is provided; leave position_ids building to IsaacModel. + if position_ids is None and tensor_stream is not None: + position_ids, self.rope_deltas = self.get_rope_index(input_ids, tensor_stream, attention_mask) + elif position_ids is None and cache_position is not None and self.rope_deltas is not None: + # Decode continuation after TensorStream prefill: advance positions using cached rope offsets. if input_ids is not None: - batch_size, seq_len = input_ids.shape - input_device = input_ids.device + base_position_ids = compute_position_ids_input_ids(input_ids) else: + if inputs_embeds is None: + raise ValueError("inputs_embeds must be provided when input_ids is None during decode") batch_size, seq_len = inputs_embeds.shape[:2] - input_device = inputs_embeds.device - - # Build position ids (MRoPE) if needed and tensor_stream is available - # During decode we reuse `self.rope_deltas` computed on the initial forward pass; `rope_delta` captures how far - # cached rotary phases have progressed so we can advance `position_ids` without rebuilding the TensorStream. - if position_ids is None: - if tensor_stream is not None: - position_ids, self.rope_deltas = self.get_rope_index(input_ids, tensor_stream, attention_mask) - elif input_ids is None: - dummy_ids = torch.zeros((batch_size, seq_len), device=input_device, dtype=torch.long) - position_ids = compute_position_ids_input_ids(dummy_ids) - else: - position_ids = compute_position_ids_input_ids(input_ids) - - rope_delta = 0 - if cache_position is not None and self.rope_deltas is not None: - # Combine the incremental decode step (`cache_position`) with cached offsets so hidden states continue - # rotating in lockstep across generation steps. - rope_delta = (cache_position[0] + self.rope_deltas).to(input_ids.device) - if not isinstance(rope_delta, int): # otherwise `deltas` is an int `0` - rope_delta = rope_delta.repeat_interleave(batch_size // rope_delta.shape[0], dim=0) + dummy_ids = torch.zeros((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long) + base_position_ids = compute_position_ids_input_ids(dummy_ids) - position_ids = position_ids.add(rope_delta) - - if attention_mask is None and tensor_stream is None: - attention_mask = torch.ones((batch_size, seq_len), device=input_device, dtype=torch.long) - - if tensor_stream is not None: - modality_tensor = modality_mask(tensor_stream) - else: - modality_tensor = torch.full( - (batch_size, seq_len), text_value, device=position_ids.device, dtype=torch.long - ) + rope_delta = (cache_position[0] + self.rope_deltas).to(base_position_ids.device) + if not isinstance(rope_delta, int): + rope_delta = rope_delta.repeat_interleave(base_position_ids.shape[0] // rope_delta.shape[0], dim=0) + position_ids = base_position_ids.add(rope_delta) outputs = self.model( input_ids=input_ids, tensor_stream=tensor_stream, attention_mask=attention_mask, position_ids=position_ids, - modality_tensor=modality_tensor, + modality_tensor=None, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) @@ -1588,22 +1560,11 @@ def prepare_inputs_for_generation( else: model_inputs["tensor_stream"] = None - # TensorStream decode path: preserve rotary offsets from prefill + # TensorStream decode path: preserve rotary offsets from prefill; let forward rebuild positions if tensor_stream is not None and not first_step and self.rope_deltas is not None: model_inputs["position_ids"] = None return model_inputs - # For decode steps, synthesize position_ids that continue from the cache offsets - if model_inputs.get("position_ids") is None and cache_position is not None and not first_step: - batch_size = 1 - if model_inputs.get("input_ids") is not None: - batch_size = model_inputs["input_ids"].shape[0] - elif model_inputs.get("inputs_embeds") is not None: - batch_size = model_inputs["inputs_embeds"].shape[0] - pos_ids = cache_position.view(1, -1).expand(batch_size, -1) - pos_ids = pos_ids.unsqueeze(-1).expand(-1, -1, 3) - model_inputs["position_ids"] = pos_ids - return model_inputs @classmethod diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 4c67f7f0d355..6cba46b79a85 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -26,6 +26,7 @@ is_perceptron_available, is_torch_available, is_torchdynamo_compiling, + is_torchvision_available, is_vision_available, ) @@ -41,6 +42,8 @@ else: Image = None +if is_torchvision_available(): + from ..pix2struct.image_processing_pix2struct_fast import torch_extract_patches if is_perceptron_available(): from perceptron.tensorstream.ops import ( @@ -84,12 +87,10 @@ reorder_images, ) from ...image_utils import ( - ChannelDimension, PILImageResampling, ) -from ...masking_utils import create_masks_for_generate, packed_sequence_mask_function +from ...masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, create_masks_for_generate, packed_sequence_mask_function from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...models.auto.modeling_auto import AutoModel from ...models.auto.tokenization_auto import AutoTokenizer @@ -101,7 +102,7 @@ # Vision preprocessing constants from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD -from ...utils.generic import TransformersKwargs, can_return_tuple, check_model_inputs +from ...utils.generic import OutputRecorder, TransformersKwargs, can_return_tuple, check_model_inputs from ..qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling from ..siglip2.configuration_siglip2 import Siglip2VisionConfig from ..siglip2.modeling_siglip2 import ( @@ -176,36 +177,27 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): resample = PILImageResampling.BILINEAR model_input_names = ["patches", "token_grids"] - valid_kwargs = IsaacImageProcessorKwargs + valid_kwargs = IsaacImageProcessorFastKwargs unused_kwargs = ["size", "do_center_crop", "crop_size"] do_resize = True - size: Optional[SizeDict] = None - default_to_square: Optional[bool] = None do_center_crop = False - crop_size: Optional[SizeDict] = None patch_size: Optional[int] = 16 max_num_patches: Optional[int] = 256 min_num_patches: Optional[int] = None pixel_shuffle_scale: Optional[int] = 1 do_pad = False - pad_size: Optional[SizeDict] = None do_rescale = True - rescale_factor = 1 / 255 do_normalize = True image_mean = list(VISION_MEAN) image_std = list(VISION_STD) do_convert_rgb = True - return_tensors = None - data_format = ChannelDimension.FIRST - input_data_format = None - device = None disable_grouping = False size_divisor: Optional[int] = None def __init__( self, - **kwargs: Unpack[IsaacImageProcessorKwargs], + **kwargs: Unpack[IsaacImageProcessorFastKwargs], ) -> None: super().__init__(**kwargs) @@ -341,7 +333,7 @@ def _preprocess( nhwc_images = image_batch.permute(0, 2, 3, 1) nhwc_images = _compute_residual_p_frames(nhwc_images, is_p_frame=[False] * batch_size) - patches = patchify_vision(nhwc_images, patch_size=patch_size) + patches = torch_extract_patches(nhwc_images.permute(0, 3, 1, 2), patch_size, patch_size) _, height_tokens, width_tokens, _ = patches.shape token_grid = ( @@ -430,32 +422,39 @@ def document_mask_function_from_cu_seqlens(cu_seqlens: Optional[torch.Tensor]) - return packed_sequence_mask_function(packed_sequence_mask) -def ensure_document_attention_mask( - attention_mask: Optional[torch.Tensor], +def create_document_attention_mask( + config: PretrainedConfig, + input_embeds: torch.Tensor, cu_seqlens: Optional[torch.Tensor], - total_tokens: int, - dtype: torch.dtype, - device: torch.device, - *, - return_mask_function: bool = False, -) -> Optional[Union[torch.Tensor, Callable]]: - """Return the provided mask, a callable mask from ``cu_seqlens``, or ``None``. - - ``return_mask_function=True`` yields a callable suitable for ``masking_utils``; otherwise - ``None`` is returned when no explicit ``attention_mask`` is provided. The legacy additive mask - has been removed in favor of the callable-based path. - """ +) -> Optional[Union[torch.Tensor, Any]]: + """Materialize a backend-specific block-diagonal attention mask. - if attention_mask is not None: - return attention_mask + This uses the standard `masking_utils` mask interface (same mechanism as Llama4), + so the returned object matches the selected attention backend (e.g. SDPA bool mask, + eager additive mask, or flex `BlockMask`). + """ - if cu_seqlens is None: + mask_function = document_mask_function_from_cu_seqlens(cu_seqlens) + if mask_function is None: return None - if return_mask_function: - return document_mask_function_from_cu_seqlens(cu_seqlens) - - return None + seq_len = input_embeds.shape[1] + cache_position = torch.arange(seq_len, device=input_embeds.device, dtype=torch.long) + + mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] + return mask_interface( + batch_size=input_embeds.shape[0], + cache_position=cache_position, + kv_length=seq_len, + kv_offset=0, + mask_function=mask_function, + attention_mask=None, + allow_is_causal_skip=False, + allow_is_bidirectional_skip=False, + dtype=input_embeds.dtype, + config=config, + use_vmap=False, + ) class IsaacVisionEmbeddings(nn.Module): @@ -613,18 +612,11 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - past_key_value: Optional[torch.Tensor] = None, output_attentions: bool = False, - is_causal: bool = False, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, **kwargs, ): - # Ignore unused arguments for interface compatibility - _ = position_ids - _ = past_key_value - _ = is_causal kwargs.pop("output_hidden_states", None) kwargs.pop("return_dict", None) @@ -637,22 +629,10 @@ def forward( keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - if not queries.is_contiguous(): - queries = queries.contiguous() - if not keys.is_contiguous(): - keys = keys.contiguous() - if not values.is_contiguous(): - values = values.contiguous() - - L = queries.size(0) - if max_seqlen is not None: - max_q = max_k = int(max_seqlen) - else: - max_q = max_k = self._max_from_cu(cu_seqlens, L) - + attn_impl = self.config._attn_implementation attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"] - if self.config._attn_implementation != "sdpa": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if attn_impl != "sdpa": + attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] dropout = 0.0 if not self.training else self.dropout attention_kwargs: dict[str, Any] = { @@ -660,15 +640,36 @@ def forward( "scaling": self.scale, "dropout": dropout, } - if cu_seqlens is not None: - attention_kwargs["cu_seq_lens_q"] = cu_seqlens - attention_kwargs["cu_seq_lens_k"] = cu_seqlens - if max_seqlen is not None: - attention_kwargs["max_length_q"] = max_q - attention_kwargs["max_length_k"] = max_k - if output_attentions: + + supports_varlen = cu_seqlens is not None and attn_impl in { + "flash_attention_2", + "flash_attention_3", + "flex_attention", + "paged|flash_attention_2", + "paged|flash_attention_3", + } + + if output_attentions and attn_impl == "eager": attention_kwargs["output_attentions"] = True + if supports_varlen: + if max_seqlen is not None: + max_q = max_k = int(max_seqlen) + elif cu_seqlens.numel() >= 2: + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + max_q = max_k = lengths.max() if lengths.numel() > 0 else seq_length + else: + max_q = max_k = seq_length + + attention_kwargs.update( + { + "cu_seq_lens_q": cu_seqlens, + "cu_seq_lens_k": cu_seqlens, + "max_length_q": max_q, + "max_length_k": max_k, + } + ) + attn_output, attn_weights = attention_interface( self, queries, @@ -691,12 +692,6 @@ def forward( return attn_output, attn_weights - @staticmethod - def _max_from_cu(cu: Optional[torch.Tensor], fallback: int) -> int: - if cu is None or cu.numel() < 2: - return fallback - return int((cu[1:] - cu[:-1]).max().item()) - class IsaacVisionEncoderLayer(Siglip2EncoderLayer): """Isaac vision encoder layer with variable-length attention.""" @@ -722,30 +717,16 @@ def forward( Maximum document length referenced by `cu_seqlens`. Passed to FlashAttention so it can size temporary buffers for packed variable-length attention. """ - attention_mask = ensure_document_attention_mask( - attention_mask, - cu_seqlens, - hidden_states.size(1), - hidden_states.dtype, - hidden_states.device, - return_mask_function=False, - ) - # Run attention directly so variable-length metadata reaches FlashAttention. residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - attn_outputs = self.self_attn( + attn_output, _ = self.self_attn( hidden_states, attention_mask=attention_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, - output_attentions=output_attentions, **kwargs, ) - if isinstance(attn_outputs, tuple): - attn_output, attn_weights = attn_outputs - else: - attn_output, attn_weights = attn_outputs, None hidden_states = residual + attn_output residual = hidden_states @@ -753,8 +734,6 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - if output_attentions: - return hidden_states, attn_weights return hidden_states @@ -766,36 +745,14 @@ def __init__(self, config: IsaacVisionConfig): self.layers = nn.ModuleList([IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) @can_return_tuple + @check_model_inputs def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): - attention_mask = ensure_document_attention_mask( - attention_mask, - cu_seqlens, - inputs_embeds.size(1), - inputs_embeds.dtype, - inputs_embeds.device, - return_mask_function=False, - ) - hidden_states = inputs_embeds - kwargs.update( - { - "max_seqlen": max_seqlen, - "cu_seqlens": cu_seqlens, - "output_attentions": output_attentions, - "output_hidden_states": output_hidden_states, - "return_dict": return_dict, - } - ) for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, @@ -898,15 +855,15 @@ def pixel_shuffle_varlen( Raises: ValueError: If more than one batch item is provided. """ - keep_batch_dim = x.dim() == 3 - if keep_batch_dim: + return_with_batch_dim = x.dim() == 3 + if return_with_batch_dim: if x.size(0) != 1: raise AssertionError("Packed sequence is expected to have batch_size == 1") - x_ = x.squeeze(0) # (seq, embed) + embeddings = x.squeeze(0) # (seq, embed) else: - x_ = x # (seq, embed) + embeddings = x # (seq, embed) - embed_dim = x_.size(-1) + embed_dim = embeddings.size(-1) scale_factor = int(scale_factor) # Calculate seq_sizes from token_grids @@ -917,17 +874,17 @@ def pixel_shuffle_varlen( seq_sizes=seq_sizes, token_grids=token_grids, scale_factor=scale_factor, - device=x_.device, + device=embeddings.device, ) # (new_seq, scale_factor**2) # Gather โ†’ (new_seq, scale_factor**2, embed_dim) - gathered = x_[gather_idx] # fancy indexing keeps gradient + gathered = embeddings[gather_idx] # fancy indexing keeps gradient # Merge the scale_factor**2 group dimension into channels to finish the shuffle out = gathered.reshape(gathered.size(0), embed_dim * scale_factor * scale_factor) # Restore batch dimension if needed - if keep_batch_dim: + if return_with_batch_dim: out = out.unsqueeze(0) return out @@ -956,14 +913,14 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): # Generate cumulative sequence lengths for variable-length attention cu_seqlens = torch.zeros(seq_sizes.size(0) + 1, dtype=torch.int32, device=hidden_states.device) cu_seqlens[1:] = seq_sizes.cumsum(0) - max_seqlen = int(seq_sizes.max().item()) if seq_sizes.numel() > 0 else 0 + + attention_mask = create_document_attention_mask(self.config, hidden_states, cu_seqlens) # Pass through encoder with variable-length attention parameters encoder_outputs = self.encoder( inputs_embeds=hidden_states, + attention_mask=attention_mask, cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - return_dict=True, ) hidden_states = encoder_outputs.last_hidden_state @@ -1094,33 +1051,6 @@ def get_image_size_for_max_num_patches( return target_height, target_width -def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: - r"""Convert normalized images into flattened ViT-style patches. - - Args: - image (`torch.Tensor`): - Tensor of shape `(num_images, height, width, channels)`. - patch_size (`int`): - Edge length of the square patches - - Returns: - `torch.Tensor`: - Patch tensor where each position stores the flattened pixels belonging to that patch. - - Raises: - ValueError: If `height` or `width` is not divisible by `patch_size`. - """ - num_images, height, width, channels = image.shape - if height % patch_size or width % patch_size: - raise ValueError(f"Dimensions of images {image.shape} are not divisible by patch_size={patch_size}.") - patches = image.reshape(num_images, height // patch_size, patch_size, width // patch_size, patch_size, channels) - patches = patches.permute(0, 1, 3, 2, 4, 5) - patches = patches.reshape( - num_images, height // patch_size, width // patch_size, channels * patch_size * patch_size - ) - return patches - - class IsaacConfig(PretrainedConfig): """Configuration class for Isaac multimodal model. @@ -1141,25 +1071,25 @@ def __init__( vision_token: str = "", **kwargs, ): - self._rope_parameters: Optional[dict[str, Any]] = None attn_implementation = kwargs.get("attn_implementation") if isinstance(text_config, dict): self.text_config = self.sub_configs["text_config"](**text_config) + elif isinstance(text_config, Qwen3Config): + self.text_config = text_config elif text_config is None: self.text_config = self.sub_configs["text_config"]() - super().__init__(**kwargs) + # Seed RoPE parameters before base init so the shared mixin can standardize/validate them. + self.rope_parameters = getattr(self.text_config, "rope_parameters", None) + self.layer_types = getattr(self.text_config, "layer_types", None) - if self._rope_scaling is None: - self._rope_scaling = getattr(self.text_config, "rope_scaling", None) - else: - self.text_config.rope_scaling = self._rope_scaling + super().__init__(**kwargs) - # Keep rope parameters alias in sync with upstream expectations - self._rope_parameters = self._rope_scaling + # Keep rope parameters aligned between the composite and text sub-configs. + self.text_config.rope_parameters = self.rope_parameters - # Mirror frequently accessed Qwen3 attributes at the composite config level for BC. + # Mirror frequently accessed Qwen3 attributes at the composite config level self.vocab_size = self.text_config.vocab_size self.hidden_size = self.text_config.hidden_size self.num_hidden_layers = self.text_config.num_hidden_layers @@ -1167,10 +1097,7 @@ def __init__( self.head_dim = self.text_config.head_dim self.hidden_act = self.text_config.hidden_act self.use_cache = self.text_config.use_cache - self.rope_theta = self.text_config.rope_parameters["rope_theta"] - - # Validate rotary parameters now that they have been mirrored locally. - rope_config_validation(self) + self.rope_theta = self.rope_parameters["rope_theta"] self.layer_types = getattr(self.text_config, "layer_types", None) layer_type_validation(self.layer_types, self.num_hidden_layers) @@ -1199,33 +1126,6 @@ def __init__( self.max_sequence_length = max_sequence_length self.vision_token = vision_token - @property - def rope_scaling(self): - if hasattr(self, "text_config") and self.text_config is not None: - return getattr(self.text_config, "rope_scaling", None) - return self._rope_scaling - - @rope_scaling.setter - def rope_scaling(self, value): - self._rope_scaling = value - if hasattr(self, "text_config") and self.text_config is not None: - self.text_config.rope_scaling = value - - @property - def rope_parameters(self) -> dict[str, Any] | None: - """Alias introduced upstream for rope scaling dictionaries.""" - value = self._rope_parameters - if value is None: - value = self.rope_scaling - if value is None: - return {"rope_type": "default"} - return value - - @rope_parameters.setter - def rope_parameters(self, value: dict[str, Any] | None) -> None: - self._rope_parameters = value - self.rope_scaling = value - def to_dict(self): output = super().to_dict() # Ensure nested configs round-trip through dict serialization @@ -1467,12 +1367,10 @@ def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: return position_ids -class IsaacRotaryEmbedding(nn.Module): +class IsaacRotaryEmbedding(qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding): EXTRA_ROPE_KEYS = {"mrope_section", "mrope_interleaved"} def __init__(self, config: IsaacConfig, device=None): - super().__init__() - rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} @@ -1481,9 +1379,9 @@ def __init__(self, config: IsaacConfig, device=None): config_for_rope.rope_scaling = sanitized_scaling if sanitized_scaling else None init_device = device if device is not None and getattr(device, "type", None) != "meta" else None - self._qwen_rotary = qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding(config_for_rope, device=init_device) + super().__init__(config_for_rope, device=init_device) - rotary_half_dim = self._qwen_rotary.inv_freq.shape[0] + rotary_half_dim = self.inv_freq.shape[0] self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim) self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size @@ -1509,10 +1407,6 @@ def _combine_axes(self, tensor: torch.Tensor) -> torch.Tensor: chunks = tensor.split(split_sections, dim=-1) return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) - @property - def inv_freq(self) -> torch.Tensor: - return self._qwen_rotary.inv_freq - def forward( self, position_ids: torch.Tensor, @@ -1544,7 +1438,7 @@ def forward( pos_axes = pos.permute(2, 0, 1).contiguous() - cos_axes, sin_axes = self._qwen_rotary(hidden_states, pos_axes) + cos_axes, sin_axes = super().forward(hidden_states, pos_axes) cos_axes = cos_axes.to(hidden_states.dtype) sin_axes = sin_axes.to(hidden_states.dtype) @@ -1559,6 +1453,7 @@ class IsaacModel(Qwen3PreTrainedModel): supports_gradient_checkpointing = True _can_compile_fullgraph = False _supports_flex_attn = False + _can_record_outputs = {"attentions": OutputRecorder(IsaacVisionAttention, index=1)} # Expose tied-weights mapping even if empty for base model tests. all_tied_weights_keys: dict[str, str] = {} @@ -1618,12 +1513,8 @@ def embed_tokens(self, value: nn.Module) -> None: self.text_model.embed_tokens = value @property - def layers(self) -> nn.ModuleList: - return self.text_model.layers - - @property - def norm(self) -> nn.Module: - return self.text_model.norm + def vision_model(self) -> nn.Module: + return self.vision_embedding.vision_tower @property def vision_model(self) -> nn.Module: @@ -1680,6 +1571,62 @@ def embed_stream(self, tensor_stream: TensorStream) -> torch.Tensor: h = embedded_ts.compact() # (B, T, D) return h + @staticmethod + def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: + return compute_position_ids_input_ids(input_ids) + + def _prepare_position_and_modality( + self, + position_ids: Optional[torch.LongTensor], + modality_tensor: Optional[torch.LongTensor], + tensor_stream: Optional[TensorStream], + inputs_embeds: torch.Tensor, + cache_position: torch.LongTensor, + ) -> tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.Tensor, torch.Tensor]: + text_value = TextType.text.value if TextType is not None else 0 + batch_size, seq_len = inputs_embeds.shape[:2] + + if modality_tensor is None: + if tensor_stream is not None: + modality_tensor = modality_mask(tensor_stream) + else: + modality_tensor = torch.full( + (batch_size, seq_len), text_value, device=inputs_embeds.device, dtype=torch.long + ) + else: + modality_tensor = modality_tensor.to(device=inputs_embeds.device, dtype=torch.long) + expected_shape = (batch_size, seq_len) + if modality_tensor.shape != torch.Size(expected_shape): + raise ValueError( + f"modality_tensor must have shape (batch_size, seq_len) {expected_shape}, " + f"but got {tuple(modality_tensor.shape)}" + ) + + if position_ids is None: + if tensor_stream is not None: + position_ids = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) + else: + position_ids = cache_position.view(1, -1).expand(modality_tensor.shape[0], -1) + + if position_ids.ndim == 2: + position_ids = position_ids.to(device=inputs_embeds.device) + position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) + + if position_ids.shape[1] != seq_len: + start_positions = position_ids[:, :1, 0] + position_ids = torch.arange(seq_len, device=inputs_embeds.device).view(1, -1) + position_ids = position_ids + start_positions + position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) + + cos, sin = self.rotary_emb( + position_ids, + modality_tensor, + hidden_states=inputs_embeds, + ) + + decoder_position_ids = position_ids[..., 0] if position_ids.ndim == 3 else position_ids + return position_ids, modality_tensor, decoder_position_ids, cos, sin + @auto_docstring @check_model_inputs def forward( @@ -1692,11 +1639,8 @@ def forward( past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPast: """ Forward pass with MRoPE position embeddings. @@ -1714,7 +1658,7 @@ def forward( omitted. """ - text_value = TextType.text.value if TextType is not None else 0 + output_attentions = kwargs.pop("output_attentions", None) # Get inputs if tensor_stream is not None and inputs_embeds is not None: @@ -1744,54 +1688,13 @@ def forward( if attention_mask is None: attention_mask = torch.ones((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long) - # Normalize modality tensor - if modality_tensor is None: - if tensor_stream is not None: - modality_tensor = modality_mask(tensor_stream) - else: - modality_tensor = torch.full( - (batch_size, seq_len), text_value, device=inputs_embeds.device, dtype=torch.long - ) - else: - modality_tensor = modality_tensor.to(dtype=torch.long) - - if modality_tensor.shape[1] != seq_len: - if modality_tensor.shape[1] > seq_len: - modality_tensor = modality_tensor[:, :seq_len] - else: - pad = modality_tensor[:, -1:].expand(-1, seq_len - modality_tensor.shape[1]) - modality_tensor = torch.cat([modality_tensor, pad], dim=1) - - # Normalize position ids - if position_ids is None: - if tensor_stream is not None: - position_ids = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) - else: - position_ids = cache_position.view(1, -1).expand(modality_tensor.shape[0], -1) - - # Expand 2D position ids (from generic padding tests or decode cache positions) to 3D MRoPE coords - if position_ids.ndim == 2: - position_ids = position_ids.to(device=inputs_embeds.device) - position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) - - # Align lengths so rotary embedding sees matching shapes - if position_ids.shape[1] != seq_len: - start_positions = position_ids[:, :1, 0] - position_ids = torch.arange(seq_len, device=inputs_embeds.device).view(1, -1) - position_ids = position_ids + start_positions - position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) - - # Compute MRoPE position embeddings if we have custom rotary_emb - cos, sin = self.rotary_emb( - position_ids, - modality_tensor, - hidden_states=inputs_embeds, + position_ids, modality_tensor, decoder_position_ids, cos, sin = self._prepare_position_and_modality( + position_ids=position_ids, + modality_tensor=modality_tensor, + tensor_stream=tensor_stream, + inputs_embeds=inputs_embeds, + cache_position=cache_position, ) - cos = cos.to(inputs_embeds.dtype) - sin = sin.to(inputs_embeds.dtype) - - # Flash attention expects 1D position_ids; keep 3D only for rotary phases - decoder_position_ids = position_ids[..., 0] if position_ids.ndim == 3 else position_ids # Prepare attention mask if not isinstance(attention_mask, dict): @@ -1848,6 +1751,7 @@ class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): config_class = IsaacConfig _can_compile_fullgraph = False _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"} + all_tied_weights_keys: dict[str, str] = {"lm_head.weight": "model.text_model.embed_tokens.weight"} def __init__(self, config: IsaacConfig): super().__init__(config) @@ -1867,11 +1771,8 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | CausalLMOutputWithPast: r""" Forward pass for conditional generation supporting both standard inputs and TensorStream. @@ -1882,66 +1783,43 @@ def forward( `input_ids`. """ - # Don't compute embeddings here - let the model handle it + output_attentions = kwargs.pop("output_attentions", None) + + # Don't compute embeddings here - let the inner model handle it if tensor_stream is not None: input_ids = None if input_ids is None and inputs_embeds is None and tensor_stream is None: raise ValueError("Either input_ids, inputs_embeds, or tensor_stream must be provided.") - text_value = TextType.text.value if TextType is not None else 0 - - if tensor_stream is None: + # Record rope deltas on prefill when TensorStream is provided; leave position_ids building to IsaacModel. + if position_ids is None and tensor_stream is not None: + position_ids, self.rope_deltas = self.get_rope_index(input_ids, tensor_stream, attention_mask) + elif position_ids is None and cache_position is not None and self.rope_deltas is not None: + # Decode continuation after TensorStream prefill: advance positions using cached rope offsets. if input_ids is not None: - batch_size, seq_len = input_ids.shape - input_device = input_ids.device + base_position_ids = compute_position_ids_input_ids(input_ids) else: + if inputs_embeds is None: + raise ValueError("inputs_embeds must be provided when input_ids is None during decode") batch_size, seq_len = inputs_embeds.shape[:2] - input_device = inputs_embeds.device - - # Build position ids (MRoPE) if needed and tensor_stream is available - # During decode we reuse `self.rope_deltas` computed on the initial forward pass; `rope_delta` captures how far - # cached rotary phases have progressed so we can advance `position_ids` without rebuilding the TensorStream. - if position_ids is None: - if tensor_stream is not None: - position_ids, self.rope_deltas = self.get_rope_index(input_ids, tensor_stream, attention_mask) - elif input_ids is None: - dummy_ids = torch.zeros((batch_size, seq_len), device=input_device, dtype=torch.long) - position_ids = compute_position_ids_input_ids(dummy_ids) - else: - position_ids = compute_position_ids_input_ids(input_ids) - - rope_delta = 0 - if cache_position is not None and self.rope_deltas is not None: - # Combine the incremental decode step (`cache_position`) with cached offsets so hidden states continue - # rotating in lockstep across generation steps. - rope_delta = (cache_position[0] + self.rope_deltas).to(input_ids.device) - if not isinstance(rope_delta, int): # otherwise `deltas` is an int `0` - rope_delta = rope_delta.repeat_interleave(batch_size // rope_delta.shape[0], dim=0) - - position_ids = position_ids.add(rope_delta) - - if attention_mask is None and tensor_stream is None: - attention_mask = torch.ones((batch_size, seq_len), device=input_device, dtype=torch.long) + dummy_ids = torch.zeros((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long) + base_position_ids = compute_position_ids_input_ids(dummy_ids) - if tensor_stream is not None: - modality_tensor = modality_mask(tensor_stream) - else: - modality_tensor = torch.full( - (batch_size, seq_len), text_value, device=position_ids.device, dtype=torch.long - ) + rope_delta = (cache_position[0] + self.rope_deltas).to(base_position_ids.device) + if not isinstance(rope_delta, int): + rope_delta = rope_delta.repeat_interleave(base_position_ids.shape[0] // rope_delta.shape[0], dim=0) + position_ids = base_position_ids.add(rope_delta) outputs = self.model( input_ids=input_ids, tensor_stream=tensor_stream, attention_mask=attention_mask, position_ids=position_ids, - modality_tensor=modality_tensor, + modality_tensor=None, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) @@ -2062,22 +1940,11 @@ def prepare_inputs_for_generation( else: model_inputs["tensor_stream"] = None - # TensorStream decode path: preserve rotary offsets from prefill + # TensorStream decode path: preserve rotary offsets from prefill; let forward rebuild positions if tensor_stream is not None and not first_step and self.rope_deltas is not None: model_inputs["position_ids"] = None return model_inputs - # For decode steps, synthesize position_ids that continue from the cache offsets - if model_inputs.get("position_ids") is None and cache_position is not None and not first_step: - batch_size = 1 - if model_inputs.get("input_ids") is not None: - batch_size = model_inputs["input_ids"].shape[0] - elif model_inputs.get("inputs_embeds") is not None: - batch_size = model_inputs["inputs_embeds"].shape[0] - pos_ids = cache_position.view(1, -1).expand(batch_size, -1) - pos_ids = pos_ids.unsqueeze(-1).expand(-1, -1, 3) - model_inputs["position_ids"] = pos_ids - return model_inputs @classmethod diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 3a2792792449..594f9765db63 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -27,7 +27,11 @@ from ...models.auto.tokenization_auto import AutoTokenizer from ...processing_utils import ProcessorMixin from ...utils import TensorType -from ...utils.import_utils import is_perceptron_available, is_torch_available, is_vision_available +from ...utils.import_utils import ( + is_perceptron_available, + is_torch_available, + is_vision_available, +) from .configuration_isaac import IsaacConfig @@ -40,7 +44,6 @@ else: Image = None - if is_perceptron_available(): from perceptron.tensorstream.ops import slice as ts_slice from perceptron.tensorstream.ops import tensor_stream_token_view diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py index 99c3ffb0962b..d020082fee2a 100644 --- a/tests/models/isaac/test_modeling_isaac.py +++ b/tests/models/isaac/test_modeling_isaac.py @@ -38,10 +38,7 @@ from transformers.image_utils import load_image from transformers.masking_utils import eager_mask, sdpa_mask from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessorFast -from transformers.models.isaac.modeling_isaac import ( - document_mask_function_from_cu_seqlens, - ensure_document_attention_mask, -) +from transformers.models.isaac.modeling_isaac import document_mask_function_from_cu_seqlens from transformers.models.isaac.processing_isaac import IsaacProcessor from transformers.testing_utils import ( require_flash_attn, @@ -166,31 +163,6 @@ def test_document_mask_function_from_cu_seqlens(self): # Same second document (indices 3 and 4) self.assertTrue(mask_fn(0, 0, 4, 3)) - def test_ensure_document_attention_mask_prefers_callable_when_requested(self): - cu_seqlens = torch.tensor([0, 2, 5], dtype=torch.int32) - total_tokens = 5 - dtype = torch.float32 - - mask_callable = ensure_document_attention_mask( - attention_mask=None, - cu_seqlens=cu_seqlens, - total_tokens=total_tokens, - dtype=dtype, - device=cu_seqlens.device, - return_mask_function=True, - ) - self.assertTrue(callable(mask_callable)) - - additive = ensure_document_attention_mask( - attention_mask=None, - cu_seqlens=cu_seqlens, - total_tokens=total_tokens, - dtype=dtype, - device=cu_seqlens.device, - return_mask_function=False, - ) - self.assertIsNone(additive) - def test_document_mask_function_materializes_with_masking_utils(self): cu_seqlens = torch.tensor([0, 2, 4], dtype=torch.int32) total_tokens = 4 @@ -531,6 +503,33 @@ def test_model_forward(self): (self.model_tester.batch_size, self.model_tester.seq_length, config.hidden_size), ) + @require_tensorstream + def test_modality_tensor_requires_matching_shape(self): + config, input_ids, attention_mask, _ = self.model_tester.prepare_config_and_inputs() + model = IsaacModel(config).to(torch_device) + model.eval() + + modality_tensor = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length), + device=torch_device, + dtype=torch.long, + ) + with torch.no_grad(): + result = model(input_ids=input_ids, attention_mask=attention_mask, modality_tensor=modality_tensor) + + self.assertEqual( + result.last_hidden_state.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, config.hidden_size), + ) + + bad_modality_tensor = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length + 1), + device=torch_device, + dtype=torch.long, + ) + with self.assertRaisesRegex(ValueError, "modality_tensor must have shape"): + model(input_ids=input_ids, attention_mask=attention_mask, modality_tensor=bad_modality_tensor) + @require_tensorstream def test_for_conditional_generation(self): config, input_ids, attention_mask, labels = self.model_tester.prepare_config_and_inputs() From f4a63748eb3845ce3d1436fc3704aa01ef42a371 Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Fri, 19 Dec 2025 19:32:10 +0400 Subject: [PATCH 57/77] review changes (#13): separate projector class, removed redundant casting; docs around varlen siglip packing * style: drop unused function * fix: do not expose gradient checkpoint flag * refactor: move test files to fixtures direcrtory, refer to them in test file * style: remove redundant routing * feat: isaac model forward autodoc + check inputs * docs: add docstring for isaac forward * style: remove conditions on pixel shuffle scale * style: explicitly use torch dtype's min value to avoid float casting * style: remove unnecessary forward() configuration handling * refactor: move to updated masking API * chore: license + module docstring for test * docs: update license * refactor: doc mask rework wip 2 doc mask refactor finished callable rework * style: pass all args down for interface compatibility * style: remove extra cast on attention implementation * test: update tests * test: remove outdated tests * refactor: isolated vision embedding class * refactor: simplify attention flow to prepare for proper handling * refactor: simplify config * chore: convert artifacts * test: add text only test * style: cache * tests: expand integration testing * feat: use HF transformers attention_interface * fix: make config roundtrip * test: bring back config test * test: add common test mixins * fix: can generate is class method * wip 60 fail * fix: return hidden states * fix: make setting input embeddings vocab size aware * wip 2: 29 fail * fix: tied weight keys to correct submodule "text_model" * wip scary change allowing input embds * fix: post init call for tp_plan * fix: enable gradient checkpointing if specific by config * feat: init all weights if not from pretrained * fix: explicitly do not support flex attention * fix: allow attention setting * chore: convert artifacts * fix: temporarily drop _init_weights * test: skip assisted decoding tests, qwen3 doesn't support it * feat: handle 2d position ids for compatibility with HF tests * test: state expectation that model is composite * test: do not test for attention outputs given Qwen3 decoder * wip 3 failures * test: update skips * test: no longer asset that position ids is non * fix: sdpa default * wip 1 test failing * ALL TESTS PASSING * reduce diffs 1 (all tests passing) * reduce diffs 2 (all tests passing) * neeeded for modular tests (all tests passing) * final: latest diff reducer (all tests passing) * test: test flash attention 2 not 3 * attempt 1 logit equiv * move around * fix: hardcode forward to prevent copy issues from conversion tool * fix: allow gradient checkpointing * test: remove redundant tests implemented by HF harness * test: drop unused utility * test: move processor tests to isolated file * test: refactor document mask tests to isolated class for organization * test: drop redundant utilities * test: move logit equivalence test to proper setup in class * test: drop unused fixtures * test: organize imports * test: refactor isolated test to IsaacModelTest * test: remove unneeded generation fixture * test: drop unused logit stats assert helper * style: cleanup forward (all tests passing) * style: cleanup forward more * test: add point extraction test * test: clean up constants * test: separate isaac and base model * test: drop remaining fixtures tests all pass * tests: delete unused fixtures * chore: convert artifact * style: remove comment * docs: remove redundant device movement code * style: rename isaac image processor class name * style: improve variable nomenclature in pixel shuffle * style: remove nonstandard approach to not handling provided args * style: drop ensuring redundant contiguity * style: remove unneeded cast * style: keep in kwargs * chore: artifacts with all tests passing * refactor: use external patchify vision replacement * refactor: defer setting class attributes to base class * style: drop unneeded aliases * style: drop args in transformerskwargs * style: simplify arg flow in vision attention * refactor: rely more on ALL_MASK_ATTENTION_FUNCTIONS * style: attn_output is always a tuple * refactor: delegate attention output to output recorder * refactor: drop sync helper * style: remove max_seqlen derivation at transformer level * refactor: tag encoder with check model inputs for return dict tracking * refactor: drop aliases * refactor: output_attentions in kwargs * refactor: dedup position id handling wip refactor 2: use new helper wip refactor 3, deduplicate * refactor: inherit rope embedding functionality * feat: split up kwarg building wip refactor 2 only calculate cu seqlens in relevant implementations * fix: error out if modality tensor has incorrect shape * refactor: dedicated class * docs: explain variable resolution siglip training * test: dtype tests for debugging need for attention output casting * style: remove redundant casts * style: test lint * refactor: drop duplicated vision property --- .../models/isaac/modeling_isaac.py | 87 ++++++------- .../models/isaac/modular_isaac.py | 61 +++++---- .../models/isaac/processing_isaac.py | 2 +- tests/models/isaac/test_modeling_isaac.py | 116 +++++++++++++++++- 4 files changed, 194 insertions(+), 72 deletions(-) diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index cda31175685e..947208ab133d 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -81,7 +81,12 @@ class IsaacImageProcessorFastKwargs(ImagesKwargs, total=False): class IsaacVisionEmbeddings(nn.Module): - """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences.""" + """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences. + + Isaac accepts variable-resolution vision inputs as a single packed sequence with per-image + `token_grids`; packing/unpacking here reconstructs per-image shapes so we can resize positional + embeddings and build `cu_seqlens` for variable-length attention (not generic generation packing). + """ def __init__(self, config: IsaacVisionConfig): super().__init__() @@ -99,6 +104,8 @@ def __init__(self, config: IsaacVisionConfig): self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor: + # Rebatch packed variable-resolution patches to resize per-image position embeddings + # and track lengths for varlen attention metadata. packed_pixel_values, seq_lengths = self._pack_to_batch(seq_patches, spatial_shapes) if packed_pixel_values is None: return seq_patches.new_zeros((0, self.embed_dim)) @@ -182,6 +189,17 @@ def _pack_to_batch( seq_patches: torch.Tensor, spatial_shapes: torch.Tensor, ) -> tuple[Optional[torch.Tensor], torch.Tensor]: + """Rebatch a packed patch sequence using per-image grids to align embeddings. + + Args: + seq_patches (`torch.Tensor`): Packed patches of shape `(total_patches, patch_dim)`. + spatial_shapes (`torch.Tensor`): Per-image patch grids of shape `(num_images, 2)` as `(H_tokens, W_tokens)`. + + Returns: + `tuple[Optional[torch.Tensor], torch.Tensor]`: A padded batch tensor shaped + `(batch, max_len, patch_dim)` plus `seq_lengths` used to form `cu_seqlens` for + variable-length attention. + """ if seq_patches.ndim != 2: raise ValueError("`seq_patches` is expected to be 2D (total_patches, patch_dim).") if spatial_shapes.ndim != 2 or spatial_shapes.size(-1) != 2: @@ -216,6 +234,7 @@ def _pack_to_batch( return packed_pixel_values, seq_lengths def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor) -> torch.Tensor: + """Flatten a padded batch back to packed sequence order using `seq_lengths`.""" output_chunks: list[torch.Tensor] = [] for batch_idx, length in enumerate(seq_lengths.tolist()): if length == 0: @@ -325,14 +344,7 @@ def forward( attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() - # Align projection inputs with parameter dtype to avoid mixed-dtype matmul errors - out_proj_dtype = self.out_proj.weight.dtype - if attn_output.dtype != out_proj_dtype: - attn_output = attn_output.to(out_proj_dtype) - attn_output = self.out_proj(attn_output) - if attn_output.dtype != hidden_states.dtype: - attn_output = attn_output.to(hidden_states.dtype) return attn_output, attn_weights @@ -398,8 +410,6 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - if output_attentions: - return hidden_states, attn_weights return hidden_states @@ -422,15 +432,6 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: hidden_states = inputs_embeds - kwargs.update( - { - "max_seqlen": max_seqlen, - "cu_seqlens": cu_seqlens, - "output_attentions": output_attentions, - "output_hidden_states": output_hidden_states, - "return_dict": return_dict, - } - ) for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, @@ -675,6 +676,24 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): return hidden_states +class IsaacMultiModalProjector(nn.Module): + def __init__(self, config: IsaacConfig): + super().__init__() + self.vision_hidden_size = config.vision_config.hidden_size * ( + config.vision_config.pixel_shuffle_scale_factor**2 + ) + self.backbone_hidden_size = config.hidden_size + self.linear_1 = nn.Linear(self.vision_hidden_size, 4 * self.vision_hidden_size, bias=False) + self.silu = nn.SiLU() + self.linear_2 = nn.Linear(4 * self.vision_hidden_size, self.backbone_hidden_size, bias=False) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.silu(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + class IsaacVisionEmbedding(nn.Module): """Vision embedding wrapper exposing tower and projector.""" @@ -683,14 +702,9 @@ class IsaacVisionEmbedding(nn.Module): def __init__(self, config: IsaacConfig): super().__init__() vision_cfg = config.vision_config - hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2) self.vision_tower = IsaacVisionTransformer(vision_cfg) - self.multimodal_projector = nn.Sequential( - nn.Linear(hidden_dim, 4 * hidden_dim, bias=False), - nn.SiLU(), - nn.Linear(4 * hidden_dim, config.hidden_size, bias=False), - ) + self.multimodal_projector = IsaacMultiModalProjector(config) def forward(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: hidden_states = self.vision_tower(vision_tokens) @@ -1333,29 +1347,6 @@ class IsaacPreTrainedModel(PreTrainedModel): } -# ============================================================================ -# Model -# ============================================================================ - - -def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: - r"""Create 3D positional indices for token input. - - Args: - input_ids (`torch.Tensor`): - Tensor of shape `(batch_size, seq_len)` containing token ids. - - Returns: - `torch.Tensor`: Positional indices with shape `(batch_size, seq_len, 3)` where each channel duplicates the - 1D position so it can be consumed by the 3-axis MRoPE rotary embedding. - """ - batch_size, seq_length = input_ids.shape - position_ids = torch.arange(seq_length, device=input_ids.device) - position_ids = position_ids.view(1, -1).expand(batch_size, -1) - position_ids = position_ids.unsqueeze(2).expand(-1, -1, 3) # Add 3D for MRoPE - return position_ids - - @auto_docstring class IsaacForConditionalGeneration(IsaacPreTrainedModel, GenerationMixin): """Isaac multimodal model for conditional generation.""" diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 6cba46b79a85..c8b9a681a68f 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -458,7 +458,12 @@ def create_document_attention_mask( class IsaacVisionEmbeddings(nn.Module): - """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences.""" + """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences. + + Isaac accepts variable-resolution vision inputs as a single packed sequence with per-image + `token_grids`; packing/unpacking here reconstructs per-image shapes so we can resize positional + embeddings and build `cu_seqlens` for variable-length attention (not generic generation packing). + """ def __init__(self, config: IsaacVisionConfig): super().__init__() @@ -476,6 +481,8 @@ def __init__(self, config: IsaacVisionConfig): self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor: + # Rebatch packed variable-resolution patches to resize per-image position embeddings + # and track lengths for varlen attention metadata. packed_pixel_values, seq_lengths = self._pack_to_batch(seq_patches, spatial_shapes) if packed_pixel_values is None: return seq_patches.new_zeros((0, self.embed_dim)) @@ -559,6 +566,17 @@ def _pack_to_batch( seq_patches: torch.Tensor, spatial_shapes: torch.Tensor, ) -> tuple[Optional[torch.Tensor], torch.Tensor]: + """Rebatch a packed patch sequence using per-image grids to align embeddings. + + Args: + seq_patches (`torch.Tensor`): Packed patches of shape `(total_patches, patch_dim)`. + spatial_shapes (`torch.Tensor`): Per-image patch grids of shape `(num_images, 2)` as `(H_tokens, W_tokens)`. + + Returns: + `tuple[Optional[torch.Tensor], torch.Tensor]`: A padded batch tensor shaped + `(batch, max_len, patch_dim)` plus `seq_lengths` used to form `cu_seqlens` for + variable-length attention. + """ if seq_patches.ndim != 2: raise ValueError("`seq_patches` is expected to be 2D (total_patches, patch_dim).") if spatial_shapes.ndim != 2 or spatial_shapes.size(-1) != 2: @@ -593,6 +611,7 @@ def _pack_to_batch( return packed_pixel_values, seq_lengths def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor) -> torch.Tensor: + """Flatten a padded batch back to packed sequence order using `seq_lengths`.""" output_chunks: list[torch.Tensor] = [] for batch_idx, length in enumerate(seq_lengths.tolist()): if length == 0: @@ -681,14 +700,7 @@ def forward( attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() - # Align projection inputs with parameter dtype to avoid mixed-dtype matmul errors - out_proj_dtype = self.out_proj.weight.dtype - if attn_output.dtype != out_proj_dtype: - attn_output = attn_output.to(out_proj_dtype) - attn_output = self.out_proj(attn_output) - if attn_output.dtype != hidden_states.dtype: - attn_output = attn_output.to(hidden_states.dtype) return attn_output, attn_weights @@ -939,6 +951,24 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): return hidden_states +class IsaacMultiModalProjector(nn.Module): + def __init__(self, config: IsaacConfig): + super().__init__() + self.vision_hidden_size = config.vision_config.hidden_size * ( + config.vision_config.pixel_shuffle_scale_factor**2 + ) + self.backbone_hidden_size = config.hidden_size + self.linear_1 = nn.Linear(self.vision_hidden_size, 4 * self.vision_hidden_size, bias=False) + self.silu = nn.SiLU() + self.linear_2 = nn.Linear(4 * self.vision_hidden_size, self.backbone_hidden_size, bias=False) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.silu(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + class IsaacVisionEmbedding(nn.Module): """Vision embedding wrapper exposing tower and projector.""" @@ -947,14 +977,9 @@ class IsaacVisionEmbedding(nn.Module): def __init__(self, config: IsaacConfig): super().__init__() vision_cfg = config.vision_config - hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2) self.vision_tower = IsaacVisionTransformer(vision_cfg) - self.multimodal_projector = nn.Sequential( - nn.Linear(hidden_dim, 4 * hidden_dim, bias=False), - nn.SiLU(), - nn.Linear(4 * hidden_dim, config.hidden_size, bias=False), - ) + self.multimodal_projector = IsaacMultiModalProjector(config) def forward(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: hidden_states = self.vision_tower(vision_tokens) @@ -1516,14 +1541,6 @@ def embed_tokens(self, value: nn.Module) -> None: def vision_model(self) -> nn.Module: return self.vision_embedding.vision_tower - @property - def vision_model(self) -> nn.Module: - return self.vision_embedding.vision_tower - - @property - def vision_tower(self) -> nn.Module: - return self.vision_embedding.vision_tower - def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: """Embed text tokens, squeezing singleton dimensions.""" # Text events are shaped as (..., 1); squeeze the singleton index dim diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 594f9765db63..df90ae550756 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -110,7 +110,7 @@ def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> class IsaacProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = ("IsaacImageProcessorFast",) - tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + tokenizer_class = ("Qwen2Tokenizer",) def __init__( self, diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py index d020082fee2a..6425edfd6110 100644 --- a/tests/models/isaac/test_modeling_isaac.py +++ b/tests/models/isaac/test_modeling_isaac.py @@ -38,7 +38,11 @@ from transformers.image_utils import load_image from transformers.masking_utils import eager_mask, sdpa_mask from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessorFast -from transformers.models.isaac.modeling_isaac import document_mask_function_from_cu_seqlens +from transformers.models.isaac.modeling_isaac import ( + IsaacVisionAttention, + IsaacVisionConfig, + document_mask_function_from_cu_seqlens, +) from transformers.models.isaac.processing_isaac import IsaacProcessor from transformers.testing_utils import ( require_flash_attn, @@ -627,6 +631,116 @@ def test_isaac_generation_with_tensor_stream(self): self.assertNotEqual(decoded_prompt.strip(), "") +@require_torch +@require_flash_attn +class IsaacAttentionDtypeTest(unittest.TestCase): + def _make_config(self): + return IsaacVisionConfig( + hidden_size=32, + intermediate_size=64, + num_hidden_layers=1, + num_attention_heads=4, + num_channels=3, + num_patches=64, + patch_size=4, + attention_dropout=0.0, + pixel_shuffle_scale_factor=1, + ) + + def _skip_if_no_cuda_bf16(self): + if not torch.cuda.is_available(): + pytest.skip("CUDA required for flash attention dtype/parity tests.") + if not torch.cuda.is_bf16_supported(): + pytest.skip("CUDA bfloat16 support required.") + + def test_flash_attention_matches_weight_dtype_bf16(self): + self._skip_if_no_cuda_bf16() + torch.manual_seed(0) + + device = torch.device("cuda") + config = self._make_config() + config._attn_implementation = "flash_attention_2" + + attn = IsaacVisionAttention(config).to(device=device, dtype=torch.bfloat16).eval() + + hidden_states = torch.randn(2, 4, config.hidden_size, device=device, dtype=torch.bfloat16) + + with torch.no_grad(): + attn_output, _ = attn(hidden_states) + + assert attn_output.dtype == attn.out_proj.weight.dtype + assert attn_output.dtype == hidden_states.dtype + + def test_flash_attention_matches_weight_dtype_bf16_with_padding(self): + self._skip_if_no_cuda_bf16() + torch.manual_seed(0) + + device = torch.device("cuda") + config = self._make_config() + config._attn_implementation = "flash_attention_2" + + attn = IsaacVisionAttention(config).to(device=device, dtype=torch.bfloat16).eval() + + hidden_states = torch.randn(2, 4, config.hidden_size, device=device, dtype=torch.bfloat16) + attention_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]], device=device, dtype=torch.bool) + + with torch.no_grad(): + attn_output, _ = attn(hidden_states, attention_mask=attention_mask) + + assert attn_output.dtype == attn.out_proj.weight.dtype + assert attn_output.dtype == hidden_states.dtype + + def test_flash_attention_matches_weight_dtype_bf16_with_cu_seqlens(self): + self._skip_if_no_cuda_bf16() + torch.manual_seed(0) + + device = torch.device("cuda") + config = self._make_config() + config._attn_implementation = "flash_attention_2" + + attn = IsaacVisionAttention(config).to(device=device, dtype=torch.bfloat16).eval() + + hidden_states = torch.randn(1, 5, config.hidden_size, device=device, dtype=torch.bfloat16) + cu_seqlens = torch.tensor([0, 3, 5], device=device, dtype=torch.int32) + + with torch.no_grad(): + attn_output, _ = attn(hidden_states, cu_seqlens=cu_seqlens, max_seqlen=3) + + assert attn_output.dtype == attn.out_proj.weight.dtype + assert attn_output.dtype == hidden_states.dtype + + def test_flash_attention_parity_with_sdpa_bf16(self): + self._skip_if_no_cuda_bf16() + torch.manual_seed(0) + + device = torch.device("cuda") + config_sdpa = self._make_config() + config_sdpa._attn_implementation = "sdpa" + + config_fa2 = self._make_config() + config_fa2._attn_implementation = "flash_attention_2" + + attn_sdpa = IsaacVisionAttention(config_sdpa).to(device=device, dtype=torch.bfloat16).eval() + attn_fa2 = IsaacVisionAttention(config_fa2).to(device=device, dtype=torch.bfloat16).eval() + + # Align weights so the only difference is the backend + attn_fa2.load_state_dict(attn_sdpa.state_dict()) + + hidden_states = torch.randn(2, 4, config_sdpa.hidden_size, device=device, dtype=torch.bfloat16) + + with torch.no_grad(): + out_sdpa, _ = attn_sdpa(hidden_states) + out_fa2, _ = attn_fa2(hidden_states) + + torch.testing.assert_close( + out_fa2.float(), + out_sdpa.float(), + rtol=1e-3, + atol=1e-3, + msg="FlashAttention2 output deviates from SDPA baseline beyond tolerance", + ) + + @require_torch @require_vision @slow From abba38b1855fe1a1e9474d438b98a9e456d372f0 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Thu, 25 Dec 2025 01:42:45 +0400 Subject: [PATCH 58/77] Squash merge pg/refactor_remove_tensorstream into main --- .../isaac/image_processing_isaac_fast.py | 127 +- .../models/isaac/modeling_isaac.py | 743 ++++------ .../models/isaac/modular_isaac.py | 1233 +++++++---------- .../models/isaac/processing_isaac.py | 346 +++-- tests/models/isaac/test_modeling_isaac.py | 290 ++-- tests/models/isaac/test_processing_isaac.py | 313 ++++- 6 files changed, 1377 insertions(+), 1675 deletions(-) diff --git a/src/transformers/models/isaac/image_processing_isaac_fast.py b/src/transformers/models/isaac/image_processing_isaac_fast.py index 58735df5fd60..f487c249571d 100644 --- a/src/transformers/models/isaac/image_processing_isaac_fast.py +++ b/src/transformers/models/isaac/image_processing_isaac_fast.py @@ -32,9 +32,7 @@ # Vision preprocessing constants from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD -from ...utils.import_utils import ( - is_torch_available, -) +from ...utils.import_utils import is_torch_available from .modeling_isaac import IsaacImageProcessorFastKwargs @@ -121,7 +119,8 @@ def get_image_size_for_max_num_patches( num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) if min_num_patches is not None and num_patches < min_num_patches: - # Scale up + # Scale up via binary search to satisfy the minimum patch budget while + # preserving divisibility by patch_size * pixel_shuffle_scale. scale_min, scale_max = 1.0, 100.0 while (scale_max - scale_min) >= eps: scale = (scale_min + scale_max) / 2 @@ -156,19 +155,6 @@ def get_image_size_for_max_num_patches( return target_height, target_width -def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> torch.Tensor: - """Compute residuals for P-frames to stay in sync with the training pipeline.""" - if not any(is_p_frame): - return frames - - frame_indices = torch.arange(len(is_p_frame), device=frames.device) - i_frame_mask = torch.tensor([not flag for flag in is_p_frame], device=frames.device) - last_i_indices = torch.cummax((i_frame_mask * (1 + frame_indices)), dim=0).values.long() - 1 - p_indices = frame_indices[torch.tensor(is_p_frame, device=frames.device)] - frames[p_indices] = frames[p_indices] - frames[last_i_indices[p_indices]] - return frames - - @auto_docstring class IsaacImageProcessorFast(BaseImageProcessorFast): MAX_PIXELS = 60_000_000 # 60โ€‘megapixel ceiling โ‰ˆ 8200 ร— 7300 px @@ -177,7 +163,7 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): resample = PILImageResampling.BILINEAR model_input_names = ["patches", "token_grids"] valid_kwargs = IsaacImageProcessorFastKwargs - unused_kwargs = ["size", "do_center_crop", "crop_size"] + unused_kwargs = ["size", "do_center_crop", "crop_size", "pad_size", "do_pad"] do_resize = True do_center_crop = False @@ -192,7 +178,6 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): image_std = list(VISION_STD) do_convert_rgb = True disable_grouping = False - size_divisor: Optional[int] = None def __init__( self, @@ -200,11 +185,6 @@ def __init__( ) -> None: super().__init__(**kwargs) - pixel_shuffle_scale = 1 if self.pixel_shuffle_scale is None else int(self.pixel_shuffle_scale) - if pixel_shuffle_scale < 1: - raise ValueError("`pixel_shuffle_scale` must be >= 1") - self.pixel_shuffle_scale = pixel_shuffle_scale - def _validate_preprocess_kwargs(self, **kwargs): # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for. kwargs.pop("do_resize", None) @@ -218,29 +198,10 @@ def resize( self, image: torch.Tensor, size: SizeDict, - interpolation: Optional[Any] = None, - antialias: bool = True, **kwargs, ) -> torch.Tensor: - if size.height is None or size.width is None: - raise ValueError("IsaacImageProcessorFast requires explicit `height` and `width` when resizing.") - - resize_mode: Any = interpolation - if hasattr(resize_mode, "value"): - resize_mode = resize_mode.value - elif hasattr(resize_mode, "name"): - resize_mode = resize_mode.name.lower() - elif resize_mode is None: - resize_mode = "bilinear" - - if isinstance(resize_mode, str): - mode_key = resize_mode.lower() - else: - mode_key = resize_mode - - resize_kwargs: dict[str, Any] = {} - if mode_key in {"linear", "bilinear", "bicubic", "trilinear"}: - resize_kwargs["align_corners"] = False + resize_kwargs: dict[str, Any] = {"align_corners": False} + resize_mode = "bilinear" return F.interpolate( image, @@ -253,10 +214,7 @@ def _preprocess( self, images: list[torch.Tensor], do_resize: bool, - size: Optional[SizeDict], interpolation: Optional[Any], - do_center_crop: bool, - crop_size: Optional[SizeDict], do_rescale: Optional[bool], rescale_factor: Optional[float], do_normalize: Optional[bool], @@ -264,8 +222,6 @@ def _preprocess( image_std: Optional[Union[float, Sequence[float]]], disable_grouping: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, - do_pad: Optional[bool] = None, - pad_size: Optional[SizeDict] = None, *, patch_size: Optional[int] = None, max_num_patches: Optional[int] = None, @@ -273,20 +229,15 @@ def _preprocess( pixel_shuffle_scale: Optional[int] = None, **kwargs, ) -> BatchFeature: - if do_center_crop: - raise ValueError("`do_center_crop` is not supported by IsaacImageProcessorFast.") - if do_pad: - raise ValueError("`do_pad` is not supported by IsaacImageProcessorFast.") - grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) - processed_patches_grouped: dict[tuple[int, ...], torch.Tensor] = {} - token_grids_grouped: dict[tuple[int, ...], torch.Tensor] = {} - virtual_dims_grouped: dict[tuple[int, ...], torch.Tensor] = {} - real_dims_grouped: dict[tuple[int, ...], torch.Tensor] = {} + + grouped_outputs = {} for shape, stacked_images in grouped_images.items(): if stacked_images.ndim != 4: - raise ValueError("Expected batched channel-first image tensors.") + raise ValueError( + f"Expected images shaped as (batch, channels, height, width); got shape {tuple(stacked_images.shape)}." + ) batch_size, channels, original_height, original_width = stacked_images.shape @@ -295,7 +246,9 @@ def _preprocess( channels = 3 if original_height * original_width > self.MAX_PIXELS: - raise ValueError(f"Image (w={original_width}, h={original_height}) > MAX=`{self.MAX_PIXELS}`") + raise ValueError( + f"Image area {original_height * original_width} (h={original_height}, w={original_width}) exceeds MAX_PIXELS={self.MAX_PIXELS}; enable resizing or provide smaller inputs." + ) target_height, target_width = get_image_size_for_max_num_patches( original_height, @@ -315,7 +268,9 @@ def _preprocess( ) else: if ((original_height % patch_size) != 0) or ((original_width % patch_size) != 0): - raise ValueError("Image dimensions must be divisible by patch_size when resize is disabled.") + raise ValueError( + f"Image dimensions (h={original_height}, w={original_width}) must be divisible by patch_size={patch_size} when resize is disabled; enable resizing or adjust the input resolution." + ) image_batch = stacked_images target_height, target_width = original_height, original_width @@ -329,10 +284,7 @@ def _preprocess( image_std=image_std, ) - nhwc_images = image_batch.permute(0, 2, 3, 1) - nhwc_images = _compute_residual_p_frames(nhwc_images, is_p_frame=[False] * batch_size) - - patches = torch_extract_patches(nhwc_images.permute(0, 3, 1, 2), patch_size, patch_size) + patches = torch_extract_patches(image_batch, patch_size, patch_size) _, height_tokens, width_tokens, _ = patches.shape token_grid = ( @@ -357,7 +309,7 @@ def _preprocess( if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): raise ValueError( - "Spatial dimensions must be divisible by pixel_shuffle_scale when pixel shuffle is enabled." + f"Token grid (h={height_tokens}, w={width_tokens}) must be divisible by pixel_shuffle_scale={pixel_shuffle_scale}; adjust resize/patch parameters or disable pixel shuffle." ) virtual_height = height_tokens // pixel_shuffle_scale virtual_width = width_tokens // pixel_shuffle_scale @@ -371,31 +323,24 @@ def _preprocess( .unsqueeze(0) .repeat(batch_size, 1) ) + grouped_outputs[shape] = (patches, token_grid, virtual_dim, real_dim) - processed_patches_grouped[shape] = patches - token_grids_grouped[shape] = token_grid - virtual_dims_grouped[shape] = virtual_dim - real_dims_grouped[shape] = real_dim - - patches_slices = reorder_images(processed_patches_grouped, grouped_images_index) - token_grid_slices = reorder_images(token_grids_grouped, grouped_images_index) - virtual_dim_slices = reorder_images(virtual_dims_grouped, grouped_images_index) - real_dim_slices = reorder_images(real_dims_grouped, grouped_images_index) - - patches_tensor = torch.stack(patches_slices, dim=0) - token_grids_tensor = torch.stack(token_grid_slices, dim=0) - virtual_dims_tensor = torch.stack(virtual_dim_slices, dim=0) - real_dims_tensor = torch.stack(real_dim_slices, dim=0) - - return BatchFeature( - data={ - "patches": patches_tensor, - "token_grids": token_grids_tensor, - "virtual_pixel_size": virtual_dims_tensor, - "real_pixel_size": real_dims_tensor, - }, - tensor_type=return_tensors, - ) + # Helper to reorder a single item of the tuple payloads using the same grouped_images_index + def _reorder_grouped_item( + grouped: dict[tuple[int, ...], tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], + grouped_index: dict[tuple[int, ...], list[int]], + item_idx: int, + ) -> list[torch.Tensor]: + return reorder_images({k: v[item_idx] for k, v in grouped.items()}, grouped_index) + + keys = ("patches", "token_grids", "virtual_pixel_size", "real_pixel_size") + tensors: dict[str, torch.Tensor] = {} + + for i, key in enumerate(keys): + slices = _reorder_grouped_item(grouped_outputs, grouped_images_index, i) + tensors[key] = torch.stack(slices, dim=0) + + return BatchFeature(data=tensors, tensor_type=return_tensors) __all__ = ["IsaacImageProcessorFast"] diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 947208ab133d..7f31ae47c480 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -21,8 +21,8 @@ import copy -from collections import defaultdict from collections.abc import Callable +from enum import IntEnum from typing import Any, Optional, Union from ...activations import ACT2FN @@ -36,13 +36,11 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...models.auto.modeling_auto import AutoModel -from ...models.qwen3.modeling_qwen3 import Qwen3PreTrainedModel +from ...models.qwen3.modeling_qwen3 import Qwen3Model, Qwen3PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring from ...utils.generic import OutputRecorder, TransformersKwargs, can_return_tuple, check_model_inputs from ...utils.import_utils import ( - is_perceptron_available, is_torch_available, is_torchdynamo_compiling, ) @@ -55,22 +53,18 @@ import torch.nn as nn import torch.nn.functional as F -if is_perceptron_available(): - from perceptron.tensorstream.ops import ( - compute_mrope_pos_tensor, - modality_mask, - reconstruct_tensor_stream_from_compact_dict, - ) - from perceptron.tensorstream.tensorstream import TensorStream, TextType, VisionType, group_streams -else: - ts_slice = None - Event = None - Stream = None - TensorStream = None - TextType = None - VisionType = None - create_stream = None - group_streams = None + +class ModalityType(IntEnum): + """ + Modality identifiers for events. + + Members: + image: Vision tokens (e.g., patches). + text: Textual tokens. + """ + + image = 0 + text = 1 class IsaacImageProcessorFastKwargs(ImagesKwargs, total=False): @@ -103,28 +97,6 @@ def __init__(self, config: IsaacVisionConfig): self.position_embedding_size = int(self.num_patches**0.5) self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) - def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor: - # Rebatch packed variable-resolution patches to resize per-image position embeddings - # and track lengths for varlen attention metadata. - packed_pixel_values, seq_lengths = self._pack_to_batch(seq_patches, spatial_shapes) - if packed_pixel_values is None: - return seq_patches.new_zeros((0, self.embed_dim)) - - target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding(packed_pixel_values.to(dtype=target_dtype)) - - positional_embeddings = self.position_embedding.weight.reshape( - self.position_embedding_size, - self.position_embedding_size, - -1, - ) - resized_positional_embeddings = self.resize_positional_embeddings( - positional_embeddings, spatial_shapes, max_length=packed_pixel_values.shape[1] - ) - - embeddings = patch_embeds + resized_positional_embeddings - return self._unpack_from_batch(embeddings, seq_lengths) - @staticmethod def resize_positional_embeddings( positional_embeddings: torch.Tensor, @@ -184,6 +156,36 @@ def resize_positional_embeddings( return resulted_positional_embeddings + @check_model_inputs + def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor`): + Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size) + spatial_shapes (`list[tuple[int, int]]`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + """ + # Rebatch packed variable-resolution patches to resize per-image position embeddings + # and track lengths for varlen attention metadata. + packed_pixel_values, seq_lengths = self._pack_to_batch(seq_patches, spatial_shapes) + if packed_pixel_values is None: + return seq_patches.new_zeros((0, self.embed_dim)) + + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(packed_pixel_values.to(dtype=target_dtype)) + + positional_embeddings = self.position_embedding.weight.reshape( + self.position_embedding_size, + self.position_embedding_size, + -1, + ) + resized_positional_embeddings = self.resize_positional_embeddings( + positional_embeddings, spatial_shapes, max_length=packed_pixel_values.shape[1] + ) + + embeddings = patch_embeds + resized_positional_embeddings + return self._unpack_from_batch(embeddings, seq_lengths) + def _pack_to_batch( self, seq_patches: torch.Tensor, @@ -192,59 +194,33 @@ def _pack_to_batch( """Rebatch a packed patch sequence using per-image grids to align embeddings. Args: - seq_patches (`torch.Tensor`): Packed patches of shape `(total_patches, patch_dim)`. - spatial_shapes (`torch.Tensor`): Per-image patch grids of shape `(num_images, 2)` as `(H_tokens, W_tokens)`. + seq_patches: Packed patches of shape (total_patches, patch_dim). + spatial_shapes: Per-image patch grids of shape (num_images, 2) as (H_tokens, W_tokens). Returns: - `tuple[Optional[torch.Tensor], torch.Tensor]`: A padded batch tensor shaped - `(batch, max_len, patch_dim)` plus `seq_lengths` used to form `cu_seqlens` for - variable-length attention. + (packed_pixel_values, seq_lengths) where: + - packed_pixel_values: (batch, max_len, patch_dim) padded with zeros, or None if batch_size == 0 + - seq_lengths: (batch,) lengths for each image """ - if seq_patches.ndim != 2: - raise ValueError("`seq_patches` is expected to be 2D (total_patches, patch_dim).") - if spatial_shapes.ndim != 2 or spatial_shapes.size(-1) != 2: - raise ValueError("`spatial_shapes` must have shape (num_images, 2) with (height_tokens, width_tokens).") - - seq_lengths = spatial_shapes.long().prod(dim=-1) - total_patches = int(seq_lengths.sum().item()) - if total_patches != seq_patches.size(0): - raise ValueError( - "Mismatch between packed patches and spatial shapes: got " - f"{seq_patches.size(0)} patches but spatial shapes imply {total_patches}." - ) - - batch_size = spatial_shapes.size(0) + # Per-image token counts + seq_lengths = spatial_shapes.long().prod(dim=-1) # (B,) + batch_size = int(seq_lengths.numel()) if batch_size == 0: return None, seq_lengths - max_length = int(seq_lengths.max().item()) - patch_dim = seq_patches.size(-1) - device = seq_patches.device - - packed_pixel_values = seq_patches.new_zeros((batch_size, max_length, patch_dim), device=device) - - start = 0 - for batch_idx, length in enumerate(seq_lengths.tolist()): - if length == 0: - continue - end = start + length - packed_pixel_values[batch_idx, :length] = seq_patches[start:end] - start = end - + # Split the packed sequence into per-image chunks, then pad to a batch + lengths_list = seq_lengths.tolist() + chunks = seq_patches.split(lengths_list, dim=0) + packed_pixel_values = nn.utils.rnn.pad_sequence(chunks, batch_first=True) # zero-padded by default return packed_pixel_values, seq_lengths def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor) -> torch.Tensor: """Flatten a padded batch back to packed sequence order using `seq_lengths`.""" - output_chunks: list[torch.Tensor] = [] - for batch_idx, length in enumerate(seq_lengths.tolist()): - if length == 0: - continue - output_chunks.append(embeddings[batch_idx, :length]) - - if not output_chunks: + lengths = seq_lengths.to(device=embeddings.device).tolist() + chunks = [embeddings[i, :l] for i, l in enumerate(lengths) if l > 0] + if not chunks: return embeddings.new_zeros((0, embeddings.size(-1))) - - return torch.cat(output_chunks, dim=0) + return torch.cat(chunks, dim=0) class IsaacVisionAttention(nn.Module): @@ -297,11 +273,9 @@ def forward( if attn_impl != "sdpa": attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] - dropout = 0.0 if not self.training else self.dropout attention_kwargs: dict[str, Any] = { "is_causal": False, "scaling": self.scale, - "dropout": dropout, } supports_varlen = cu_seqlens is not None and attn_impl in { @@ -311,10 +285,6 @@ def forward( "paged|flash_attention_2", "paged|flash_attention_3", } - - if output_attentions and attn_impl == "eager": - attention_kwargs["output_attentions"] = True - if supports_varlen: if max_seqlen is not None: max_q = max_k = int(max_seqlen) @@ -341,9 +311,7 @@ def forward( attention_mask, **attention_kwargs, ) - attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() - attn_output = self.out_proj(attn_output) return attn_output, attn_weights @@ -423,8 +391,7 @@ def __init__(self, config: IsaacVisionConfig): self.gradient_checkpointing = False # Ignore copy - @can_return_tuple - @check_model_inputs + @auto_docstring def forward( self, inputs_embeds, @@ -438,30 +405,8 @@ def forward( attention_mask, **kwargs, ) - return BaseModelOutput(last_hidden_state=hidden_states) - - -def document_mask_function_from_cu_seqlens(cu_seqlens: Optional[torch.Tensor]) -> Optional[Callable]: - """Return a mask function that blocks cross-document attention from packed ``cu_seqlens``. - - The returned callable matches the signature expected by ``masking_utils`` mask factories and - yields ``True`` only when query/key positions belong to the same packed segment. - """ - - if cu_seqlens is None: - return None - - if cu_seqlens.numel() < 2: - return None - - seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() - if seq_sizes.numel() == 0: - return None - total_tokens = int(seq_sizes.sum().item()) - seg_ids = torch.repeat_interleave(torch.arange(seq_sizes.numel(), device=cu_seqlens.device), seq_sizes) - packed_sequence_mask = seg_ids.view(1, total_tokens) - return packed_sequence_mask_function(packed_sequence_mask) + return BaseModelOutput(last_hidden_state=hidden_states) def create_document_attention_mask( @@ -469,16 +414,23 @@ def create_document_attention_mask( input_embeds: torch.Tensor, cu_seqlens: Optional[torch.Tensor], ) -> Optional[Union[torch.Tensor, Any]]: - """Materialize a backend-specific block-diagonal attention mask. + """ + Materialize a backend-specific block-diagonal attention mask from packed cu_seqlens. - This uses the standard `masking_utils` mask interface (same mechanism as Llama4), - so the returned object matches the selected attention backend (e.g. SDPA bool mask, - eager additive mask, or flex `BlockMask`). + Returns None if cu_seqlens is missing/degenerate. """ + if cu_seqlens is None or cu_seqlens.numel() < 2: + return None # Degenerate input: nothing to mask - mask_function = document_mask_function_from_cu_seqlens(cu_seqlens) - if mask_function is None: - return None + seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() + if seq_sizes.numel() == 0 or int(seq_sizes.sum()) == 0: + return None # All-empty segments produce no attention blocks + + seg_ids = torch.repeat_interleave( + torch.arange(seq_sizes.numel(), device=cu_seqlens.device), + seq_sizes, + ) + mask_function = packed_sequence_mask_function(seg_ids.view(1, -1)) seq_len = input_embeds.shape[1] cache_position = torch.arange(seq_len, device=input_embeds.device, dtype=torch.long) @@ -523,47 +475,29 @@ def create_pixel_shuffle_index_map( packed sequence for the j-th sub-patch that forms the i-th output token. """ - if device is None: - device = seq_sizes.device - - scale_factor = int(scale_factor) - if scale_factor < 2: - raise ValueError("`scale_factor` must be โ‰ฅ 2") - - # Safety: all spatial dims must be divisible by the scale factor - # Cannot run under torch compile fullgraph mode hence if not is_torchdynamo_compiling(): - if not ((token_grids[:, 0] % scale_factor == 0).all() and (token_grids[:, 1] % scale_factor == 0).all()): + if (token_grids % scale_factor).any(): raise AssertionError( - "Every (H,W) in `token_grids` must be divisible by " - f"scale_factor={scale_factor}, got {token_grids.tolist()}" + f"Every (H,W) in token_grids must be divisible by scale_factor={scale_factor}, got {token_grids.tolist()}" ) gather_chunks: list[torch.Tensor] = [] tok_offset = 0 - - for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist(), strict=False): - # Build the (H, W) grid of flat indices for this image - grid = torch.arange(seq_len, device=device, dtype=torch.int64) + tok_offset - grid = grid.view(h, w) # (H, W) - - # -------- identical ordering to your fixed-res routine -------- - # Step 1: split width into blocks of scale_factor - grid = grid.view(h, w // scale_factor, scale_factor) # (H, W/scale_factor, scale_factor) - # Step 2: now split height into blocks of scale_factor - grid = grid.view(h // scale_factor, scale_factor, w // scale_factor, scale_factor) - # (H/scale_factor, scale_factor, W/scale_factor, scale_factor) - # Step 3: final permutation to (H/scale_factor, W/scale_factor, scale_factor, scale_factor) - grid = grid.permute(0, 2, 1, 3).contiguous() # (H/scale_factor, W/scale_factor, scale_factor, scale_factor) - # Step 4: each (scale_factor, scale_factor) block forms one output token - gather_chunks.append(grid.reshape(-1, scale_factor * scale_factor)) - # (H*W / scale_factor**2, scale_factor**2) + for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist()): + # Flat indices for this image's packed segment + grid = torch.arange(seq_len, device=device, dtype=torch.int64).view(h, w) + tok_offset + + # Block into (H/s, W/s) groups; each group contributes s*s indices + grid = ( + grid.view(h // scale_factor, scale_factor, w // scale_factor, scale_factor) + .permute(0, 2, 1, 3) + .contiguous() + ) + gather_chunks.append(grid.view(-1, scale_factor * scale_factor)) tok_offset += seq_len - # Concatenate over all images in the packed batch - gather_idx = torch.cat(gather_chunks, dim=0) # (ฮฃ_i HแตขWแตข/scale_factor**2, scale_factor**2) - return gather_idx + return torch.cat(gather_chunks, dim=0) def pixel_shuffle_varlen( @@ -595,7 +529,9 @@ def pixel_shuffle_varlen( return_with_batch_dim = x.dim() == 3 if return_with_batch_dim: if x.size(0) != 1: - raise AssertionError("Packed sequence is expected to have batch_size == 1") + raise ValueError( + f"Packed vision sequences expect a singleton batch dimension; received batch_size={x.size(0)}." + ) embeddings = x.squeeze(0) # (seq, embed) else: embeddings = x # (seq, embed) @@ -606,7 +542,8 @@ def pixel_shuffle_varlen( # Calculate seq_sizes from token_grids seq_sizes = torch.prod(token_grids, dim=-1) - # Build index map and gather in one go + # Build a single gather index so pixel shuffle works on the packed stream + # without unpacking per-image grids. gather_idx = create_pixel_shuffle_index_map( seq_sizes=seq_sizes, token_grids=token_grids, @@ -627,6 +564,19 @@ def pixel_shuffle_varlen( class IsaacVisionTransformer(nn.Module): + """Vision tower that packs variable-resolution patches, applies varlen attention, and pixel-shuffles outputs. + + Args: + config (IsaacVisionConfig): Vision configuration with pixel-shuffle and patching parameters. + + Inputs: + packed_seq_patches (Tuple[Tensor, Tensor]): ``(patches, token_grids)`` where ``patches`` is a packed + patch sequence and ``token_grids`` holds per-image (H_tokens, W_tokens). + + Returns: + torch.Tensor: Vision embeddings after encoder + pixel shuffle, shaped ``(seq_len, hidden_size * s^2)``. + """ + _supports_sdpa = True def __init__(self, config: IsaacVisionConfig): @@ -644,12 +594,12 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): # Get embeddings from packed sequence hidden_states = self.embeddings(seq_patches, token_grids) - # Add a pseudo batch dimension for the encoder + # Add a pseudo batch dimension so we can reuse the batch-first encoder stack + # while still driving per-image cu_seqlens through the varlen attention path. hidden_states = hidden_states.unsqueeze(0) # Generate cumulative sequence lengths for variable-length attention - cu_seqlens = torch.zeros(seq_sizes.size(0) + 1, dtype=torch.int32, device=hidden_states.device) - cu_seqlens[1:] = seq_sizes.cumsum(0) + cu_seqlens = F.pad(seq_sizes.cumsum(0).to(torch.int32), (1, 0)) attention_mask = create_document_attention_mask(self.config, hidden_states, cu_seqlens) @@ -677,6 +627,8 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): class IsaacMultiModalProjector(nn.Module): + """Maps vision tower outputs to the text hidden size with a SiLU MLP.""" + def __init__(self, config: IsaacConfig): super().__init__() self.vision_hidden_size = config.vision_config.hidden_size * ( @@ -695,7 +647,17 @@ def forward(self, image_features): class IsaacVisionEmbedding(nn.Module): - """Vision embedding wrapper exposing tower and projector.""" + """Wraps the vision tower plus projection into the text hidden size. + + Args: + config (IsaacConfig): Composite config containing both vision and text settings. + + Inputs: + vision_tokens (Tuple[Tensor, Tensor]): Packed vision patches and token grids. + + Returns: + torch.Tensor: Projected vision embeddings aligned to the text hidden size. + """ _supports_sdpa = True @@ -712,15 +674,11 @@ def forward(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Ten class IsaacRotaryEmbedding(qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding): - EXTRA_ROPE_KEYS = {"mrope_section", "mrope_interleaved"} - def __init__(self, config: IsaacConfig, device=None): rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} - - sanitized_scaling = {k: v for k, v in rope_scaling.items() if k not in self.EXTRA_ROPE_KEYS} config_for_rope = copy.copy(rope_source_cfg) - config_for_rope.rope_scaling = sanitized_scaling if sanitized_scaling else None + config_for_rope.rope_scaling = rope_scaling init_device = device if device is not None and getattr(device, "type", None) != "meta" else None super().__init__(config_for_rope, device=init_device) @@ -738,12 +696,6 @@ def _resolve_mrope_section(section: Optional[list[int]], rotary_half_dim: int) - return base section = [int(v) for v in section] - if len(section) != 3: - raise ValueError("`mrope_section` must contain exactly three elements (temporal, height, width)") - if sum(section) != rotary_half_dim: - raise ValueError( - f"`mrope_section` must sum to the rotary half-dimension ({rotary_half_dim}). Received {section}." - ) return section def _combine_axes(self, tensor: torch.Tensor) -> torch.Tensor: @@ -757,11 +709,6 @@ def forward( modality_tensor: torch.Tensor, hidden_states: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - if position_ids.ndim != 3 or position_ids.size(-1) != 3: - raise ValueError("`position_ids` must have shape (batch, seq_len, 3) for MRoPE") - if modality_tensor.shape != position_ids.shape[:2]: - raise ValueError("`modality_tensor` must align with the first two dims of `position_ids`") - if hidden_states is None: batch, seq_len, _ = position_ids.shape hidden_states = torch.zeros( @@ -774,48 +721,23 @@ def forward( with torch.no_grad(): pos = position_ids.clone() - image_value = VisionType.image.value if VisionType is not None else 1 - not_spatial = modality_tensor != image_value + not_spatial = modality_tensor != ModalityType.image.value if not_spatial.any(): + # Collapse non-vision modalities to 1D positions so rotary embedding + # treats them like text tokens while keeping image tokens 3D. data_1d = pos[not_spatial][..., 0].unsqueeze(-1) pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) pos_axes = pos.permute(2, 0, 1).contiguous() cos_axes, sin_axes = super().forward(hidden_states, pos_axes) - cos_axes = cos_axes.to(hidden_states.dtype) sin_axes = sin_axes.to(hidden_states.dtype) - - cos_combined = self._combine_axes(cos_axes) - sin_combined = self._combine_axes(sin_axes) + cos_combined, sin_combined = self._combine_axes(cos_axes), self._combine_axes(sin_axes) return cos_combined, sin_combined -# ============================================================================ -# Model -# ============================================================================ - - -def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: - r"""Create 3D positional indices for token input. - - Args: - input_ids (`torch.Tensor`): - Tensor of shape `(batch_size, seq_len)` containing token ids. - - Returns: - `torch.Tensor`: Positional indices with shape `(batch_size, seq_len, 3)` where each channel duplicates the - 1D position so it can be consumed by the 3-axis MRoPE rotary embedding. - """ - batch_size, seq_length = input_ids.shape - position_ids = torch.arange(seq_length, device=input_ids.device) - position_ids = position_ids.view(1, -1).expand(batch_size, -1) - position_ids = position_ids.unsqueeze(2).expand(-1, -1, 3) # Add 3D for MRoPE - return position_ids - - @auto_docstring class IsaacModel(PreTrainedModel): config: IsaacConfig @@ -829,7 +751,6 @@ class IsaacModel(PreTrainedModel): _can_compile_fullgraph = False _supports_attention_backend = True _can_record_outputs = {"attentions": OutputRecorder(IsaacVisionAttention, index=1)} - # Expose tied-weights mapping even if empty for base model tests. all_tied_weights_keys: dict[str, str] = {} def __init__(self, config: IsaacConfig): @@ -837,30 +758,17 @@ def __init__(self, config: IsaacConfig): text_cfg_source = config.text_config text_cfg = copy.deepcopy(text_cfg_source) - self.text_model = AutoModel.from_config(text_cfg) - # Ensure downstream callers observe the composed config - self.text_model.config = config + self.text_model = Qwen3Model._from_config(text_cfg) + self.text_model.config = config # Ensure downstream callers observe the composed config self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) - if config.vision_config is None: - raise ValueError("IsaacConfig should always have vision_config") - self.vision_embedding = IsaacVisionEmbedding(config) self.vision_embedding._supports_sdpa = True - - # Dispatch table for TensorStream balanced embedding (text + vision) - self.embed_fns = { - TextType: self.embed_text_tokens, - VisionType: self.embed_vision, - } - - # Keep track of config attributes that downstream utilities may query directly on the model. self.max_sequence_length = config.max_sequence_length self.vision_rescale_factor = config.vision_rescale_factor self.vision_token = config.vision_token - # Initialize weights and parallel plans (including tp_plan from the text model) self.post_init() # Respect config-specified gradient checkpointing @@ -891,118 +799,73 @@ def embed_tokens(self, value: nn.Module) -> None: def vision_model(self) -> nn.Module: return self.vision_embedding.vision_tower - def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: - """Embed text tokens, squeezing singleton dimensions.""" - # Text events are shaped as (..., 1); squeeze the singleton index dim - h = self.text_model.embed_tokens(token_ids) - if h.dim() >= 2 and h.size(-2) == 1: - h = h[..., 0, :] - return h - - def embed_vision(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - """Embed vision tokens using the vision encoder.""" - # vision tokens is (seq_patches, token_grids) - return self.vision_embedding(vision_tokens) - - def embed_stream(self, tensor_stream: TensorStream) -> torch.Tensor: + def embed_packed_inputs( + self, input_ids: torch.Tensor, packed_inputs: dict[str, Optional[torch.Tensor]] + ) -> tuple[torch.Tensor, torch.Tensor]: """ - Embed each modality stream independently, preserving the original TensorStream - structure. + Expects input_ids for text tokens and packed_inputs containing: + - modality_tensor: (batch, seq_len) modality ids aligned to the sequence + - position_ids: (batch, seq_len, 3) MRoPE coordinates (optional) + - vision_patches: concatenated vision tokens shaped (total_tokens, embed_dim) or None + - vision_token_grids: (num_images, 2) token grid sizes or None + - vision_token_offsets: (num_images,) offsets into each image's virtual token span (optional) + - vision_token_lengths: (num_images,) surviving virtual token lengths per image (optional) """ - flat_stream = tensor_stream.flat_stream() - per_modality_stream = group_streams(flat_stream, group_fn=lambda ev: ev.type, schedule=False) - per_modality_compact_stream = {k: v.compact() for k, v in per_modality_stream.items()} - - # Collect per-event grids for vision tokens (H, W like dims sans time) - token_grids = defaultdict(list) - for stream in tensor_stream.streams: - for event in stream: - token_grids[event.type].append(event.dims(virtual=False)) - - embedded_compact = {} - for stream_type, modality_payload_tensor in per_modality_compact_stream.items(): - if stream_type.modality == VisionType: - # Build a (N_events, 2) grid tensor with spatial dims only - grids = token_grids.get(stream_type, []) - if len(grids) == 0: - input_tensor = modality_payload_tensor - else: - token_grids_tensor = torch.tensor(grids, dtype=torch.long, device=tensor_stream.device)[:, 1:] - input_tensor = (modality_payload_tensor, token_grids_tensor) - embedded_compact[stream_type] = self.embed_fns[stream_type.modality](input_tensor) - else: - embedded_compact[stream_type] = self.embed_fns[stream_type.modality](modality_payload_tensor) - - # Reconstruct a TensorStream with embedded payloads and compact - embedded_ts = reconstruct_tensor_stream_from_compact_dict(tensor_stream, embedded_compact) - h = embedded_ts.compact() # (B, T, D) - return h - - @staticmethod - def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: - return compute_position_ids_input_ids(input_ids) - - def _prepare_position_and_modality( - self, - position_ids: Optional[torch.LongTensor], - modality_tensor: Optional[torch.LongTensor], - tensor_stream: Optional[TensorStream], - inputs_embeds: torch.Tensor, - cache_position: torch.LongTensor, - ) -> tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.Tensor, torch.Tensor]: - text_value = TextType.text.value if TextType is not None else 0 - batch_size, seq_len = inputs_embeds.shape[:2] - - if modality_tensor is None: - if tensor_stream is not None: - modality_tensor = modality_mask(tensor_stream) - else: - modality_tensor = torch.full( - (batch_size, seq_len), text_value, device=inputs_embeds.device, dtype=torch.long - ) - else: - modality_tensor = modality_tensor.to(device=inputs_embeds.device, dtype=torch.long) - expected_shape = (batch_size, seq_len) - if modality_tensor.shape != torch.Size(expected_shape): - raise ValueError( - f"modality_tensor must have shape (batch_size, seq_len) {expected_shape}, " - f"but got {tuple(modality_tensor.shape)}" - ) - - if position_ids is None: - if tensor_stream is not None: - position_ids = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) - else: - position_ids = cache_position.view(1, -1).expand(modality_tensor.shape[0], -1) - - if position_ids.ndim == 2: - position_ids = position_ids.to(device=inputs_embeds.device) - position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) - - if position_ids.shape[1] != seq_len: - start_positions = position_ids[:, :1, 0] - position_ids = torch.arange(seq_len, device=inputs_embeds.device).view(1, -1) - position_ids = position_ids + start_positions - position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) - - cos, sin = self.rotary_emb( - position_ids, - modality_tensor, - hidden_states=inputs_embeds, - ) + modality = packed_inputs["modality_tensor"].to(device=input_ids.device, dtype=torch.long) + embeds = self.text_model.embed_tokens(input_ids) + + vision_patches = packed_inputs.get("vision_patches") + if vision_patches is None: + return embeds, modality + + token_grids = packed_inputs["vision_token_grids"].to(device=vision_patches.device, dtype=torch.long) + vision = self.vision_embedding((vision_patches, token_grids)) # (total_tokens, hidden) + + # per-image token counts AFTER pixel-shuffle + s = int(self.config.vision_config.pixel_shuffle_scale_factor) + sizes = token_grids.prod(-1).div(s * s, rounding_mode="floor").tolist() + offsets = packed_inputs.get("vision_token_offsets") + lengths = packed_inputs.get("vision_token_lengths") + + if offsets is not None or lengths is not None: + off = ( + offsets.to(device=vision.device, dtype=torch.long) + if offsets is not None + else torch.zeros(len(sizes), device=vision.device, dtype=torch.long) + ) + ln = ( + lengths.to(device=vision.device, dtype=torch.long) + if lengths is not None + else torch.tensor(sizes, device=vision.device, dtype=torch.long) + ) - decoder_position_ids = position_ids[..., 0] if position_ids.ndim == 3 else position_ids - return position_ids, modality_tensor, decoder_position_ids, cos, sin + # Honor per-image crop windows (after pixel shuffle) so we only splice back + # the surviving virtual tokens instead of the full vision span. + chunks = vision.split(sizes, dim=0) + picked: list[torch.Tensor] = [] + for c, n, o, l in zip(chunks, sizes, off.tolist(), ln.tolist()): + if n <= 0: + continue + o = max(0, min(int(o), n)) + l = max(0, min(int(l), n - o)) + if l: + picked.append(c[o : o + l]) + vision = torch.cat(picked, 0) if picked else vision.new_zeros((0, vision.size(-1))) + + m = modality == ModalityType.image.value + embeds = embeds.clone() + embeds[m] = vision.to(device=embeds.device, dtype=embeds.dtype) + + return embeds, modality @auto_docstring @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, - tensor_stream: Optional[TensorStream] = None, + packed_inputs: Optional[dict[str, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - modality_tensor: Optional[torch.LongTensor] = None, past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -1015,32 +878,29 @@ def forward( Computes position embeddings once and passes them through all layers. Args: - tensor_stream (`TensorStream`, *optional*): - Packed multimodal stream of text and vision events to embed directly. Mutually exclusive with - `input_ids` and `inputs_embeds`. When provided, the method derives `position_ids` and `modality_tensor` - if they are not supplied. + packed_inputs (`dict`, *optional*): + Plain tensor payloads extracted from a TensorStream. When provided, it replaces the TensorStream path + and requires `input_ids` for text tokens (or `text_token_ids` so `input_ids` can be rebuilt). modality_tensor (`torch.LongTensor`, *optional*): Modality identifiers aligned with the embedded sequence, shaped `(batch_size, seq_len)` and containing - values from `TextType`/`VisionType`. Automatically built from `tensor_stream` or `input_ids` when - omitted. + values from `ModalityType`. Automatically built from `packed_inputs` or treated as text-only when omitted. """ output_attentions = kwargs.pop("output_attentions", None) - # Get inputs - if tensor_stream is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both tensor_stream and inputs_embeds") - if tensor_stream is None and input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + # Resolve the input source (prefer packed_inputs > ids > embeds). + modality_tensor: Optional[torch.Tensor] = None + precomputed_position_ids: Optional[torch.Tensor] = None - # Resolve the input source (TensorStream takes precedence over token ids). - if tensor_stream is not None: - inputs_embeds = self.embed_stream(tensor_stream) + if packed_inputs is not None: + inputs_embeds, modality_tensor = self.embed_packed_inputs(input_ids, packed_inputs) + precomputed_position_ids = packed_inputs.get("position_ids") + if precomputed_position_ids is not None: + precomputed_position_ids = precomputed_position_ids.to(inputs_embeds.device) elif input_ids is not None: inputs_embeds = self.text_model.embed_tokens(input_ids) - elif inputs_embeds is None: - raise ValueError("You have to specify either tensor_stream, input_ids or inputs_embeds") + device = inputs_embeds.device batch_size, seq_len = inputs_embeds.shape[:2] # Ensure cache exists when requested @@ -1050,21 +910,36 @@ def forward( if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=device) if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long) + attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.long) - position_ids, modality_tensor, decoder_position_ids, cos, sin = self._prepare_position_and_modality( - position_ids=position_ids, - modality_tensor=modality_tensor, - tensor_stream=tensor_stream, - inputs_embeds=inputs_embeds, - cache_position=cache_position, - ) + position_ids = position_ids if position_ids is not None else precomputed_position_ids + if position_ids is None: + position_ids = cache_position.view(1, -1).expand(batch_size, -1) + + if modality_tensor is None: + modality_tensor = torch.full( + (batch_size, seq_len), ModalityType.text.value, device=device, dtype=torch.long + ) + else: + modality_tensor = modality_tensor.to(device=device, dtype=torch.long) + + position_ids = position_ids.to(device=device) + + if position_ids.ndim == 2: + position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) + if position_ids.shape[1] != seq_len: + start_positions = position_ids[:, :1, 0] + position_ids = torch.arange(seq_len, device=device).view(1, -1) + start_positions + position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) - # Prepare attention mask - if not isinstance(attention_mask, dict): + cos, sin = self.rotary_emb(position_ids, modality_tensor, hidden_states=inputs_embeds) + + decoder_position_ids = position_ids[..., 0] if position_ids.ndim == 3 else position_ids + + if not isinstance(attention_mask, dict): # Prepare attention mask attention_mask = create_masks_for_generate( config=self.config, input_embeds=inputs_embeds, @@ -1074,19 +949,15 @@ def forward( position_ids=decoder_position_ids, ) - is_attention_mask_dict = isinstance(attention_mask, dict) - - # Initialize hidden states + is_mask_dict = isinstance(attention_mask, dict) hidden_states = inputs_embeds all_attentions = [] if output_attentions else None - for decoder_layer in self.text_model.layers: - layer_attention_mask = ( - attention_mask[decoder_layer.attention_type] if is_attention_mask_dict else attention_mask - ) - layer_outputs = decoder_layer( + for layer in self.text_model.layers: + layer_mask = attention_mask[layer.attention_type] if is_mask_dict else attention_mask + layer_outputs = layer( hidden_states, - attention_mask=layer_attention_mask, + attention_mask=layer_mask, position_ids=decoder_position_ids, past_key_values=past_key_values, use_cache=use_cache, @@ -1101,7 +972,6 @@ def forward( if output_attentions and layer_outputs_is_tuple: all_attentions.append(layer_outputs[1]) - # Final layer norm hidden_states = self.text_model.norm(hidden_states) return BaseModelOutputWithPast( @@ -1364,18 +1234,18 @@ def __init__(self, config: IsaacConfig): self.model = IsaacModel(config) # Use our custom model self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - # Tracks rotary position offsets computed during a full forward pass so decode steps can reuse them. self.rope_deltas = None # Initialize weights and apply final processing self.post_init() - @can_return_tuple @auto_docstring + @can_return_tuple + @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, - tensor_stream: Optional[TensorStream] = None, + packed_inputs: Optional[dict[str, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[list[torch.FloatTensor]] = None, @@ -1385,36 +1255,45 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | CausalLMOutputWithPast: - r""" - Forward pass for conditional generation supporting both standard inputs and TensorStream. + """Run multimodal CausalLM forward, accepting packed vision/text inputs. - tensor_stream (`TensorStream`, *optional*): - Packed multimodal stream (text, vision, audio tokens) that already encodes spatial metadata. When provided, - the model derives embeddings, modality masks, and 3D rotary coordinates directly from the stream instead of - `input_ids`. - """ + Args: + input_ids: Text token ids. + packed_inputs (`dict`, *optional*): + Packed vision/text payload from ``IsaacProcessor`` containing modality ids, MRoPE position ids, and + vision patch tensors/grids (with optional offsets/lengths) used to rebuild embeddings. + attention_mask: Attention mask or mask dict; created if not provided. + position_ids: Optional 3D MRoPE positions; auto-derived when absent. + past_key_values: Cache for decoding. + inputs_embeds: Precomputed embeddings (bypass embedding layer). + labels: Target ids for computing language modeling loss. + use_cache: Whether to return caches. + cache_position: Positions for cache-aware generation. + Returns: + CausalLMOutputWithPast: logits, optional loss, caches, hidden states, attentions. + """ output_attentions = kwargs.pop("output_attentions", None) - # Don't compute embeddings here - let the inner model handle it - if tensor_stream is not None: - input_ids = None - if input_ids is None and inputs_embeds is None and tensor_stream is None: - raise ValueError("Either input_ids, inputs_embeds, or tensor_stream must be provided.") + if position_ids is None and packed_inputs is not None: + pos_3d = packed_inputs.get("position_ids") + if pos_3d is not None: + position_ids, self.rope_deltas = self.get_rope_index( + position_ids=pos_3d, + attention_mask=attention_mask, + ) - # Record rope deltas on prefill when TensorStream is provided; leave position_ids building to IsaacModel. - if position_ids is None and tensor_stream is not None: - position_ids, self.rope_deltas = self.get_rope_index(input_ids, tensor_stream, attention_mask) elif position_ids is None and cache_position is not None and self.rope_deltas is not None: - # Decode continuation after TensorStream prefill: advance positions using cached rope offsets. if input_ids is not None: - base_position_ids = compute_position_ids_input_ids(input_ids) + base_position_ids = torch.arange(input_ids.size(1), device=input_ids.device)[None, :, None].expand( + input_ids.size(0), -1, 3 + ) else: - if inputs_embeds is None: - raise ValueError("inputs_embeds must be provided when input_ids is None during decode") batch_size, seq_len = inputs_embeds.shape[:2] dummy_ids = torch.zeros((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long) - base_position_ids = compute_position_ids_input_ids(dummy_ids) + base_position_ids = torch.arange(dummy_ids.size(1), device=dummy_ids.device)[None, :, None].expand( + dummy_ids.size(0), -1, 3 + ) rope_delta = (cache_position[0] + self.rope_deltas).to(base_position_ids.device) if not isinstance(rope_delta, int): @@ -1423,10 +1302,9 @@ def forward( outputs = self.model( input_ids=input_ids, - tensor_stream=tensor_stream, + packed_inputs=packed_inputs, attention_mask=attention_mask, position_ids=position_ids, - modality_tensor=None, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -1453,43 +1331,44 @@ def forward( def set_input_embeddings(self, value: nn.Module) -> None: self.model.set_input_embeddings(value) vocab_size = getattr(value, "num_embeddings", None) - if vocab_size is not None: - self.config.vocab_size = vocab_size - self.model.config.vocab_size = vocab_size - if hasattr(self.model, "text_model"): - self.model.text_model.config.vocab_size = vocab_size - if self.lm_head.weight.shape[0] != vocab_size: - self.lm_head = nn.Linear(self.config.hidden_size, vocab_size, bias=False) - if hasattr(self.model, "embed_tokens"): - self.lm_head.weight = self.model.text_model.embed_tokens.weight + self.config.vocab_size = vocab_size + self.model.config.vocab_size = vocab_size + self.model.text_model.config.vocab_size = vocab_size + if self.lm_head.weight.shape[0] != vocab_size: + self.lm_head = nn.Linear(self.config.hidden_size, vocab_size, bias=False) + self.lm_head.weight = self.model.text_model.embed_tokens.weight def get_rope_index( self, - input_ids: Optional[torch.Tensor], - tensor_stream: Optional[TensorStream], - attention_mask: Optional[torch.Tensor], + *, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Compute MRoPE position ids from a TensorStream (or 1D fallback). + """ + Compute (position_ids_3d, rope_deltas) without TensorStream. - Returns (position_ids, rope_deltas). position_ids is (B,L,3) for MRoPE. - rope_deltas is (B,1) used to advance positions in decode. + - If `position_ids` is provided, it must be shape (B, L, 3). + - Else, if `input_ids` is provided, position ids are synthesized as (B, L, 3). + - `rope_deltas` is (B, 1) used to advance positions during decode. """ - # tensor_stream present: compute 3D coords - if tensor_stream is None and input_ids is None: - raise ValueError("`tensor_stream` or `input_ids` must be provided to compute rope indices") - if tensor_stream is not None: - pos_3d = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) + if position_ids is None: + pos_3d = torch.arange(input_ids.size(1), device=input_ids.device)[None, :, None].expand( + input_ids.size(0), -1, 3 + ) else: - pos_3d = compute_position_ids_input_ids(input_ids) - B, L, _ = pos_3d.shape + pos_3d = position_ids + if pos_3d.ndim != 3 or pos_3d.size(-1) != 3: + raise ValueError( + f"`position_ids` must have shape (batch, seq_len, 3) for MRoPE; got shape {tuple(pos_3d.shape)}." + ) - # Max position per batch across the 3 planes and sequence dimension: (B,) + B, L, _ = pos_3d.shape m_per_batch = pos_3d.amax(dim=(1, 2)) - # Sequence lengths per batch: (B,) if attention_mask is None: - seq_lens = torch.full_like(m_per_batch, L) + seq_lens = torch.full((B,), L, device=pos_3d.device, dtype=m_per_batch.dtype) else: seq_lens = attention_mask.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=m_per_batch.device) @@ -1502,33 +1381,11 @@ def prepare_inputs_for_generation( past_key_values: Optional[list[torch.FloatTensor]] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - tensor_stream: Optional[TensorStream] = None, + packed_inputs: Optional[dict[str, torch.Tensor]] = None, cache_position: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - use_cache: bool = True, **kwargs, ) -> dict[str, Any]: - """ - Prepare inputs for generation, handling TensorStream inputs properly. - """ - if cache_position is None: - seq_length = None - device = None - if input_ids is not None: - seq_length = input_ids.shape[1] - device = input_ids.device - elif inputs_embeds is not None: - seq_length = inputs_embeds.shape[1] - device = inputs_embeds.device - elif tensor_stream is not None: - _, seq_length = tensor_stream.shape - device = tensor_stream.device - if seq_length is not None: - # prepare_inputs_for_generation may be invoked outside `generate`, so synthesize the - # same cache positions that GenerationMixin would have created during prefill. - cache_position = torch.arange(seq_length, dtype=torch.long, device=device) - - # Call parent preparation model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -1536,25 +1393,15 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, cache_position=cache_position, position_ids=position_ids, - use_cache=use_cache, **kwargs, ) + if packed_inputs is None: + return model_inputs cache_position = model_inputs.get("cache_position", cache_position) - - # Handle TensorStream only for the prefill step first_step = cache_position is None or cache_position[0] == 0 - if tensor_stream is not None and first_step: - model_inputs["tensor_stream"] = tensor_stream - # Let forward rebuild MRoPE coordinates from the TensorStream - model_inputs["position_ids"] = None - else: - model_inputs["tensor_stream"] = None - - # TensorStream decode path: preserve rotary offsets from prefill; let forward rebuild positions - if tensor_stream is not None and not first_step and self.rope_deltas is not None: - model_inputs["position_ids"] = None - return model_inputs + model_inputs["packed_inputs"] = packed_inputs if first_step else None + model_inputs["position_ids"] = None return model_inputs diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index c8b9a681a68f..31c915de5b03 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -17,13 +17,10 @@ import copy import math -import re -from collections import defaultdict from collections.abc import Callable, Sequence from typing import Any, Optional, Union from ...utils.import_utils import ( - is_perceptron_available, is_torch_available, is_torchdynamo_compiling, is_torchvision_available, @@ -45,35 +42,7 @@ if is_torchvision_available(): from ..pix2struct.image_processing_pix2struct_fast import torch_extract_patches -if is_perceptron_available(): - from perceptron.tensorstream.ops import ( - compute_mrope_pos_tensor, - modality_mask, - reconstruct_tensor_stream_from_compact_dict, - tensor_stream_token_view, - ) - from perceptron.tensorstream.ops import ( - slice as ts_slice, - ) - from perceptron.tensorstream.tensorstream import ( - Event, - Stream, - TensorStream, - TextType, - VisionType, - create_stream, - group_streams, - ) -else: - ts_slice = None - Event = None - Stream = None - TensorStream = None - TextType = None - VisionType = None - create_stream = None - group_streams = None - +from enum import IntEnum from ...cache_utils import DynamicCache from ...configuration_utils import PretrainedConfig, layer_type_validation @@ -90,28 +59,45 @@ PILImageResampling, ) from ...masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, create_masks_for_generate, packed_sequence_mask_function -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS -from ...models.auto.modeling_auto import AutoModel -from ...models.auto.tokenization_auto import AutoTokenizer from ...models.qwen3.configuration_qwen3 import Qwen3Config -from ...models.qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3PreTrainedModel +from ...models.qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3Model, Qwen3PreTrainedModel from ...processing_utils import ProcessorMixin, Unpack from ...utils import TensorType, auto_docstring # Vision preprocessing constants from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD -from ...utils.generic import OutputRecorder, TransformersKwargs, can_return_tuple, check_model_inputs +from ...utils.generic import ( + OutputRecorder, + TransformersKwargs, + can_return_tuple, + check_model_inputs, +) from ..qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling from ..siglip2.configuration_siglip2 import Siglip2VisionConfig from ..siglip2.modeling_siglip2 import ( Siglip2Attention, Siglip2Encoder, Siglip2EncoderLayer, + Siglip2VisionEmbeddings, ) +class ModalityType(IntEnum): + """ + Modality identifiers for events. + + Members: + image: Vision tokens (e.g., patches). + text: Textual tokens. + """ + + image = 0 + text = 1 + + class IsaacVisionConfig(Siglip2VisionConfig): """Vision configuration for Isaac with Pixel Shuffle support. @@ -178,7 +164,7 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): resample = PILImageResampling.BILINEAR model_input_names = ["patches", "token_grids"] valid_kwargs = IsaacImageProcessorFastKwargs - unused_kwargs = ["size", "do_center_crop", "crop_size"] + unused_kwargs = ["size", "do_center_crop", "crop_size", "pad_size", "do_pad"] do_resize = True do_center_crop = False @@ -193,7 +179,6 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): image_std = list(VISION_STD) do_convert_rgb = True disable_grouping = False - size_divisor: Optional[int] = None def __init__( self, @@ -201,11 +186,6 @@ def __init__( ) -> None: super().__init__(**kwargs) - pixel_shuffle_scale = 1 if self.pixel_shuffle_scale is None else int(self.pixel_shuffle_scale) - if pixel_shuffle_scale < 1: - raise ValueError("`pixel_shuffle_scale` must be >= 1") - self.pixel_shuffle_scale = pixel_shuffle_scale - def _validate_preprocess_kwargs(self, **kwargs): # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for. kwargs.pop("do_resize", None) @@ -219,29 +199,10 @@ def resize( self, image: torch.Tensor, size: SizeDict, - interpolation: Optional[Any] = None, - antialias: bool = True, **kwargs, ) -> torch.Tensor: - if size.height is None or size.width is None: - raise ValueError("IsaacImageProcessorFast requires explicit `height` and `width` when resizing.") - - resize_mode: Any = interpolation - if hasattr(resize_mode, "value"): - resize_mode = resize_mode.value - elif hasattr(resize_mode, "name"): - resize_mode = resize_mode.name.lower() - elif resize_mode is None: - resize_mode = "bilinear" - - if isinstance(resize_mode, str): - mode_key = resize_mode.lower() - else: - mode_key = resize_mode - - resize_kwargs: dict[str, Any] = {} - if mode_key in {"linear", "bilinear", "bicubic", "trilinear"}: - resize_kwargs["align_corners"] = False + resize_kwargs: dict[str, Any] = {"align_corners": False} + resize_mode = "bilinear" return F.interpolate( image, @@ -254,10 +215,7 @@ def _preprocess( self, images: list[torch.Tensor], do_resize: bool, - size: Optional[SizeDict], interpolation: Optional[Any], - do_center_crop: bool, - crop_size: Optional[SizeDict], do_rescale: Optional[bool], rescale_factor: Optional[float], do_normalize: Optional[bool], @@ -265,8 +223,6 @@ def _preprocess( image_std: Optional[Union[float, Sequence[float]]], disable_grouping: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, - do_pad: Optional[bool] = None, - pad_size: Optional[SizeDict] = None, *, patch_size: Optional[int] = None, max_num_patches: Optional[int] = None, @@ -274,20 +230,15 @@ def _preprocess( pixel_shuffle_scale: Optional[int] = None, **kwargs, ) -> BatchFeature: - if do_center_crop: - raise ValueError("`do_center_crop` is not supported by IsaacImageProcessorFast.") - if do_pad: - raise ValueError("`do_pad` is not supported by IsaacImageProcessorFast.") - grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) - processed_patches_grouped: dict[tuple[int, ...], torch.Tensor] = {} - token_grids_grouped: dict[tuple[int, ...], torch.Tensor] = {} - virtual_dims_grouped: dict[tuple[int, ...], torch.Tensor] = {} - real_dims_grouped: dict[tuple[int, ...], torch.Tensor] = {} + + grouped_outputs = {} for shape, stacked_images in grouped_images.items(): if stacked_images.ndim != 4: - raise ValueError("Expected batched channel-first image tensors.") + raise ValueError( + f"Expected images shaped as (batch, channels, height, width); got shape {tuple(stacked_images.shape)}." + ) batch_size, channels, original_height, original_width = stacked_images.shape @@ -296,7 +247,9 @@ def _preprocess( channels = 3 if original_height * original_width > self.MAX_PIXELS: - raise ValueError(f"Image (w={original_width}, h={original_height}) > MAX=`{self.MAX_PIXELS}`") + raise ValueError( + f"Image area {original_height * original_width} (h={original_height}, w={original_width}) exceeds MAX_PIXELS={self.MAX_PIXELS}; enable resizing or provide smaller inputs." + ) target_height, target_width = get_image_size_for_max_num_patches( original_height, @@ -316,7 +269,9 @@ def _preprocess( ) else: if ((original_height % patch_size) != 0) or ((original_width % patch_size) != 0): - raise ValueError("Image dimensions must be divisible by patch_size when resize is disabled.") + raise ValueError( + f"Image dimensions (h={original_height}, w={original_width}) must be divisible by patch_size={patch_size} when resize is disabled; enable resizing or adjust the input resolution." + ) image_batch = stacked_images target_height, target_width = original_height, original_width @@ -330,10 +285,7 @@ def _preprocess( image_std=image_std, ) - nhwc_images = image_batch.permute(0, 2, 3, 1) - nhwc_images = _compute_residual_p_frames(nhwc_images, is_p_frame=[False] * batch_size) - - patches = torch_extract_patches(nhwc_images.permute(0, 3, 1, 2), patch_size, patch_size) + patches = torch_extract_patches(image_batch, patch_size, patch_size) _, height_tokens, width_tokens, _ = patches.shape token_grid = ( @@ -358,7 +310,7 @@ def _preprocess( if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): raise ValueError( - "Spatial dimensions must be divisible by pixel_shuffle_scale when pixel shuffle is enabled." + f"Token grid (h={height_tokens}, w={width_tokens}) must be divisible by pixel_shuffle_scale={pixel_shuffle_scale}; adjust resize/patch parameters or disable pixel shuffle." ) virtual_height = height_tokens // pixel_shuffle_scale virtual_width = width_tokens // pixel_shuffle_scale @@ -372,54 +324,24 @@ def _preprocess( .unsqueeze(0) .repeat(batch_size, 1) ) + grouped_outputs[shape] = (patches, token_grid, virtual_dim, real_dim) - processed_patches_grouped[shape] = patches - token_grids_grouped[shape] = token_grid - virtual_dims_grouped[shape] = virtual_dim - real_dims_grouped[shape] = real_dim - - patches_slices = reorder_images(processed_patches_grouped, grouped_images_index) - token_grid_slices = reorder_images(token_grids_grouped, grouped_images_index) - virtual_dim_slices = reorder_images(virtual_dims_grouped, grouped_images_index) - real_dim_slices = reorder_images(real_dims_grouped, grouped_images_index) - - patches_tensor = torch.stack(patches_slices, dim=0) - token_grids_tensor = torch.stack(token_grid_slices, dim=0) - virtual_dims_tensor = torch.stack(virtual_dim_slices, dim=0) - real_dims_tensor = torch.stack(real_dim_slices, dim=0) - - return BatchFeature( - data={ - "patches": patches_tensor, - "token_grids": token_grids_tensor, - "virtual_pixel_size": virtual_dims_tensor, - "real_pixel_size": real_dims_tensor, - }, - tensor_type=return_tensors, - ) - + # Helper to reorder a single item of the tuple payloads using the same grouped_images_index + def _reorder_grouped_item( + grouped: dict[tuple[int, ...], tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], + grouped_index: dict[tuple[int, ...], list[int]], + item_idx: int, + ) -> list[torch.Tensor]: + return reorder_images({k: v[item_idx] for k, v in grouped.items()}, grouped_index) -def document_mask_function_from_cu_seqlens(cu_seqlens: Optional[torch.Tensor]) -> Optional[Callable]: - """Return a mask function that blocks cross-document attention from packed ``cu_seqlens``. + keys = ("patches", "token_grids", "virtual_pixel_size", "real_pixel_size") + tensors: dict[str, torch.Tensor] = {} - The returned callable matches the signature expected by ``masking_utils`` mask factories and - yields ``True`` only when query/key positions belong to the same packed segment. - """ + for i, key in enumerate(keys): + slices = _reorder_grouped_item(grouped_outputs, grouped_images_index, i) + tensors[key] = torch.stack(slices, dim=0) - if cu_seqlens is None: - return None - - if cu_seqlens.numel() < 2: - return None - - seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() - if seq_sizes.numel() == 0: - return None - - total_tokens = int(seq_sizes.sum().item()) - seg_ids = torch.repeat_interleave(torch.arange(seq_sizes.numel(), device=cu_seqlens.device), seq_sizes) - packed_sequence_mask = seg_ids.view(1, total_tokens) - return packed_sequence_mask_function(packed_sequence_mask) + return BatchFeature(data=tensors, tensor_type=return_tensors) def create_document_attention_mask( @@ -427,16 +349,23 @@ def create_document_attention_mask( input_embeds: torch.Tensor, cu_seqlens: Optional[torch.Tensor], ) -> Optional[Union[torch.Tensor, Any]]: - """Materialize a backend-specific block-diagonal attention mask. + """ + Materialize a backend-specific block-diagonal attention mask from packed cu_seqlens. - This uses the standard `masking_utils` mask interface (same mechanism as Llama4), - so the returned object matches the selected attention backend (e.g. SDPA bool mask, - eager additive mask, or flex `BlockMask`). + Returns None if cu_seqlens is missing/degenerate. """ + if cu_seqlens is None or cu_seqlens.numel() < 2: + return None # Degenerate input: nothing to mask - mask_function = document_mask_function_from_cu_seqlens(cu_seqlens) - if mask_function is None: - return None + seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() + if seq_sizes.numel() == 0 or int(seq_sizes.sum()) == 0: + return None # All-empty segments produce no attention blocks + + seg_ids = torch.repeat_interleave( + torch.arange(seq_sizes.numel(), device=cu_seqlens.device), + seq_sizes, + ) + mask_function = packed_sequence_mask_function(seg_ids.view(1, -1)) seq_len = input_embeds.shape[1] cache_position = torch.arange(seq_len, device=input_embeds.device, dtype=torch.long) @@ -457,7 +386,7 @@ def create_document_attention_mask( ) -class IsaacVisionEmbeddings(nn.Module): +class IsaacVisionEmbeddings(Siglip2VisionEmbeddings): """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences. Isaac accepts variable-resolution vision inputs as a single packed sequence with per-image @@ -466,7 +395,7 @@ class IsaacVisionEmbeddings(nn.Module): """ def __init__(self, config: IsaacVisionConfig): - super().__init__() + super().__init__(config) self.config = config self.embed_dim = config.hidden_size self.patch_size = config.patch_size @@ -480,6 +409,7 @@ def __init__(self, config: IsaacVisionConfig): self.position_embedding_size = int(self.num_patches**0.5) self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + @check_model_inputs def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor: # Rebatch packed variable-resolution patches to resize per-image position embeddings # and track lengths for varlen attention metadata. @@ -502,65 +432,6 @@ def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> to embeddings = patch_embeds + resized_positional_embeddings return self._unpack_from_batch(embeddings, seq_lengths) - @staticmethod - def resize_positional_embeddings( - positional_embeddings: torch.Tensor, - spatial_shapes: torch.LongTensor, - max_length: int, - ) -> torch.Tensor: - """ - Resize positional embeddings to image-specific size and pad to a fixed size. - - Args: - positional_embeddings (`torch.Tensor`): - Position embeddings of shape (height, width, embed_dim) - spatial_shapes (`torch.LongTensor`): - Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to - max_length (`int`): - Maximum length of the positional embeddings to pad resized positional embeddings to - - Returns: - `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) - """ - batch_size = spatial_shapes.shape[0] - embed_dim = positional_embeddings.shape[-1] - source_dtype = positional_embeddings.dtype - - resulted_positional_embeddings = torch.empty( - (batch_size, max_length, embed_dim), - device=positional_embeddings.device, - dtype=source_dtype, - ) - - # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation - positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0) - - # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU - if positional_embeddings.device.type == "cpu": - positional_embeddings = positional_embeddings.to(torch.float32) - - for i in range(batch_size): - # (1, dim, height, width) -> (1, dim, target_height, target_width) - height, width = spatial_shapes[i] - resized_embeddings = F.interpolate( - positional_embeddings, - size=(height, width), - mode="bilinear", - align_corners=False, - antialias=True, - ) - - # (1, dim, target_height, target_width) -> (target_height * target_width, dim) - resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1) - - # Cast to original dtype - resized_embeddings = resized_embeddings.to(source_dtype) - - resulted_positional_embeddings[i, : height * width] = resized_embeddings - resulted_positional_embeddings[i, height * width :] = resized_embeddings[0] - - return resulted_positional_embeddings - def _pack_to_batch( self, seq_patches: torch.Tensor, @@ -569,59 +440,33 @@ def _pack_to_batch( """Rebatch a packed patch sequence using per-image grids to align embeddings. Args: - seq_patches (`torch.Tensor`): Packed patches of shape `(total_patches, patch_dim)`. - spatial_shapes (`torch.Tensor`): Per-image patch grids of shape `(num_images, 2)` as `(H_tokens, W_tokens)`. + seq_patches: Packed patches of shape (total_patches, patch_dim). + spatial_shapes: Per-image patch grids of shape (num_images, 2) as (H_tokens, W_tokens). Returns: - `tuple[Optional[torch.Tensor], torch.Tensor]`: A padded batch tensor shaped - `(batch, max_len, patch_dim)` plus `seq_lengths` used to form `cu_seqlens` for - variable-length attention. + (packed_pixel_values, seq_lengths) where: + - packed_pixel_values: (batch, max_len, patch_dim) padded with zeros, or None if batch_size == 0 + - seq_lengths: (batch,) lengths for each image """ - if seq_patches.ndim != 2: - raise ValueError("`seq_patches` is expected to be 2D (total_patches, patch_dim).") - if spatial_shapes.ndim != 2 or spatial_shapes.size(-1) != 2: - raise ValueError("`spatial_shapes` must have shape (num_images, 2) with (height_tokens, width_tokens).") - - seq_lengths = spatial_shapes.long().prod(dim=-1) - total_patches = int(seq_lengths.sum().item()) - if total_patches != seq_patches.size(0): - raise ValueError( - "Mismatch between packed patches and spatial shapes: got " - f"{seq_patches.size(0)} patches but spatial shapes imply {total_patches}." - ) - - batch_size = spatial_shapes.size(0) + # Per-image token counts + seq_lengths = spatial_shapes.long().prod(dim=-1) # (B,) + batch_size = int(seq_lengths.numel()) if batch_size == 0: return None, seq_lengths - max_length = int(seq_lengths.max().item()) - patch_dim = seq_patches.size(-1) - device = seq_patches.device - - packed_pixel_values = seq_patches.new_zeros((batch_size, max_length, patch_dim), device=device) - - start = 0 - for batch_idx, length in enumerate(seq_lengths.tolist()): - if length == 0: - continue - end = start + length - packed_pixel_values[batch_idx, :length] = seq_patches[start:end] - start = end - + # Split the packed sequence into per-image chunks, then pad to a batch + lengths_list = seq_lengths.tolist() + chunks = seq_patches.split(lengths_list, dim=0) + packed_pixel_values = nn.utils.rnn.pad_sequence(chunks, batch_first=True) # zero-padded by default return packed_pixel_values, seq_lengths def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor) -> torch.Tensor: """Flatten a padded batch back to packed sequence order using `seq_lengths`.""" - output_chunks: list[torch.Tensor] = [] - for batch_idx, length in enumerate(seq_lengths.tolist()): - if length == 0: - continue - output_chunks.append(embeddings[batch_idx, :length]) - - if not output_chunks: + lengths = seq_lengths.to(device=embeddings.device).tolist() + chunks = [embeddings[i, :l] for i, l in enumerate(lengths) if l > 0] + if not chunks: return embeddings.new_zeros((0, embeddings.size(-1))) - - return torch.cat(output_chunks, dim=0) + return torch.cat(chunks, dim=0) class IsaacVisionAttention(Siglip2Attention): @@ -653,11 +498,9 @@ def forward( if attn_impl != "sdpa": attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] - dropout = 0.0 if not self.training else self.dropout attention_kwargs: dict[str, Any] = { "is_causal": False, "scaling": self.scale, - "dropout": dropout, } supports_varlen = cu_seqlens is not None and attn_impl in { @@ -667,10 +510,6 @@ def forward( "paged|flash_attention_2", "paged|flash_attention_3", } - - if output_attentions and attn_impl == "eager": - attention_kwargs["output_attentions"] = True - if supports_varlen: if max_seqlen is not None: max_q = max_k = int(max_seqlen) @@ -697,9 +536,7 @@ def forward( attention_mask, **attention_kwargs, ) - attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() - attn_output = self.out_proj(attn_output) return attn_output, attn_weights @@ -756,23 +593,6 @@ def __init__(self, config: IsaacVisionConfig): super().__init__(config) self.layers = nn.ModuleList([IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) - @can_return_tuple - @check_model_inputs - def forward( - self, - inputs_embeds, - attention_mask: Optional[torch.Tensor] = None, - **kwargs: Unpack[TransformersKwargs], - ): - hidden_states = inputs_embeds - for encoder_layer in self.layers: - hidden_states = encoder_layer( - hidden_states, - attention_mask, - **kwargs, - ) - return BaseModelOutput(last_hidden_state=hidden_states) - def create_pixel_shuffle_index_map( seq_sizes: torch.Tensor, @@ -798,47 +618,29 @@ def create_pixel_shuffle_index_map( packed sequence for the j-th sub-patch that forms the i-th output token. """ - if device is None: - device = seq_sizes.device - - scale_factor = int(scale_factor) - if scale_factor < 2: - raise ValueError("`scale_factor` must be โ‰ฅ 2") - - # Safety: all spatial dims must be divisible by the scale factor - # Cannot run under torch compile fullgraph mode hence if not is_torchdynamo_compiling(): - if not ((token_grids[:, 0] % scale_factor == 0).all() and (token_grids[:, 1] % scale_factor == 0).all()): + if (token_grids % scale_factor).any(): raise AssertionError( - "Every (H,W) in `token_grids` must be divisible by " - f"scale_factor={scale_factor}, got {token_grids.tolist()}" + f"Every (H,W) in token_grids must be divisible by scale_factor={scale_factor}, got {token_grids.tolist()}" ) gather_chunks: list[torch.Tensor] = [] tok_offset = 0 - - for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist(), strict=False): - # Build the (H, W) grid of flat indices for this image - grid = torch.arange(seq_len, device=device, dtype=torch.int64) + tok_offset - grid = grid.view(h, w) # (H, W) - - # -------- identical ordering to your fixed-res routine -------- - # Step 1: split width into blocks of scale_factor - grid = grid.view(h, w // scale_factor, scale_factor) # (H, W/scale_factor, scale_factor) - # Step 2: now split height into blocks of scale_factor - grid = grid.view(h // scale_factor, scale_factor, w // scale_factor, scale_factor) - # (H/scale_factor, scale_factor, W/scale_factor, scale_factor) - # Step 3: final permutation to (H/scale_factor, W/scale_factor, scale_factor, scale_factor) - grid = grid.permute(0, 2, 1, 3).contiguous() # (H/scale_factor, W/scale_factor, scale_factor, scale_factor) - # Step 4: each (scale_factor, scale_factor) block forms one output token - gather_chunks.append(grid.reshape(-1, scale_factor * scale_factor)) - # (H*W / scale_factor**2, scale_factor**2) + for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist()): + # Flat indices for this image's packed segment + grid = torch.arange(seq_len, device=device, dtype=torch.int64).view(h, w) + tok_offset + + # Block into (H/s, W/s) groups; each group contributes s*s indices + grid = ( + grid.view(h // scale_factor, scale_factor, w // scale_factor, scale_factor) + .permute(0, 2, 1, 3) + .contiguous() + ) + gather_chunks.append(grid.view(-1, scale_factor * scale_factor)) tok_offset += seq_len - # Concatenate over all images in the packed batch - gather_idx = torch.cat(gather_chunks, dim=0) # (ฮฃ_i HแตขWแตข/scale_factor**2, scale_factor**2) - return gather_idx + return torch.cat(gather_chunks, dim=0) def pixel_shuffle_varlen( @@ -870,7 +672,9 @@ def pixel_shuffle_varlen( return_with_batch_dim = x.dim() == 3 if return_with_batch_dim: if x.size(0) != 1: - raise AssertionError("Packed sequence is expected to have batch_size == 1") + raise ValueError( + f"Packed vision sequences expect a singleton batch dimension; received batch_size={x.size(0)}." + ) embeddings = x.squeeze(0) # (seq, embed) else: embeddings = x # (seq, embed) @@ -881,7 +685,8 @@ def pixel_shuffle_varlen( # Calculate seq_sizes from token_grids seq_sizes = torch.prod(token_grids, dim=-1) - # Build index map and gather in one go + # Build a single gather index so pixel shuffle works on the packed stream + # without unpacking per-image grids. gather_idx = create_pixel_shuffle_index_map( seq_sizes=seq_sizes, token_grids=token_grids, @@ -902,6 +707,19 @@ def pixel_shuffle_varlen( class IsaacVisionTransformer(nn.Module): + """Vision tower that packs variable-resolution patches, applies varlen attention, and pixel-shuffles outputs. + + Args: + config (IsaacVisionConfig): Vision configuration with pixel-shuffle and patching parameters. + + Inputs: + packed_seq_patches (Tuple[Tensor, Tensor]): ``(patches, token_grids)`` where ``patches`` is a packed + patch sequence and ``token_grids`` holds per-image (H_tokens, W_tokens). + + Returns: + torch.Tensor: Vision embeddings after encoder + pixel shuffle, shaped ``(seq_len, hidden_size * s^2)``. + """ + _supports_sdpa = True def __init__(self, config: IsaacVisionConfig): @@ -919,12 +737,12 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): # Get embeddings from packed sequence hidden_states = self.embeddings(seq_patches, token_grids) - # Add a pseudo batch dimension for the encoder + # Add a pseudo batch dimension so we can reuse the batch-first encoder stack + # while still driving per-image cu_seqlens through the varlen attention path. hidden_states = hidden_states.unsqueeze(0) # Generate cumulative sequence lengths for variable-length attention - cu_seqlens = torch.zeros(seq_sizes.size(0) + 1, dtype=torch.int32, device=hidden_states.device) - cu_seqlens[1:] = seq_sizes.cumsum(0) + cu_seqlens = F.pad(seq_sizes.cumsum(0).to(torch.int32), (1, 0)) attention_mask = create_document_attention_mask(self.config, hidden_states, cu_seqlens) @@ -952,6 +770,8 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): class IsaacMultiModalProjector(nn.Module): + """Maps vision tower outputs to the text hidden size with a SiLU MLP.""" + def __init__(self, config: IsaacConfig): super().__init__() self.vision_hidden_size = config.vision_config.hidden_size * ( @@ -970,7 +790,17 @@ def forward(self, image_features): class IsaacVisionEmbedding(nn.Module): - """Vision embedding wrapper exposing tower and projector.""" + """Wraps the vision tower plus projection into the text hidden size. + + Args: + config (IsaacConfig): Composite config containing both vision and text settings. + + Inputs: + vision_tokens (Tuple[Tensor, Tensor]): Packed vision patches and token grids. + + Returns: + torch.Tensor: Projected vision embeddings aligned to the text hidden size. + """ _supports_sdpa = True @@ -1041,7 +871,8 @@ def get_image_size_for_max_num_patches( num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) if min_num_patches is not None and num_patches < min_num_patches: - # Scale up + # Scale up via binary search to satisfy the minimum patch budget while + # preserving divisibility by patch_size * pixel_shuffle_scale. scale_min, scale_max = 1.0, 100.0 while (scale_max - scale_min) >= eps: scale = (scale_min + scale_max) / 2 @@ -1161,58 +992,25 @@ def to_dict(self): return output -# ============================================================================ -# Processor Components -# ============================================================================ - - -def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> Event: - r"""Wrap a text into an `Event` compatible with the multimodal TensorStream. +class IsaacProcessor(ProcessorMixin): + """Processor that pairs the Isaac image processor with the Qwen2 tokenizer. Args: - tokenizer (`AutoTokenizer`): - Tokenizer used to convert text into model vocabulary ids. - text (`str`): - Plain-text fragment to encode. - time (`float`, *optional*, defaults to 0.0): - Timeline coordinate associated with the event. Both start and end times use the same value because text - segments are instantaneous in the scheduler. + image_processor: Vision preprocessor (fast) used for patch extraction. + tokenizer: Qwen2 tokenizer instance. + vision_token (str, optional): Placeholder token marking image locations. Defaults to "". + max_sequence_length (int, optional): Maximum combined text+vision tokens kept. Defaults to 16384. + rescale_factor (float, optional): Image rescale factor; defaults to 1/255. + config (IsaacConfig | dict, optional): If provided, overrides processor defaults from the model config. Returns: - `Event`: Event carrying a `(num_tokens, 1)` tensor of token ids with matching - metadata so that downstream processors can compute modality-specific embeddings. + BatchFeature: Contains ``input_ids`` and ``packed_inputs`` (patch tensors, grids, offsets, lengths, modality, positions). """ - tokens = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt").squeeze(0) - - # Calculate dimensions for the event - num_tokens = len(tokens) - dims_virtual = [num_tokens, 1] # [sequence_length, 1] - dims_real = dims_virtual.copy() - - # Ensure tokens has the right shape for tensor_stream_token_view - # It expects a 2D tensor where sum(dim=-1) gives the token IDs - if tokens.dim() == 1: - tokens = tokens.unsqueeze(-1) - - return Event( - data=tokens, - type=TextType.text, - time=(time, time), - dims_virtual=dims_virtual, - dims_real=dims_real, - idx_range=(0, num_tokens), - ) - - -# ============================================================================ -# Processor -# ============================================================================ - -class IsaacProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = ("IsaacImageProcessorFast",) tokenizer_class = ("Qwen2Tokenizer",) + pad_token_id = 151643 def __init__( self, @@ -1224,79 +1022,174 @@ def __init__( rescale_factor: Optional[float] = None, config: Optional[Union[IsaacConfig, dict]] = None, ) -> None: - if tokenizer is None: - raise ValueError("`tokenizer` must be provided to initialize IsaacProcessor.") - if isinstance(config, dict): config = IsaacConfig(**config) if config is not None: - max_sequence_length = config.max_sequence_length vision_token = config.vision_token + max_sequence_length = config.max_sequence_length rescale_factor = config.vision_rescale_factor resolved_rescale_factor = float(rescale_factor) if rescale_factor is not None else float(1 / 255) - if config is not None: config.vision_rescale_factor = resolved_rescale_factor self.image_processor = image_processor - super().__init__(image_processor, tokenizer) + self.current_processor = self.image_processor self.config = config - - # Mirror tokenizer chat template so ProcessorMixin.apply_chat_template works. self.chat_template = getattr(self.tokenizer, "chat_template", None) - self.vision_token = vision_token self.max_sequence_length = max_sequence_length - def build_event_stream_simple( - self, - text: str, - images: Optional[list[Image]] = None, - ) -> Stream: - events = [] - # Process text and images - # Find all occurrences of vision token - - pattern = re.escape(self.vision_token) - parts = re.split(f"({pattern})", text) # Keep the delimiter in the result - - image_idx = 0 - for current_time, part in enumerate(parts): - if part == self.vision_token: - # Replace vision token with image event - if images is None or image_idx >= len(images): - raise ValueError("Encountered vision token without a corresponding image.") - - features = self.image_processor( - images=images[image_idx], - return_tensors=TensorType.PYTORCH, - ) + def _pack_single(self, text: str, images: Optional[list[Image]]) -> dict[str, Optional[torch.Tensor]]: + # Parse by vision_token; interleave text segments and image segments. + segments = text.split(self.vision_token) + num_images = len(segments) - 1 + if num_images and (images is None or len(images) != num_images): + raise ValueError( + f"Expected one image per '{self.vision_token}' token: found {num_images} token(s) but received {0 if images is None else len(images)} image(s)." + ) - patches = features["patches"][0] # (H_tokens, W_tokens, embed) - virtual_dims = features["virtual_pixel_size"][0].tolist() - real_dims = features["real_pixel_size"][0].tolist() - - vision_event = Event( - data=patches.reshape(-1, patches.shape[-1]), - type=VisionType.image, - time=(current_time, current_time), - dims_virtual=virtual_dims, - dims_real=real_dims, - idx_range=(0, math.prod(virtual_dims)), + items: list[dict[str, Any]] = [] + total = 0 + + for index, segment in enumerate(segments): + if segment: + tok = ( + self.tokenizer.encode(segment, add_special_tokens=False, return_tensors="pt") + .squeeze(0) + .to(torch.long) + ) + segment_length = int(tok.numel()) + items.append({"type": "text", "segment_length": segment_length, "tok": tok}) + total += segment_length + + if index < num_images: + feat = self.image_processor(images=images[index], return_tensors=TensorType.PYTORCH) + patches = feat["patches"][0].reshape(-1, feat["patches"].shape[-1]) + + virtual_pixel_size = feat["virtual_pixel_size"][0].to(torch.long).tolist() + real_pixel_size = feat["real_pixel_size"][0].to(torch.long).tolist() + dims = tuple((virtual_pixel_size + [1, 1, 1])[:3]) # (T,H,W) in virtual space + segment_length = int(dims[0] * dims[1] * dims[2]) + + items.append( + { + "type": "image", + "segment_length": segment_length, + "dims": dims, + "patches": patches, + "grid": (int(real_pixel_size[1]), int(real_pixel_size[2])), + } ) - events.append(vision_event) - image_idx += 1 - elif part: # Non-empty text part - # tokens = self.text_processor.tokenize(part, add_special_tokens=False) - text_event = create_text_event(self.tokenizer, part, time=current_time) - events.append(text_event) + total += segment_length + + # Tail crop window. + start = max(0, total - self.max_sequence_length) + end = total + + fill_value = self.pad_token_id + base_device: Optional[torch.device] = None + position_ids, modality, input_ids = [], [], [] + vpatches, grids, vision_token_offsets, vision_token_lengths = [], [], [], [] + + global_offset = 0 + position_offset = 0 + + for item in items: + segment_length = int(item["segment_length"]) + current_window_start = max(start, global_offset) + current_window_end = min(end, global_offset + segment_length) + has_overlap = current_window_end > current_window_start + + if has_overlap and base_device is None: + base_device = item["patches"].device if item["type"] == "image" else item["tok"].device + + if has_overlap: + segment_local_start = int(current_window_start - global_offset) + segment_local_end = int(current_window_end - global_offset) + segment_local_indices = torch.arange( + segment_local_start, segment_local_end, device=base_device, dtype=torch.long + ) + segment_kept_length = segment_local_end - segment_local_start + + if item["type"] == "text": + slice_index = segment_local_indices + position_offset + zero_axis_pad = torch.zeros_like(slice_index) + position_ids.append(torch.stack((slice_index, zero_axis_pad, zero_axis_pad), -1)) + modality.append( + torch.full( + (segment_kept_length,), ModalityType.text.value, device=base_device, dtype=torch.long + ) + ) + input_ids.append(item["tok"].to(base_device)[segment_local_start:segment_local_end]) + position_offset += segment_length + else: + num_pos_slices, grid_height_tokens, grid_width_tokens = item["dims"] + hw = grid_height_tokens * grid_width_tokens + slice_index = (segment_local_indices // hw) + position_offset + rem = segment_local_indices % hw + row_index = rem // grid_width_tokens + col_index = rem % grid_width_tokens + position_ids.append(torch.stack((slice_index, row_index, col_index), -1)) + modality.append( + torch.full( + (segment_kept_length,), ModalityType.image.value, device=base_device, dtype=torch.long + ) + ) + input_ids.append( + torch.full((segment_kept_length,), fill_value, device=base_device, dtype=torch.long) + ) + + vpatches.append(item["patches"].to(base_device)) # full patches; slice later via offsets/lengths + # Record per-image slice boundaries so we can drop cropped virtual tokens + # after pixel shuffle without re-packing the entire vision stream. + grids.append(item["grid"]) + vision_token_offsets.append(segment_local_start) + vision_token_lengths.append(segment_kept_length) + + position_offset += int(num_pos_slices) + + else: + position_offset += segment_length if item["type"] == "text" else int(item["dims"][0]) + + global_offset += segment_length + + modality_tensor = ( + torch.cat(modality, 0).unsqueeze(0) + if modality + else torch.zeros((1, 0), device=base_device, dtype=torch.long) + ) + position_ids = ( + torch.cat(position_ids, 0).unsqueeze(0) + if position_ids + else torch.zeros((1, 0, 3), device=base_device, dtype=torch.long) + ) + input_ids = ( + torch.cat(input_ids, 0).unsqueeze(0) + if input_ids + else torch.zeros((1, 0), device=base_device, dtype=torch.long) + ) + + if vpatches: + vision_patches = torch.cat(vpatches, 0) + vision_token_grids = torch.tensor(grids, device=base_device, dtype=torch.long) + vision_token_offsets = torch.tensor(vision_token_offsets, device=base_device, dtype=torch.long) + vision_token_lengths = torch.tensor(vision_token_lengths, device=base_device, dtype=torch.long) + else: + vision_patches = vision_token_grids = vision_token_offsets = vision_token_lengths = None - # Create stream without scheduling (events already in order) - return create_stream(events, priority=[TextType.text, VisionType.image], schedule=True) + return { + "input_ids": input_ids, + "vision_patches": vision_patches, + "vision_token_grids": vision_token_grids, + "vision_token_offsets": vision_token_offsets, + "vision_token_lengths": vision_token_lengths, + "modality_tensor": modality_tensor, + "position_ids": position_ids, + } def __call__( self, @@ -1305,103 +1198,32 @@ def __call__( return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, **kwargs, ) -> BatchFeature: - """ - Process text and images into TensorStream format. - Args: - text: Input text or list of texts with vision tokens - images: PIL image or list of images (optional) - return_tensors: Format for output tensors - - Returns: - BatchFeature with input_ids and tensor_stream - """ - # Normalize inputs to lists - if isinstance(text, str): - texts = [text] - else: - texts = text + texts = [text] if isinstance(text, str) else text + if len(texts) != 1: + raise ValueError( + f"IsaacProcessor currently supports batch_size=1; received {len(texts)} text prompts. Split the batch and call the processor per sample." + ) + images_list = None if images is not None: - if isinstance(images, Image): - images_list = [images] - else: - images_list = images - else: - images_list = None - - if len(texts) != 1: - raise ValueError("IsaacProcessor currently supports batch_size=1") - if images_list is not None: - # Count vision tokens in text to validate image count - vision_token_count = texts[0].count(self.vision_token) - if vision_token_count != len(images_list): + images_list = [images] if isinstance(images, Image) else images + n_tok = texts[0].count(self.vision_token) + if n_tok != len(images_list): raise ValueError( - f"Number of {self.vision_token} tokens in text ({vision_token_count}) " - f"must match number of images ({len(images_list)})" + f"Expected {len(images_list)} occurrences of '{self.vision_token}' (one per provided image), but found {n_tok} in the text." ) - # Build event stream - stream = self.build_event_stream_simple( - text=texts[0], - images=images_list, - ) - - # Create TensorStream - tensor_stream = TensorStream([stream]) - - # Slice to max length if needed - _, T = tensor_stream.shape - if T > self.max_sequence_length: - tensor_stream = ts_slice(tensor_stream, start=T - self.max_sequence_length, end=T) - - # Get token view - tokens = tensor_stream_token_view(tensor_stream) - if return_tensors in (TensorType.PYTORCH, "pt"): - input_ids = torch.as_tensor(tokens, dtype=torch.long) - else: - input_ids = tokens - - data = { - "input_ids": input_ids, - "tensor_stream": tensor_stream, - } - - return BatchFeature(data=data) - - -# ============================================================================ -# Model -# ============================================================================ - - -def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: - r"""Create 3D positional indices for token input. - - Args: - input_ids (`torch.Tensor`): - Tensor of shape `(batch_size, seq_len)` containing token ids. - - Returns: - `torch.Tensor`: Positional indices with shape `(batch_size, seq_len, 3)` where each channel duplicates the - 1D position so it can be consumed by the 3-axis MRoPE rotary embedding. - """ - batch_size, seq_length = input_ids.shape - position_ids = torch.arange(seq_length, device=input_ids.device) - position_ids = position_ids.view(1, -1).expand(batch_size, -1) - position_ids = position_ids.unsqueeze(2).expand(-1, -1, 3) # Add 3D for MRoPE - return position_ids + packed = self._pack_single(texts[0], images_list) + input_ids = packed.pop("input_ids") + return BatchFeature(data={"input_ids": input_ids, "packed_inputs": packed}) class IsaacRotaryEmbedding(qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding): - EXTRA_ROPE_KEYS = {"mrope_section", "mrope_interleaved"} - def __init__(self, config: IsaacConfig, device=None): rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} - - sanitized_scaling = {k: v for k, v in rope_scaling.items() if k not in self.EXTRA_ROPE_KEYS} config_for_rope = copy.copy(rope_source_cfg) - config_for_rope.rope_scaling = sanitized_scaling if sanitized_scaling else None + config_for_rope.rope_scaling = rope_scaling init_device = device if device is not None and getattr(device, "type", None) != "meta" else None super().__init__(config_for_rope, device=init_device) @@ -1419,12 +1241,6 @@ def _resolve_mrope_section(section: Optional[list[int]], rotary_half_dim: int) - return base section = [int(v) for v in section] - if len(section) != 3: - raise ValueError("`mrope_section` must contain exactly three elements (temporal, height, width)") - if sum(section) != rotary_half_dim: - raise ValueError( - f"`mrope_section` must sum to the rotary half-dimension ({rotary_half_dim}). Received {section}." - ) return section def _combine_axes(self, tensor: torch.Tensor) -> torch.Tensor: @@ -1438,11 +1254,6 @@ def forward( modality_tensor: torch.Tensor, hidden_states: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - if position_ids.ndim != 3 or position_ids.size(-1) != 3: - raise ValueError("`position_ids` must have shape (batch, seq_len, 3) for MRoPE") - if modality_tensor.shape != position_ids.shape[:2]: - raise ValueError("`modality_tensor` must align with the first two dims of `position_ids`") - if hidden_states is None: batch, seq_len, _ = position_ids.shape hidden_states = torch.zeros( @@ -1455,31 +1266,29 @@ def forward( with torch.no_grad(): pos = position_ids.clone() - image_value = VisionType.image.value if VisionType is not None else 1 - not_spatial = modality_tensor != image_value + not_spatial = modality_tensor != ModalityType.image.value if not_spatial.any(): + # Collapse non-vision modalities to 1D positions so rotary embedding + # treats them like text tokens while keeping image tokens 3D. data_1d = pos[not_spatial][..., 0].unsqueeze(-1) pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) pos_axes = pos.permute(2, 0, 1).contiguous() cos_axes, sin_axes = super().forward(hidden_states, pos_axes) - cos_axes = cos_axes.to(hidden_states.dtype) sin_axes = sin_axes.to(hidden_states.dtype) - - cos_combined = self._combine_axes(cos_axes) - sin_combined = self._combine_axes(sin_axes) + cos_combined, sin_combined = self._combine_axes(cos_axes), self._combine_axes(sin_axes) return cos_combined, sin_combined +@auto_docstring class IsaacModel(Qwen3PreTrainedModel): supports_gradient_checkpointing = True _can_compile_fullgraph = False _supports_flex_attn = False _can_record_outputs = {"attentions": OutputRecorder(IsaacVisionAttention, index=1)} - # Expose tied-weights mapping even if empty for base model tests. all_tied_weights_keys: dict[str, str] = {} def __init__(self, config: IsaacConfig): @@ -1487,30 +1296,17 @@ def __init__(self, config: IsaacConfig): text_cfg_source = config.text_config text_cfg = copy.deepcopy(text_cfg_source) - self.text_model = AutoModel.from_config(text_cfg) - # Ensure downstream callers observe the composed config - self.text_model.config = config + self.text_model = Qwen3Model._from_config(text_cfg) + self.text_model.config = config # Ensure downstream callers observe the composed config self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) - if config.vision_config is None: - raise ValueError("IsaacConfig should always have vision_config") - self.vision_embedding = IsaacVisionEmbedding(config) self.vision_embedding._supports_sdpa = True - - # Dispatch table for TensorStream balanced embedding (text + vision) - self.embed_fns = { - TextType: self.embed_text_tokens, - VisionType: self.embed_vision, - } - - # Keep track of config attributes that downstream utilities may query directly on the model. self.max_sequence_length = config.max_sequence_length self.vision_rescale_factor = config.vision_rescale_factor self.vision_token = config.vision_token - # Initialize weights and parallel plans (including tp_plan from the text model) self.post_init() # Respect config-specified gradient checkpointing @@ -1541,118 +1337,73 @@ def embed_tokens(self, value: nn.Module) -> None: def vision_model(self) -> nn.Module: return self.vision_embedding.vision_tower - def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: - """Embed text tokens, squeezing singleton dimensions.""" - # Text events are shaped as (..., 1); squeeze the singleton index dim - h = self.text_model.embed_tokens(token_ids) - if h.dim() >= 2 and h.size(-2) == 1: - h = h[..., 0, :] - return h - - def embed_vision(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - """Embed vision tokens using the vision encoder.""" - # vision tokens is (seq_patches, token_grids) - return self.vision_embedding(vision_tokens) - - def embed_stream(self, tensor_stream: TensorStream) -> torch.Tensor: + def embed_packed_inputs( + self, input_ids: torch.Tensor, packed_inputs: dict[str, Optional[torch.Tensor]] + ) -> tuple[torch.Tensor, torch.Tensor]: """ - Embed each modality stream independently, preserving the original TensorStream - structure. + Expects input_ids for text tokens and packed_inputs containing: + - modality_tensor: (batch, seq_len) modality ids aligned to the sequence + - position_ids: (batch, seq_len, 3) MRoPE coordinates (optional) + - vision_patches: concatenated vision tokens shaped (total_tokens, embed_dim) or None + - vision_token_grids: (num_images, 2) token grid sizes or None + - vision_token_offsets: (num_images,) offsets into each image's virtual token span (optional) + - vision_token_lengths: (num_images,) surviving virtual token lengths per image (optional) """ - flat_stream = tensor_stream.flat_stream() - per_modality_stream = group_streams(flat_stream, group_fn=lambda ev: ev.type, schedule=False) - per_modality_compact_stream = {k: v.compact() for k, v in per_modality_stream.items()} - - # Collect per-event grids for vision tokens (H, W like dims sans time) - token_grids = defaultdict(list) - for stream in tensor_stream.streams: - for event in stream: - token_grids[event.type].append(event.dims(virtual=False)) - - embedded_compact = {} - for stream_type, modality_payload_tensor in per_modality_compact_stream.items(): - if stream_type.modality == VisionType: - # Build a (N_events, 2) grid tensor with spatial dims only - grids = token_grids.get(stream_type, []) - if len(grids) == 0: - input_tensor = modality_payload_tensor - else: - token_grids_tensor = torch.tensor(grids, dtype=torch.long, device=tensor_stream.device)[:, 1:] - input_tensor = (modality_payload_tensor, token_grids_tensor) - embedded_compact[stream_type] = self.embed_fns[stream_type.modality](input_tensor) - else: - embedded_compact[stream_type] = self.embed_fns[stream_type.modality](modality_payload_tensor) - - # Reconstruct a TensorStream with embedded payloads and compact - embedded_ts = reconstruct_tensor_stream_from_compact_dict(tensor_stream, embedded_compact) - h = embedded_ts.compact() # (B, T, D) - return h - - @staticmethod - def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: - return compute_position_ids_input_ids(input_ids) - - def _prepare_position_and_modality( - self, - position_ids: Optional[torch.LongTensor], - modality_tensor: Optional[torch.LongTensor], - tensor_stream: Optional[TensorStream], - inputs_embeds: torch.Tensor, - cache_position: torch.LongTensor, - ) -> tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.Tensor, torch.Tensor]: - text_value = TextType.text.value if TextType is not None else 0 - batch_size, seq_len = inputs_embeds.shape[:2] - - if modality_tensor is None: - if tensor_stream is not None: - modality_tensor = modality_mask(tensor_stream) - else: - modality_tensor = torch.full( - (batch_size, seq_len), text_value, device=inputs_embeds.device, dtype=torch.long - ) - else: - modality_tensor = modality_tensor.to(device=inputs_embeds.device, dtype=torch.long) - expected_shape = (batch_size, seq_len) - if modality_tensor.shape != torch.Size(expected_shape): - raise ValueError( - f"modality_tensor must have shape (batch_size, seq_len) {expected_shape}, " - f"but got {tuple(modality_tensor.shape)}" - ) - - if position_ids is None: - if tensor_stream is not None: - position_ids = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) - else: - position_ids = cache_position.view(1, -1).expand(modality_tensor.shape[0], -1) - - if position_ids.ndim == 2: - position_ids = position_ids.to(device=inputs_embeds.device) - position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) - - if position_ids.shape[1] != seq_len: - start_positions = position_ids[:, :1, 0] - position_ids = torch.arange(seq_len, device=inputs_embeds.device).view(1, -1) - position_ids = position_ids + start_positions - position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) - - cos, sin = self.rotary_emb( - position_ids, - modality_tensor, - hidden_states=inputs_embeds, - ) + modality = packed_inputs["modality_tensor"].to(device=input_ids.device, dtype=torch.long) + embeds = self.text_model.embed_tokens(input_ids) + + vision_patches = packed_inputs.get("vision_patches") + if vision_patches is None: + return embeds, modality + + token_grids = packed_inputs["vision_token_grids"].to(device=vision_patches.device, dtype=torch.long) + vision = self.vision_embedding((vision_patches, token_grids)) # (total_tokens, hidden) + + # per-image token counts AFTER pixel-shuffle + s = int(self.config.vision_config.pixel_shuffle_scale_factor) + sizes = token_grids.prod(-1).div(s * s, rounding_mode="floor").tolist() + offsets = packed_inputs.get("vision_token_offsets") + lengths = packed_inputs.get("vision_token_lengths") + + if offsets is not None or lengths is not None: + off = ( + offsets.to(device=vision.device, dtype=torch.long) + if offsets is not None + else torch.zeros(len(sizes), device=vision.device, dtype=torch.long) + ) + ln = ( + lengths.to(device=vision.device, dtype=torch.long) + if lengths is not None + else torch.tensor(sizes, device=vision.device, dtype=torch.long) + ) - decoder_position_ids = position_ids[..., 0] if position_ids.ndim == 3 else position_ids - return position_ids, modality_tensor, decoder_position_ids, cos, sin + # Honor per-image crop windows (after pixel shuffle) so we only splice back + # the surviving virtual tokens instead of the full vision span. + chunks = vision.split(sizes, dim=0) + picked: list[torch.Tensor] = [] + for c, n, o, l in zip(chunks, sizes, off.tolist(), ln.tolist()): + if n <= 0: + continue + o = max(0, min(int(o), n)) + l = max(0, min(int(l), n - o)) + if l: + picked.append(c[o : o + l]) + vision = torch.cat(picked, 0) if picked else vision.new_zeros((0, vision.size(-1))) + + m = modality == ModalityType.image.value + embeds = embeds.clone() + embeds[m] = vision.to(device=embeds.device, dtype=embeds.dtype) + + return embeds, modality @auto_docstring @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, - tensor_stream: Optional[TensorStream] = None, + packed_inputs: Optional[dict[str, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - modality_tensor: Optional[torch.LongTensor] = None, past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -1665,32 +1416,29 @@ def forward( Computes position embeddings once and passes them through all layers. Args: - tensor_stream (`TensorStream`, *optional*): - Packed multimodal stream of text and vision events to embed directly. Mutually exclusive with - `input_ids` and `inputs_embeds`. When provided, the method derives `position_ids` and `modality_tensor` - if they are not supplied. + packed_inputs (`dict`, *optional*): + Plain tensor payloads extracted from a TensorStream. When provided, it replaces the TensorStream path + and requires `input_ids` for text tokens (or `text_token_ids` so `input_ids` can be rebuilt). modality_tensor (`torch.LongTensor`, *optional*): Modality identifiers aligned with the embedded sequence, shaped `(batch_size, seq_len)` and containing - values from `TextType`/`VisionType`. Automatically built from `tensor_stream` or `input_ids` when - omitted. + values from `ModalityType`. Automatically built from `packed_inputs` or treated as text-only when omitted. """ output_attentions = kwargs.pop("output_attentions", None) - # Get inputs - if tensor_stream is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both tensor_stream and inputs_embeds") - if tensor_stream is None and input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + # Resolve the input source (prefer packed_inputs > ids > embeds). + modality_tensor: Optional[torch.Tensor] = None + precomputed_position_ids: Optional[torch.Tensor] = None - # Resolve the input source (TensorStream takes precedence over token ids). - if tensor_stream is not None: - inputs_embeds = self.embed_stream(tensor_stream) + if packed_inputs is not None: + inputs_embeds, modality_tensor = self.embed_packed_inputs(input_ids, packed_inputs) + precomputed_position_ids = packed_inputs.get("position_ids") + if precomputed_position_ids is not None: + precomputed_position_ids = precomputed_position_ids.to(inputs_embeds.device) elif input_ids is not None: inputs_embeds = self.text_model.embed_tokens(input_ids) - elif inputs_embeds is None: - raise ValueError("You have to specify either tensor_stream, input_ids or inputs_embeds") + device = inputs_embeds.device batch_size, seq_len = inputs_embeds.shape[:2] # Ensure cache exists when requested @@ -1700,21 +1448,36 @@ def forward( if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=device) if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long) + attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.long) - position_ids, modality_tensor, decoder_position_ids, cos, sin = self._prepare_position_and_modality( - position_ids=position_ids, - modality_tensor=modality_tensor, - tensor_stream=tensor_stream, - inputs_embeds=inputs_embeds, - cache_position=cache_position, - ) + position_ids = position_ids if position_ids is not None else precomputed_position_ids + if position_ids is None: + position_ids = cache_position.view(1, -1).expand(batch_size, -1) - # Prepare attention mask - if not isinstance(attention_mask, dict): + if modality_tensor is None: + modality_tensor = torch.full( + (batch_size, seq_len), ModalityType.text.value, device=device, dtype=torch.long + ) + else: + modality_tensor = modality_tensor.to(device=device, dtype=torch.long) + + position_ids = position_ids.to(device=device) + + if position_ids.ndim == 2: + position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) + if position_ids.shape[1] != seq_len: + start_positions = position_ids[:, :1, 0] + position_ids = torch.arange(seq_len, device=device).view(1, -1) + start_positions + position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) + + cos, sin = self.rotary_emb(position_ids, modality_tensor, hidden_states=inputs_embeds) + + decoder_position_ids = position_ids[..., 0] if position_ids.ndim == 3 else position_ids + + if not isinstance(attention_mask, dict): # Prepare attention mask attention_mask = create_masks_for_generate( config=self.config, input_embeds=inputs_embeds, @@ -1724,19 +1487,15 @@ def forward( position_ids=decoder_position_ids, ) - is_attention_mask_dict = isinstance(attention_mask, dict) - - # Initialize hidden states + is_mask_dict = isinstance(attention_mask, dict) hidden_states = inputs_embeds all_attentions = [] if output_attentions else None - for decoder_layer in self.text_model.layers: - layer_attention_mask = ( - attention_mask[decoder_layer.attention_type] if is_attention_mask_dict else attention_mask - ) - layer_outputs = decoder_layer( + for layer in self.text_model.layers: + layer_mask = attention_mask[layer.attention_type] if is_mask_dict else attention_mask + layer_outputs = layer( hidden_states, - attention_mask=layer_attention_mask, + attention_mask=layer_mask, position_ids=decoder_position_ids, past_key_values=past_key_values, use_cache=use_cache, @@ -1751,7 +1510,6 @@ def forward( if output_attentions and layer_outputs_is_tuple: all_attentions.append(layer_outputs[1]) - # Final layer norm hidden_states = self.text_model.norm(hidden_states) return BaseModelOutputWithPast( @@ -1762,6 +1520,7 @@ def forward( ) +@auto_docstring class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): """Isaac multimodal model for conditional generation.""" @@ -1775,13 +1534,15 @@ def __init__(self, config: IsaacConfig): self.model = IsaacModel(config) # Use our custom model self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - # Tracks rotary position offsets computed during a full forward pass so decode steps can reuse them. self.rope_deltas = None + @auto_docstring + @can_return_tuple + @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, - tensor_stream: Optional[TensorStream] = None, + packed_inputs: Optional[dict[str, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[list[torch.FloatTensor]] = None, @@ -1791,36 +1552,45 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | CausalLMOutputWithPast: - r""" - Forward pass for conditional generation supporting both standard inputs and TensorStream. + """Run multimodal CausalLM forward, accepting packed vision/text inputs. - tensor_stream (`TensorStream`, *optional*): - Packed multimodal stream (text, vision, audio tokens) that already encodes spatial metadata. When provided, - the model derives embeddings, modality masks, and 3D rotary coordinates directly from the stream instead of - `input_ids`. - """ + Args: + input_ids: Text token ids. + packed_inputs (`dict`, *optional*): + Packed vision/text payload from ``IsaacProcessor`` containing modality ids, MRoPE position ids, and + vision patch tensors/grids (with optional offsets/lengths) used to rebuild embeddings. + attention_mask: Attention mask or mask dict; created if not provided. + position_ids: Optional 3D MRoPE positions; auto-derived when absent. + past_key_values: Cache for decoding. + inputs_embeds: Precomputed embeddings (bypass embedding layer). + labels: Target ids for computing language modeling loss. + use_cache: Whether to return caches. + cache_position: Positions for cache-aware generation. + Returns: + CausalLMOutputWithPast: logits, optional loss, caches, hidden states, attentions. + """ output_attentions = kwargs.pop("output_attentions", None) - # Don't compute embeddings here - let the inner model handle it - if tensor_stream is not None: - input_ids = None - if input_ids is None and inputs_embeds is None and tensor_stream is None: - raise ValueError("Either input_ids, inputs_embeds, or tensor_stream must be provided.") + if position_ids is None and packed_inputs is not None: + pos_3d = packed_inputs.get("position_ids") + if pos_3d is not None: + position_ids, self.rope_deltas = self.get_rope_index( + position_ids=pos_3d, + attention_mask=attention_mask, + ) - # Record rope deltas on prefill when TensorStream is provided; leave position_ids building to IsaacModel. - if position_ids is None and tensor_stream is not None: - position_ids, self.rope_deltas = self.get_rope_index(input_ids, tensor_stream, attention_mask) elif position_ids is None and cache_position is not None and self.rope_deltas is not None: - # Decode continuation after TensorStream prefill: advance positions using cached rope offsets. if input_ids is not None: - base_position_ids = compute_position_ids_input_ids(input_ids) + base_position_ids = torch.arange(input_ids.size(1), device=input_ids.device)[None, :, None].expand( + input_ids.size(0), -1, 3 + ) else: - if inputs_embeds is None: - raise ValueError("inputs_embeds must be provided when input_ids is None during decode") batch_size, seq_len = inputs_embeds.shape[:2] dummy_ids = torch.zeros((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long) - base_position_ids = compute_position_ids_input_ids(dummy_ids) + base_position_ids = torch.arange(dummy_ids.size(1), device=dummy_ids.device)[None, :, None].expand( + dummy_ids.size(0), -1, 3 + ) rope_delta = (cache_position[0] + self.rope_deltas).to(base_position_ids.device) if not isinstance(rope_delta, int): @@ -1829,10 +1599,9 @@ def forward( outputs = self.model( input_ids=input_ids, - tensor_stream=tensor_stream, + packed_inputs=packed_inputs, attention_mask=attention_mask, position_ids=position_ids, - modality_tensor=None, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -1859,43 +1628,44 @@ def forward( def set_input_embeddings(self, value: nn.Module) -> None: self.model.set_input_embeddings(value) vocab_size = getattr(value, "num_embeddings", None) - if vocab_size is not None: - self.config.vocab_size = vocab_size - self.model.config.vocab_size = vocab_size - if hasattr(self.model, "text_model"): - self.model.text_model.config.vocab_size = vocab_size - if self.lm_head.weight.shape[0] != vocab_size: - self.lm_head = nn.Linear(self.config.hidden_size, vocab_size, bias=False) - if hasattr(self.model, "embed_tokens"): - self.lm_head.weight = self.model.text_model.embed_tokens.weight + self.config.vocab_size = vocab_size + self.model.config.vocab_size = vocab_size + self.model.text_model.config.vocab_size = vocab_size + if self.lm_head.weight.shape[0] != vocab_size: + self.lm_head = nn.Linear(self.config.hidden_size, vocab_size, bias=False) + self.lm_head.weight = self.model.text_model.embed_tokens.weight def get_rope_index( self, - input_ids: Optional[torch.Tensor], - tensor_stream: Optional[TensorStream], - attention_mask: Optional[torch.Tensor], + *, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Compute MRoPE position ids from a TensorStream (or 1D fallback). + """ + Compute (position_ids_3d, rope_deltas) without TensorStream. - Returns (position_ids, rope_deltas). position_ids is (B,L,3) for MRoPE. - rope_deltas is (B,1) used to advance positions in decode. + - If `position_ids` is provided, it must be shape (B, L, 3). + - Else, if `input_ids` is provided, position ids are synthesized as (B, L, 3). + - `rope_deltas` is (B, 1) used to advance positions during decode. """ - # tensor_stream present: compute 3D coords - if tensor_stream is None and input_ids is None: - raise ValueError("`tensor_stream` or `input_ids` must be provided to compute rope indices") - if tensor_stream is not None: - pos_3d = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) + if position_ids is None: + pos_3d = torch.arange(input_ids.size(1), device=input_ids.device)[None, :, None].expand( + input_ids.size(0), -1, 3 + ) else: - pos_3d = compute_position_ids_input_ids(input_ids) - B, L, _ = pos_3d.shape + pos_3d = position_ids + if pos_3d.ndim != 3 or pos_3d.size(-1) != 3: + raise ValueError( + f"`position_ids` must have shape (batch, seq_len, 3) for MRoPE; got shape {tuple(pos_3d.shape)}." + ) - # Max position per batch across the 3 planes and sequence dimension: (B,) + B, L, _ = pos_3d.shape m_per_batch = pos_3d.amax(dim=(1, 2)) - # Sequence lengths per batch: (B,) if attention_mask is None: - seq_lens = torch.full_like(m_per_batch, L) + seq_lens = torch.full((B,), L, device=pos_3d.device, dtype=m_per_batch.dtype) else: seq_lens = attention_mask.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=m_per_batch.device) @@ -1908,33 +1678,11 @@ def prepare_inputs_for_generation( past_key_values: Optional[list[torch.FloatTensor]] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - tensor_stream: Optional[TensorStream] = None, + packed_inputs: Optional[dict[str, torch.Tensor]] = None, cache_position: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - use_cache: bool = True, **kwargs, ) -> dict[str, Any]: - """ - Prepare inputs for generation, handling TensorStream inputs properly. - """ - if cache_position is None: - seq_length = None - device = None - if input_ids is not None: - seq_length = input_ids.shape[1] - device = input_ids.device - elif inputs_embeds is not None: - seq_length = inputs_embeds.shape[1] - device = inputs_embeds.device - elif tensor_stream is not None: - _, seq_length = tensor_stream.shape - device = tensor_stream.device - if seq_length is not None: - # prepare_inputs_for_generation may be invoked outside `generate`, so synthesize the - # same cache positions that GenerationMixin would have created during prefill. - cache_position = torch.arange(seq_length, dtype=torch.long, device=device) - - # Call parent preparation model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -1942,25 +1690,15 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, cache_position=cache_position, position_ids=position_ids, - use_cache=use_cache, **kwargs, ) + if packed_inputs is None: + return model_inputs cache_position = model_inputs.get("cache_position", cache_position) - - # Handle TensorStream only for the prefill step first_step = cache_position is None or cache_position[0] == 0 - if tensor_stream is not None and first_step: - model_inputs["tensor_stream"] = tensor_stream - # Let forward rebuild MRoPE coordinates from the TensorStream - model_inputs["position_ids"] = None - else: - model_inputs["tensor_stream"] = None - - # TensorStream decode path: preserve rotary offsets from prefill; let forward rebuild positions - if tensor_stream is not None and not first_step and self.rope_deltas is not None: - model_inputs["position_ids"] = None - return model_inputs + model_inputs["packed_inputs"] = packed_inputs if first_step else None + model_inputs["position_ids"] = None return model_inputs @@ -1969,19 +1707,6 @@ def can_generate(cls) -> bool: return True -def _compute_residual_p_frames(frames: torch.Tensor, is_p_frame: list[bool]) -> torch.Tensor: - """Compute residuals for P-frames to stay in sync with the training pipeline.""" - if not any(is_p_frame): - return frames - - frame_indices = torch.arange(len(is_p_frame), device=frames.device) - i_frame_mask = torch.tensor([not flag for flag in is_p_frame], device=frames.device) - last_i_indices = torch.cummax((i_frame_mask * (1 + frame_indices)), dim=0).values.long() - 1 - p_indices = frame_indices[torch.tensor(is_p_frame, device=frames.device)] - frames[p_indices] = frames[p_indices] - frames[last_i_indices[p_indices]] - return frames - - __all__ = [ "IsaacConfig", "IsaacModel", diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index df90ae550756..6c75ef572c6b 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -19,20 +19,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math -import re -from typing import Optional, Union +from typing import Any, Optional, Union from ...feature_extraction_utils import BatchFeature -from ...models.auto.tokenization_auto import AutoTokenizer from ...processing_utils import ProcessorMixin from ...utils import TensorType -from ...utils.import_utils import ( - is_perceptron_available, - is_torch_available, - is_vision_available, -) +from ...utils.import_utils import is_torch_available, is_vision_available from .configuration_isaac import IsaacConfig +from .modeling_isaac import ModalityType if is_torch_available(): @@ -44,73 +38,26 @@ else: Image = None -if is_perceptron_available(): - from perceptron.tensorstream.ops import slice as ts_slice - from perceptron.tensorstream.ops import tensor_stream_token_view - from perceptron.tensorstream.tensorstream import Event, Stream, TensorStream, TextType, VisionType, create_stream -else: - ts_slice = None - Event = None - Stream = None - TensorStream = None - TextType = None - VisionType = None - create_stream = None - group_streams = None - - -# ============================================================================ -# Processor Components -# ============================================================================ - -def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> Event: - r"""Wrap a text into an `Event` compatible with the multimodal TensorStream. +class IsaacProcessor(ProcessorMixin): + """Processor that pairs the Isaac image processor with the Qwen2 tokenizer. Args: - tokenizer (`AutoTokenizer`): - Tokenizer used to convert text into model vocabulary ids. - text (`str`): - Plain-text fragment to encode. - time (`float`, *optional*, defaults to 0.0): - Timeline coordinate associated with the event. Both start and end times use the same value because text - segments are instantaneous in the scheduler. + image_processor: Vision preprocessor (fast) used for patch extraction. + tokenizer: Qwen2 tokenizer instance. + vision_token (str, optional): Placeholder token marking image locations. Defaults to "". + max_sequence_length (int, optional): Maximum combined text+vision tokens kept. Defaults to 16384. + rescale_factor (float, optional): Image rescale factor; defaults to 1/255. + config (IsaacConfig | dict, optional): If provided, overrides processor defaults from the model config. Returns: - `Event`: Event carrying a `(num_tokens, 1)` tensor of token ids with matching - metadata so that downstream processors can compute modality-specific embeddings. + BatchFeature: Contains ``input_ids`` and ``packed_inputs`` (patch tensors, grids, offsets, lengths, modality, positions). """ - tokens = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt").squeeze(0) - - # Calculate dimensions for the event - num_tokens = len(tokens) - dims_virtual = [num_tokens, 1] # [sequence_length, 1] - dims_real = dims_virtual.copy() - - # Ensure tokens has the right shape for tensor_stream_token_view - # It expects a 2D tensor where sum(dim=-1) gives the token IDs - if tokens.dim() == 1: - tokens = tokens.unsqueeze(-1) - return Event( - data=tokens, - type=TextType.text, - time=(time, time), - dims_virtual=dims_virtual, - dims_real=dims_real, - idx_range=(0, num_tokens), - ) - - -# ============================================================================ -# Processor -# ============================================================================ - - -class IsaacProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = ("IsaacImageProcessorFast",) tokenizer_class = ("Qwen2Tokenizer",) + pad_token_id = 151643 def __init__( self, @@ -122,79 +69,174 @@ def __init__( rescale_factor: Optional[float] = None, config: Optional[Union[IsaacConfig, dict]] = None, ) -> None: - if tokenizer is None: - raise ValueError("`tokenizer` must be provided to initialize IsaacProcessor.") - if isinstance(config, dict): config = IsaacConfig(**config) if config is not None: - max_sequence_length = config.max_sequence_length vision_token = config.vision_token + max_sequence_length = config.max_sequence_length rescale_factor = config.vision_rescale_factor resolved_rescale_factor = float(rescale_factor) if rescale_factor is not None else float(1 / 255) - if config is not None: config.vision_rescale_factor = resolved_rescale_factor self.image_processor = image_processor - super().__init__(image_processor, tokenizer) + self.current_processor = self.image_processor self.config = config - - # Mirror tokenizer chat template so ProcessorMixin.apply_chat_template works. self.chat_template = getattr(self.tokenizer, "chat_template", None) - self.vision_token = vision_token self.max_sequence_length = max_sequence_length - def build_event_stream_simple( - self, - text: str, - images: Optional[list[Image]] = None, - ) -> Stream: - events = [] - # Process text and images - # Find all occurrences of vision token - - pattern = re.escape(self.vision_token) - parts = re.split(f"({pattern})", text) # Keep the delimiter in the result - - image_idx = 0 - for current_time, part in enumerate(parts): - if part == self.vision_token: - # Replace vision token with image event - if images is None or image_idx >= len(images): - raise ValueError("Encountered vision token without a corresponding image.") - - features = self.image_processor( - images=images[image_idx], - return_tensors=TensorType.PYTORCH, + def _pack_single(self, text: str, images: Optional[list[Image]]) -> dict[str, Optional[torch.Tensor]]: + # Parse by vision_token; interleave text segments and image segments. + segments = text.split(self.vision_token) + num_images = len(segments) - 1 + if num_images and (images is None or len(images) != num_images): + raise ValueError( + f"Expected one image per '{self.vision_token}' token: found {num_images} token(s) but received {0 if images is None else len(images)} image(s)." + ) + + items: list[dict[str, Any]] = [] + total = 0 + + for index, segment in enumerate(segments): + if segment: + tok = ( + self.tokenizer.encode(segment, add_special_tokens=False, return_tensors="pt") + .squeeze(0) + .to(torch.long) ) - - patches = features["patches"][0] # (H_tokens, W_tokens, embed) - virtual_dims = features["virtual_pixel_size"][0].tolist() - real_dims = features["real_pixel_size"][0].tolist() - - vision_event = Event( - data=patches.reshape(-1, patches.shape[-1]), - type=VisionType.image, - time=(current_time, current_time), - dims_virtual=virtual_dims, - dims_real=real_dims, - idx_range=(0, math.prod(virtual_dims)), + segment_length = int(tok.numel()) + items.append({"type": "text", "segment_length": segment_length, "tok": tok}) + total += segment_length + + if index < num_images: + feat = self.image_processor(images=images[index], return_tensors=TensorType.PYTORCH) + patches = feat["patches"][0].reshape(-1, feat["patches"].shape[-1]) + + virtual_pixel_size = feat["virtual_pixel_size"][0].to(torch.long).tolist() + real_pixel_size = feat["real_pixel_size"][0].to(torch.long).tolist() + dims = tuple((virtual_pixel_size + [1, 1, 1])[:3]) # (T,H,W) in virtual space + segment_length = int(dims[0] * dims[1] * dims[2]) + + items.append( + { + "type": "image", + "segment_length": segment_length, + "dims": dims, + "patches": patches, + "grid": (int(real_pixel_size[1]), int(real_pixel_size[2])), + } ) - events.append(vision_event) - image_idx += 1 - elif part: # Non-empty text part - # tokens = self.text_processor.tokenize(part, add_special_tokens=False) - text_event = create_text_event(self.tokenizer, part, time=current_time) - events.append(text_event) + total += segment_length + + # Tail crop window. + start = max(0, total - self.max_sequence_length) + end = total + + fill_value = self.pad_token_id + base_device: Optional[torch.device] = None + position_ids, modality, input_ids = [], [], [] + vpatches, grids, vision_token_offsets, vision_token_lengths = [], [], [], [] + + global_offset = 0 + position_offset = 0 + + for item in items: + segment_length = int(item["segment_length"]) + current_window_start = max(start, global_offset) + current_window_end = min(end, global_offset + segment_length) + has_overlap = current_window_end > current_window_start + + if has_overlap and base_device is None: + base_device = item["patches"].device if item["type"] == "image" else item["tok"].device + + if has_overlap: + segment_local_start = int(current_window_start - global_offset) + segment_local_end = int(current_window_end - global_offset) + segment_local_indices = torch.arange( + segment_local_start, segment_local_end, device=base_device, dtype=torch.long + ) + segment_kept_length = segment_local_end - segment_local_start + + if item["type"] == "text": + slice_index = segment_local_indices + position_offset + zero_axis_pad = torch.zeros_like(slice_index) + position_ids.append(torch.stack((slice_index, zero_axis_pad, zero_axis_pad), -1)) + modality.append( + torch.full( + (segment_kept_length,), ModalityType.text.value, device=base_device, dtype=torch.long + ) + ) + input_ids.append(item["tok"].to(base_device)[segment_local_start:segment_local_end]) + position_offset += segment_length + else: + num_pos_slices, grid_height_tokens, grid_width_tokens = item["dims"] + hw = grid_height_tokens * grid_width_tokens + slice_index = (segment_local_indices // hw) + position_offset + rem = segment_local_indices % hw + row_index = rem // grid_width_tokens + col_index = rem % grid_width_tokens + position_ids.append(torch.stack((slice_index, row_index, col_index), -1)) + modality.append( + torch.full( + (segment_kept_length,), ModalityType.image.value, device=base_device, dtype=torch.long + ) + ) + input_ids.append( + torch.full((segment_kept_length,), fill_value, device=base_device, dtype=torch.long) + ) + + vpatches.append(item["patches"].to(base_device)) # full patches; slice later via offsets/lengths + # Record per-image slice boundaries so we can drop cropped virtual tokens + # after pixel shuffle without re-packing the entire vision stream. + grids.append(item["grid"]) + vision_token_offsets.append(segment_local_start) + vision_token_lengths.append(segment_kept_length) + + position_offset += int(num_pos_slices) + + else: + position_offset += segment_length if item["type"] == "text" else int(item["dims"][0]) - # Create stream without scheduling (events already in order) - return create_stream(events, priority=[TextType.text, VisionType.image], schedule=True) + global_offset += segment_length + + modality_tensor = ( + torch.cat(modality, 0).unsqueeze(0) + if modality + else torch.zeros((1, 0), device=base_device, dtype=torch.long) + ) + position_ids = ( + torch.cat(position_ids, 0).unsqueeze(0) + if position_ids + else torch.zeros((1, 0, 3), device=base_device, dtype=torch.long) + ) + input_ids = ( + torch.cat(input_ids, 0).unsqueeze(0) + if input_ids + else torch.zeros((1, 0), device=base_device, dtype=torch.long) + ) + + if vpatches: + vision_patches = torch.cat(vpatches, 0) + vision_token_grids = torch.tensor(grids, device=base_device, dtype=torch.long) + vision_token_offsets = torch.tensor(vision_token_offsets, device=base_device, dtype=torch.long) + vision_token_lengths = torch.tensor(vision_token_lengths, device=base_device, dtype=torch.long) + else: + vision_patches = vision_token_grids = vision_token_offsets = vision_token_lengths = None + + return { + "input_ids": input_ids, + "vision_patches": vision_patches, + "vision_token_grids": vision_token_grids, + "vision_token_offsets": vision_token_offsets, + "vision_token_lengths": vision_token_lengths, + "modality_tensor": modality_tensor, + "position_ids": position_ids, + } def __call__( self, @@ -203,68 +245,24 @@ def __call__( return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, **kwargs, ) -> BatchFeature: - """ - Process text and images into TensorStream format. - Args: - text: Input text or list of texts with vision tokens - images: PIL image or list of images (optional) - return_tensors: Format for output tensors - - Returns: - BatchFeature with input_ids and tensor_stream - """ - # Normalize inputs to lists - if isinstance(text, str): - texts = [text] - else: - texts = text + texts = [text] if isinstance(text, str) else text + if len(texts) != 1: + raise ValueError( + f"IsaacProcessor currently supports batch_size=1; received {len(texts)} text prompts. Split the batch and call the processor per sample." + ) + images_list = None if images is not None: - if isinstance(images, Image): - images_list = [images] - else: - images_list = images - else: - images_list = None - - if len(texts) != 1: - raise ValueError("IsaacProcessor currently supports batch_size=1") - if images_list is not None: - # Count vision tokens in text to validate image count - vision_token_count = texts[0].count(self.vision_token) - if vision_token_count != len(images_list): + images_list = [images] if isinstance(images, Image) else images + n_tok = texts[0].count(self.vision_token) + if n_tok != len(images_list): raise ValueError( - f"Number of {self.vision_token} tokens in text ({vision_token_count}) " - f"must match number of images ({len(images_list)})" + f"Expected {len(images_list)} occurrences of '{self.vision_token}' (one per provided image), but found {n_tok} in the text." ) - # Build event stream - stream = self.build_event_stream_simple( - text=texts[0], - images=images_list, - ) - - # Create TensorStream - tensor_stream = TensorStream([stream]) - - # Slice to max length if needed - _, T = tensor_stream.shape - if T > self.max_sequence_length: - tensor_stream = ts_slice(tensor_stream, start=T - self.max_sequence_length, end=T) - - # Get token view - tokens = tensor_stream_token_view(tensor_stream) - if return_tensors in (TensorType.PYTORCH, "pt"): - input_ids = torch.as_tensor(tokens, dtype=torch.long) - else: - input_ids = tokens - - data = { - "input_ids": input_ids, - "tensor_stream": tensor_stream, - } - - return BatchFeature(data=data) + packed = self._pack_single(texts[0], images_list) + input_ids = packed.pop("input_ids") + return BatchFeature(data={"input_ids": input_ids, "packed_inputs": packed}) __all__ = ["IsaacProcessor"] diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py index 6425edfd6110..24437ec57221 100644 --- a/tests/models/isaac/test_modeling_isaac.py +++ b/tests/models/isaac/test_modeling_isaac.py @@ -36,12 +36,10 @@ is_torch_available, ) from transformers.image_utils import load_image -from transformers.masking_utils import eager_mask, sdpa_mask from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessorFast from transformers.models.isaac.modeling_isaac import ( IsaacVisionAttention, IsaacVisionConfig, - document_mask_function_from_cu_seqlens, ) from transformers.models.isaac.processing_isaac import IsaacProcessor from transformers.testing_utils import ( @@ -153,74 +151,6 @@ def _rounded(value: torch.Tensor | float) -> float: } -@require_torch -class IsaacDocumentMaskingTest(unittest.TestCase): - def test_document_mask_function_from_cu_seqlens(self): - cu_seqlens = torch.tensor([0, 3, 5], dtype=torch.int32) - mask_fn = document_mask_function_from_cu_seqlens(cu_seqlens) - - self.assertIsNotNone(mask_fn) - # Same document (indices 1 and 2) - self.assertTrue(mask_fn(0, 0, 1, 2)) - # Cross-document (index 1 in first doc, 3 in second doc) - self.assertFalse(mask_fn(0, 0, 1, 3)) - # Same second document (indices 3 and 4) - self.assertTrue(mask_fn(0, 0, 4, 3)) - - def test_document_mask_function_materializes_with_masking_utils(self): - cu_seqlens = torch.tensor([0, 2, 4], dtype=torch.int32) - total_tokens = 4 - mask_fn = document_mask_function_from_cu_seqlens(cu_seqlens) - - cache_position = torch.arange(total_tokens, device=cu_seqlens.device, dtype=torch.long) - expected_bool = torch.tensor( - [ - [ - [ - [True, True, False, False], - [True, True, False, False], - [False, False, True, True], - [False, False, True, True], - ] - ] - ], - device=cu_seqlens.device, - ) - - sdpa = sdpa_mask( - batch_size=1, - cache_position=cache_position, - kv_length=total_tokens, - kv_offset=0, - mask_function=mask_fn, - attention_mask=None, - allow_is_causal_skip=False, - allow_is_bidirectional_skip=False, - allow_torch_fix=False, - use_vmap=False, - ) - # sdpa_mask returns True for allowed positions; SDPA expects True to mean "mask out" - self.assertTrue(torch.equal(sdpa, expected_bool)) - - eager = eager_mask( - batch_size=1, - cache_position=cache_position, - kv_length=total_tokens, - kv_offset=0, - mask_function=mask_fn, - attention_mask=None, - allow_is_bidirectional_skip=False, - use_vmap=False, - dtype=torch.float32, - ) - expected_additive = torch.where( - expected_bool, - torch.tensor(0.0, device=cu_seqlens.device, dtype=torch.float32), - torch.tensor(torch.finfo(torch.float32).min, device=cu_seqlens.device, dtype=torch.float32), - ) - self.assertTrue(torch.equal(eager, expected_additive)) - - def create_isaac_processor( tokenizer, isaac_config, @@ -481,6 +411,10 @@ def test_assisted_decoding_matches_greedy_search_0_random(self): def test_assisted_decoding_matches_greedy_search_1_same(self): pass + @unittest.skip(reason="Unsupported") + def test_flash_attn_kernels_inference_equivalence(self): + pass + @unittest.skip(reason="Assisted decoding not supported; Qwen3 backbone does not implement returning attentions") def test_assisted_decoding_sample(self): pass @@ -507,33 +441,6 @@ def test_model_forward(self): (self.model_tester.batch_size, self.model_tester.seq_length, config.hidden_size), ) - @require_tensorstream - def test_modality_tensor_requires_matching_shape(self): - config, input_ids, attention_mask, _ = self.model_tester.prepare_config_and_inputs() - model = IsaacModel(config).to(torch_device) - model.eval() - - modality_tensor = torch.zeros( - (self.model_tester.batch_size, self.model_tester.seq_length), - device=torch_device, - dtype=torch.long, - ) - with torch.no_grad(): - result = model(input_ids=input_ids, attention_mask=attention_mask, modality_tensor=modality_tensor) - - self.assertEqual( - result.last_hidden_state.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, config.hidden_size), - ) - - bad_modality_tensor = torch.zeros( - (self.model_tester.batch_size, self.model_tester.seq_length + 1), - device=torch_device, - dtype=torch.long, - ) - with self.assertRaisesRegex(ValueError, "modality_tensor must have shape"): - model(input_ids=input_ids, attention_mask=attention_mask, modality_tensor=bad_modality_tensor) - @require_tensorstream def test_for_conditional_generation(self): config, input_ids, attention_mask, labels = self.model_tester.prepare_config_and_inputs() @@ -549,15 +456,6 @@ def test_for_conditional_generation(self): ) self.assertIsNotNone(result.loss) - def test_prepare_inputs_for_generation(self): - config, input_ids, attention_mask, _ = self.model_tester.prepare_config_and_inputs() - model = IsaacForConditionalGeneration(config) - model.to(torch_device) - - prepared_inputs = model.prepare_inputs_for_generation(input_ids=input_ids, attention_mask=attention_mask) - self.assertIn("input_ids", prepared_inputs) - self.assertIn("position_ids", prepared_inputs) - @require_tensorstream def test_isaac_for_conditional_generation_initialization(self): config = self.model_tester.get_config() @@ -567,7 +465,6 @@ def test_isaac_for_conditional_generation_initialization(self): self.assertTrue(hasattr(model, "model")) self.assertTrue(hasattr(model, "lm_head")) self.assertTrue(hasattr(model.model, "vision_embedding")) - self.assertTrue(hasattr(model.model, "embed_fns")) input_ids = torch.randint(0, config.vocab_size, (1, 10), device=torch_device, dtype=torch.long) with torch.no_grad(): @@ -589,46 +486,115 @@ def test_isaac_for_conditional_generation_loss_and_generate_flag(self): self.assertEqual(outputs.loss.ndim, 0) self.assertEqual(outputs.logits.shape, (batch_size, seq_len, config.vocab_size)) - @require_vision - @require_tensorstream - def test_isaac_generation_with_tensor_stream(self): - config = self.model_tester.get_config() - tokenizer = SimpleIsaacTokenizer() - image_processor = IsaacImageProcessorFast( - patch_size=config.vision_config.patch_size, - max_num_patches=config.vision_config.num_patches, - pixel_shuffle_scale=config.vision_config.pixel_shuffle_scale_factor, - rescale_factor=config.vision_rescale_factor, - ) - processor = IsaacProcessor( - image_processor=image_processor, - tokenizer=tokenizer, - config=config, + +@require_torch +@require_flash_attn +class IsaacAttentionDtypeTest(unittest.TestCase): + def _make_config(self): + return IsaacVisionConfig( + hidden_size=32, + intermediate_size=64, + num_hidden_layers=1, + num_attention_heads=4, + num_channels=3, + num_patches=64, + patch_size=4, + attention_dropout=0.0, + pixel_shuffle_scale_factor=1, ) - model = IsaacForConditionalGeneration(config).to(torch_device) - model.eval() + def _skip_if_no_cuda_bf16(self): + if not torch.cuda.is_available(): + pytest.skip("CUDA required for flash attention dtype/parity tests.") + if not torch.cuda.is_bf16_supported(): + pytest.skip("CUDA bfloat16 support required.") - messages = [{"role": "user", "content": "Hello there!"}] - prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - processed = processor(text=prompt, images=None, return_tensors="pt") - - input_ids = processed["input_ids"].to(torch_device) - tensor_stream = processed["tensor_stream"].to(torch_device) - generated = model.generate( - input_ids=input_ids, - tensor_stream=tensor_stream, - max_new_tokens=5, - do_sample=False, - pad_token_id=processor.tokenizer.pad_token_id, - eos_token_id=processor.tokenizer.eos_token_id, - ) + def test_flash_attention_matches_weight_dtype_bf16(self): + self._skip_if_no_cuda_bf16() + torch.manual_seed(0) + + device = torch.device("cuda") + config = self._make_config() + config._attn_implementation = "flash_attention_2" + + attn = IsaacVisionAttention(config).to(device=device, dtype=torch.bfloat16).eval() + + hidden_states = torch.randn(2, 4, config.hidden_size, device=device, dtype=torch.bfloat16) + + with torch.no_grad(): + attn_output, _ = attn(hidden_states) - self.assertEqual(generated.shape[0], 1) - self.assertGreaterEqual(generated.shape[1], input_ids.shape[1]) - decoded_prompt = processor.tokenizer.decode(generated[0], skip_special_tokens=True) - self.assertIsInstance(decoded_prompt, str) - self.assertNotEqual(decoded_prompt.strip(), "") + assert attn_output.dtype == attn.out_proj.weight.dtype + assert attn_output.dtype == hidden_states.dtype + + def test_flash_attention_matches_weight_dtype_bf16_with_padding(self): + self._skip_if_no_cuda_bf16() + torch.manual_seed(0) + + device = torch.device("cuda") + config = self._make_config() + config._attn_implementation = "flash_attention_2" + + attn = IsaacVisionAttention(config).to(device=device, dtype=torch.bfloat16).eval() + + hidden_states = torch.randn(2, 4, config.hidden_size, device=device, dtype=torch.bfloat16) + attention_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]], device=device, dtype=torch.bool) + + with torch.no_grad(): + attn_output, _ = attn(hidden_states, attention_mask=attention_mask) + + assert attn_output.dtype == attn.out_proj.weight.dtype + assert attn_output.dtype == hidden_states.dtype + + def test_flash_attention_matches_weight_dtype_bf16_with_cu_seqlens(self): + self._skip_if_no_cuda_bf16() + torch.manual_seed(0) + + device = torch.device("cuda") + config = self._make_config() + config._attn_implementation = "flash_attention_2" + + attn = IsaacVisionAttention(config).to(device=device, dtype=torch.bfloat16).eval() + + hidden_states = torch.randn(1, 5, config.hidden_size, device=device, dtype=torch.bfloat16) + cu_seqlens = torch.tensor([0, 3, 5], device=device, dtype=torch.int32) + + with torch.no_grad(): + attn_output, _ = attn(hidden_states, cu_seqlens=cu_seqlens, max_seqlen=3) + + assert attn_output.dtype == attn.out_proj.weight.dtype + assert attn_output.dtype == hidden_states.dtype + + def test_flash_attention_parity_with_sdpa_bf16(self): + self._skip_if_no_cuda_bf16() + torch.manual_seed(0) + + device = torch.device("cuda") + config_sdpa = self._make_config() + config_sdpa._attn_implementation = "sdpa" + + config_fa2 = self._make_config() + config_fa2._attn_implementation = "flash_attention_2" + + attn_sdpa = IsaacVisionAttention(config_sdpa).to(device=device, dtype=torch.bfloat16).eval() + attn_fa2 = IsaacVisionAttention(config_fa2).to(device=device, dtype=torch.bfloat16).eval() + + # Align weights so the only difference is the backend + attn_fa2.load_state_dict(attn_sdpa.state_dict()) + + hidden_states = torch.randn(2, 4, config_sdpa.hidden_size, device=device, dtype=torch.bfloat16) + + with torch.no_grad(): + out_sdpa, _ = attn_sdpa(hidden_states) + out_fa2, _ = attn_fa2(hidden_states) + + torch.testing.assert_close( + out_fa2.float(), + out_sdpa.float(), + rtol=1e-3, + atol=1e-3, + msg="FlashAttention2 output deviates from SDPA baseline beyond tolerance", + ) @require_torch @@ -769,11 +735,18 @@ def setUp(self): def _generate_from_messages(self, messages, images, num_tokens=None): prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() processor_output = self.processor(text=prompt, images=images, return_tensors="pt") - tensor_stream = processor_output["tensor_stream"].to(self.device) + packed_inputs = processor_output["packed_inputs"] + input_ids = processor_output["input_ids"].to(self.device) + prompt_len = input_ids.shape[1] + packed_inputs = { + key: (value.to(self.device) if isinstance(value, torch.Tensor) else value) + for key, value in packed_inputs.items() + } with torch.no_grad(): outputs = self.model.generate( - tensor_stream=tensor_stream, + input_ids=input_ids, + packed_inputs=packed_inputs, max_new_tokens=num_tokens or self.max_new_tokens, do_sample=False, pad_token_id=self.tokenizer.eos_token_id, @@ -783,7 +756,8 @@ def _generate_from_messages(self, messages, images, num_tokens=None): ) generated_ids = outputs.sequences - generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) + generated_tail = generated_ids[:, prompt_len:] + generated_text = self.tokenizer.decode(generated_tail[0], skip_special_tokens=True) return generated_text def test_generate_from_image_text(self): @@ -846,11 +820,17 @@ def test_logit_equivalence(self): ] prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() processor_output = self.processor(text=prompt, images=images, return_tensors="pt") - tensor_stream = processor_output["tensor_stream"].to(self.device) + packed_inputs = processor_output["packed_inputs"] + input_ids = processor_output["input_ids"] + device = next(self.model.parameters()).device + input_ids = input_ids.to(device) + # Move packed tensors to model device + packed_inputs = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in packed_inputs.items()} with torch.no_grad(): outputs = self.model.generate( - tensor_stream=tensor_stream, + input_ids=input_ids, + packed_inputs=packed_inputs, max_new_tokens=num_tokens or self.max_new_tokens, do_sample=False, pad_token_id=self.tokenizer.eos_token_id, @@ -919,11 +899,18 @@ def test_hf_generate_box_points(self): messages, images = document_to_messages(document, vision_token=self.hf_config.vision_token) prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() processor_output = self.processor(text=prompt, images=images, return_tensors="pt") - tensor_stream = processor_output["tensor_stream"].to(self.device) + packed_inputs = processor_output["packed_inputs"] + input_ids = processor_output["input_ids"].to(self.device) + prompt_len = input_ids.shape[1] + packed_inputs = { + key: (value.to(self.device) if isinstance(value, torch.Tensor) else value) + for key, value in packed_inputs.items() + } with torch.no_grad(): outputs = self.model.generate( - tensor_stream=tensor_stream, + input_ids=input_ids, + packed_inputs=packed_inputs, max_new_tokens=self.max_new_tokens, do_sample=False, pad_token_id=self.tokenizer.eos_token_id, @@ -932,7 +919,8 @@ def test_hf_generate_box_points(self): ) generated_ids = outputs.sequences - hf_generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) + hf_generated_tail = generated_ids[:, prompt_len:] + hf_generated_text = self.tokenizer.decode(hf_generated_tail[0], skip_special_tokens=True) points = extract_points(hf_generated_text) assert len(points) == 1 first_point = points[0] diff --git a/tests/models/isaac/test_processing_isaac.py b/tests/models/isaac/test_processing_isaac.py index 4d173fdac305..41e28bf253e0 100644 --- a/tests/models/isaac/test_processing_isaac.py +++ b/tests/models/isaac/test_processing_isaac.py @@ -15,13 +15,14 @@ """Testing suite for the Isaac processor.""" import pytest +import torch from transformers import IsaacConfig, PythonBackend from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessorFast +from transformers.models.isaac.modeling_isaac import ModalityType from transformers.models.isaac.processing_isaac import IsaacProcessor from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_vision_available -from transformers.utils.import_utils import is_perceptron_available if is_vision_available(): @@ -30,15 +31,6 @@ Image = None -if is_perceptron_available(): - from perceptron.tensorstream.tensorstream import TensorStream -else: - TensorStream = None - - -require_tensorstream = pytest.mark.skipif(TensorStream is None, reason="TensorStream backend is not available") - - class SimpleIsaacTokenizer(PythonBackend): vocab_files_names = {} model_input_names = ["input_ids"] @@ -105,6 +97,82 @@ def _make_dummy_image(size=(32, 32), color=(255, 0, 0)): return Image.new("RGB", size, color=color) +def _make_processor_with_max_len(tokenizer, base_config, max_len): + config = IsaacConfig(**base_config.to_dict()) + config.max_sequence_length = max_len + vision_config = config.vision_config + image_processor = IsaacImageProcessorFast( + patch_size=vision_config.patch_size, + max_num_patches=vision_config.num_patches, + pixel_shuffle_scale=vision_config.pixel_shuffle_scale_factor, + rescale_factor=config.vision_rescale_factor, + ) + return IsaacProcessor(image_processor=image_processor, tokenizer=tokenizer, config=config) + + +def _run_processor(processor, text, images=None): + return processor(text=text, images=images, return_tensors="pt") + + +def _assert_common(outputs): + assert set(outputs.keys()) == {"input_ids", "packed_inputs"} + input_ids = outputs["input_ids"] + packed_inputs = outputs["packed_inputs"] + + expected_packed_keys = { + "vision_patches", + "vision_token_grids", + "vision_token_offsets", + "vision_token_lengths", + "modality_tensor", + "position_ids", + } + assert set(packed_inputs.keys()) == expected_packed_keys + + assert input_ids.shape[0] == 1 + assert input_ids.dtype == torch.long + + modality = packed_inputs["modality_tensor"] + position_ids = packed_inputs["position_ids"] + assert modality.shape == (1, input_ids.shape[1]) + assert position_ids.shape == (1, input_ids.shape[1], 3) + assert modality.dtype == torch.long + assert position_ids.dtype == torch.long + assert modality.device == input_ids.device == position_ids.device + + return input_ids, packed_inputs + + +def _assert_no_vision(packed_inputs): + assert packed_inputs["vision_patches"] is None + assert packed_inputs["vision_token_grids"] is None + assert packed_inputs["vision_token_offsets"] is None + assert packed_inputs["vision_token_lengths"] is None + + +def _assert_vision_segments(packed_inputs, expected_segments): + assert packed_inputs["vision_patches"] is not None + assert packed_inputs["vision_token_grids"] is not None + assert packed_inputs["vision_token_offsets"] is not None + assert packed_inputs["vision_token_lengths"] is not None + + assert packed_inputs["vision_token_grids"].shape[0] == expected_segments + assert packed_inputs["vision_token_offsets"].shape == (expected_segments,) + assert packed_inputs["vision_token_lengths"].shape == (expected_segments,) + + +def _count_modality(packed_inputs, modality_value): + modality = packed_inputs["modality_tensor"] + return int((modality == modality_value).sum().item()) + + +def _get_image_token_length(processor, image, vision_token): + outputs = _run_processor(processor, text=vision_token, images=[image]) + _, packed = _assert_common(outputs) + _assert_vision_segments(packed, expected_segments=1) + return packed["vision_token_lengths"][0].item() + + @pytest.fixture def isaac_tiny_config(): text_config = { @@ -168,7 +236,6 @@ def isaac_processor(isaac_tokenizer, isaac_tiny_config): @require_torch @require_vision -@require_tensorstream def test_isaac_processor_matches_config_defaults(isaac_processor, isaac_tiny_config): assert isaac_processor.vision_token == isaac_tiny_config.vision_token assert isaac_processor.max_sequence_length == isaac_tiny_config.max_sequence_length @@ -179,86 +246,218 @@ def test_isaac_processor_matches_config_defaults(isaac_processor, isaac_tiny_con @require_torch @require_vision -@require_tensorstream -def test_isaac_processor_text_only_round_trip(isaac_processor): - messages = [{"role": "user", "content": "Hello, how are you?"}] - prompt = isaac_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - outputs = isaac_processor(text=prompt, images=None, return_tensors="pt") - - assert "input_ids" in outputs - assert "tensor_stream" in outputs - assert TensorStream is not None - assert isinstance(outputs["tensor_stream"], TensorStream) - assert outputs["input_ids"].shape[0] == 1 +def test_text_only_has_no_vision_fields(isaac_processor): + outputs = _run_processor(isaac_processor, text="Hello, how are you?", images=None) + _, packed = _assert_common(outputs) + _assert_no_vision(packed) @require_torch -@require_tensorstream -def test_isaac_processor_accepts_batchencoding_chat_template(isaac_processor): +@require_vision +def test_accepts_batchencoding_chat_template(isaac_processor): messages = [{"role": "user", "content": "Hello, how are you?"}] batch_encoding = isaac_processor.apply_chat_template(messages, add_generation_prompt=True) - outputs = isaac_processor(text=batch_encoding, images=None, return_tensors="pt") - - assert "input_ids" in outputs - assert "tensor_stream" in outputs - assert TensorStream is not None - assert isinstance(outputs["tensor_stream"], TensorStream) - assert outputs["input_ids"].shape[0] == 1 + outputs = _run_processor(isaac_processor, text=batch_encoding, images=None) + _, packed = _assert_common(outputs) + _assert_no_vision(packed) @require_torch @require_vision -@require_tensorstream -def test_isaac_processor_with_single_image(isaac_processor): +def test_single_image_returns_offsets_and_lengths(isaac_processor): vision_token = isaac_processor.vision_token text = f"Look at this {vision_token} and describe it." image = _make_dummy_image() - outputs = isaac_processor(text=text, images=[image], return_tensors="pt") - assert TensorStream is not None - assert isinstance(outputs["tensor_stream"], TensorStream) - assert outputs["input_ids"].ndim == 2 + outputs = _run_processor(isaac_processor, text=text, images=[image]) + _, packed = _assert_common(outputs) + _assert_vision_segments(packed, expected_segments=1) + + grid_tokens = torch.prod(packed["vision_token_grids"], dim=-1) + torch.testing.assert_close(packed["vision_token_lengths"], grid_tokens) + torch.testing.assert_close(packed["vision_token_offsets"], torch.zeros_like(packed["vision_token_offsets"])) @require_torch @require_vision -@require_tensorstream -def test_isaac_processor_with_multiple_images(isaac_processor): +def test_multiple_images_have_matching_offsets_lengths_and_grids(isaac_processor): vision_token = isaac_processor.vision_token text = f"First {vision_token} then {vision_token}" images = [_make_dummy_image(color=(255, 0, 0)), _make_dummy_image(color=(0, 255, 0))] - outputs = isaac_processor(text=text, images=images, return_tensors="pt") - assert TensorStream is not None - assert isinstance(outputs["tensor_stream"], TensorStream) - assert outputs["input_ids"].shape[0] == 1 + outputs = _run_processor(isaac_processor, text=text, images=images) + _, packed = _assert_common(outputs) + _assert_vision_segments(packed, expected_segments=2) + + grid_tokens = torch.prod(packed["vision_token_grids"], dim=-1) + torch.testing.assert_close(packed["vision_token_lengths"], grid_tokens) + torch.testing.assert_close(packed["vision_token_offsets"], torch.zeros_like(packed["vision_token_offsets"])) @require_torch @require_vision -@require_tensorstream -def test_isaac_processor_error_on_image_mismatch(isaac_processor): +def test_error_on_image_mismatch(isaac_processor): vision_token = isaac_processor.vision_token text = f"{vision_token} {vision_token}" image = _make_dummy_image() - with pytest.raises(ValueError, match="must match number of images"): - isaac_processor(text=text, images=[image], return_tensors="pt") + with pytest.raises(ValueError, match="occurrences of"): + _run_processor(isaac_processor, text=text, images=[image]) + + +@require_torch +@require_vision +def test_consecutive_vision_tokens_allow_empty_text_segments(isaac_processor): + vision_token = isaac_processor.vision_token + text = f"prefix {vision_token}{vision_token} suffix" + images = [_make_dummy_image(), _make_dummy_image(color=(0, 0, 255))] + + outputs = _run_processor(isaac_processor, text=text, images=images) + _, packed = _assert_common(outputs) + _assert_vision_segments(packed, expected_segments=2) + + torch.testing.assert_close(packed["vision_token_offsets"], torch.zeros_like(packed["vision_token_offsets"])) + grid_tokens = torch.prod(packed["vision_token_grids"], dim=-1) + torch.testing.assert_close(packed["vision_token_lengths"], grid_tokens) + + +@require_torch +@require_vision +def test_device_and_dtype_consistency(isaac_processor): + vision_token = isaac_processor.vision_token + text = f"Describe this {vision_token}" + image = _make_dummy_image() + + outputs = _run_processor(isaac_processor, text=text, images=[image]) + input_ids, packed = _assert_common(outputs) + _assert_vision_segments(packed, expected_segments=1) + + tensors = [ + input_ids, + packed["position_ids"], + packed["modality_tensor"], + packed["vision_token_offsets"], + packed["vision_token_lengths"], + packed["vision_token_grids"], + ] + devices = {t.device for t in tensors} + assert len(devices) == 1 + for t in tensors: + assert t.dtype == torch.long + + +@require_torch +@require_vision +def test_no_crop_when_total_below_max(isaac_processor): + vision_token = isaac_processor.vision_token + text = f"hello {vision_token} world" + image = _make_dummy_image() + + outputs = _run_processor(isaac_processor, text=text, images=[image]) + input_ids, packed = _assert_common(outputs) + _assert_vision_segments(packed, expected_segments=1) + + grid_tokens = torch.prod(packed["vision_token_grids"], dim=-1) + text_tokens = _count_modality(packed, ModalityType.text.value) + assert input_ids.shape[1] == grid_tokens.item() + text_tokens + + +@require_torch +@require_vision +def test_exact_fit_keeps_all_tokens(isaac_processor, isaac_tokenizer, isaac_tiny_config): + vision_token = isaac_processor.vision_token + text = f"hey {vision_token} there" + image = _make_dummy_image() + + base_outputs = _run_processor(isaac_processor, text=text, images=[image]) + base_length = base_outputs["input_ids"].shape[1] + base_packed = base_outputs["packed_inputs"] + base_vision_length = base_packed["vision_token_lengths"][0].item() + + processor = _make_processor_with_max_len(isaac_tokenizer, isaac_tiny_config, base_length) + outputs = _run_processor(processor, text=text, images=[image]) + + input_ids, packed = _assert_common(outputs) + _assert_vision_segments(packed, expected_segments=1) + assert input_ids.shape[1] == base_length + assert packed["vision_token_lengths"].item() == base_vision_length @require_torch @require_vision -@require_tensorstream -def test_isaac_processor_consistent_tensor_stream_types(isaac_processor): - text_only = "Simple question?" - text_with_image = f"Describe this {isaac_processor.vision_token}" +def test_crop_truncates_text_segment_only(isaac_processor, isaac_tokenizer, isaac_tiny_config): + vision_token = isaac_processor.vision_token + text_prefix_tokens = " ".join([f"t{i}" for i in range(8)]) + text_suffix = "tail end" + text = f"{text_prefix_tokens} {vision_token} {text_suffix}" + image = _make_dummy_image() + + base_outputs = _run_processor(isaac_processor, text=text, images=[image]) + base_packed = base_outputs["packed_inputs"] + full_text_tokens = _count_modality(base_packed, ModalityType.text.value) + vision_length = base_packed["vision_token_lengths"][0].item() + + max_len = base_outputs["input_ids"].shape[1] - 4 + processor = _make_processor_with_max_len(isaac_tokenizer, isaac_tiny_config, max_len) + outputs = _run_processor(processor, text=text, images=[image]) + + input_ids, packed = _assert_common(outputs) + _assert_vision_segments(packed, expected_segments=1) + assert input_ids.shape[1] == max_len + + kept_text_tokens = _count_modality(packed, ModalityType.text.value) + assert kept_text_tokens == full_text_tokens - 4 + torch.testing.assert_close(packed["vision_token_offsets"], torch.zeros_like(packed["vision_token_offsets"])) + assert packed["vision_token_lengths"].item() == vision_length + + +@require_torch +@require_vision +def test_crop_cuts_through_image_segment(isaac_processor, isaac_tokenizer, isaac_tiny_config): + vision_token = isaac_processor.vision_token + text_before = "hi" + text_after = "bye" + text = f"{text_before} {vision_token} {text_after}" + image = _make_dummy_image() + + base_outputs = _run_processor(isaac_processor, text=text, images=[image]) + base_packed = base_outputs["packed_inputs"] + vision_full = base_packed["vision_token_lengths"][0].item() + text_before_len = len(isaac_tokenizer.encode(text_before, add_special_tokens=False)) + text_after_len = len(isaac_tokenizer.encode(text_after, add_special_tokens=False)) + total_length = vision_full + text_before_len + text_after_len + + max_len = 40 + start = total_length - max_len + expected_offset = max(0, start - text_before_len) + expected_length = vision_full - expected_offset + + processor = _make_processor_with_max_len(isaac_tokenizer, isaac_tiny_config, max_len) + outputs = _run_processor(processor, text=text, images=[image]) + + input_ids, packed = _assert_common(outputs) + _assert_vision_segments(packed, expected_segments=1) + + assert input_ids.shape[1] == max_len + assert packed["vision_token_offsets"].item() == expected_offset + assert packed["vision_token_lengths"].item() == expected_length + assert _count_modality(packed, ModalityType.text.value) == text_after_len + + +@require_torch +@require_vision +def test_crop_removes_all_vision_when_window_excludes_images(isaac_processor, isaac_tokenizer, isaac_tiny_config): + vision_token = isaac_processor.vision_token + text_tail = "closing" + text = f"{vision_token} {text_tail}" image = _make_dummy_image() - outputs_text = isaac_processor(text=text_only, images=None, return_tensors="pt") - outputs_image = isaac_processor(text=text_with_image, images=[image], return_tensors="pt") + tail_tokens = len(isaac_processor.tokenizer.encode(text_tail, add_special_tokens=False)) + processor = _make_processor_with_max_len(isaac_tokenizer, isaac_tiny_config, tail_tokens) + outputs = _run_processor(processor, text=text, images=[image]) - assert TensorStream is not None - assert isinstance(outputs_text["tensor_stream"], TensorStream) - assert isinstance(outputs_image["tensor_stream"], TensorStream) - assert outputs_text["input_ids"].shape[0] == outputs_image["input_ids"].shape[0] == 1 + input_ids, packed = _assert_common(outputs) + _assert_no_vision(packed) + assert input_ids.shape[1] == tail_tokens + assert _count_modality(packed, ModalityType.text.value) == tail_tokens From 2b6969820aecfabeb662873f28d110121b96386e Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Tue, 30 Dec 2025 17:00:53 +0400 Subject: [PATCH 59/77] feat: batched inference + rope refactor --- docs/source/en/model_doc/isaac.md | 2 +- .../models/isaac/configuration_isaac.py | 2 + .../isaac/image_processing_isaac_fast.py | 55 +- .../models/isaac/modeling_isaac.py | 289 +++++----- .../models/isaac/modular_isaac.py | 516 +++++++++--------- .../models/isaac/processing_isaac.py | 130 ++++- src/transformers/utils/import_utils.py | 8 - tests/models/isaac/test_modeling_isaac.py | 407 +++++++++----- tests/models/isaac/test_processing_isaac.py | 208 ++++++- 9 files changed, 996 insertions(+), 621 deletions(-) diff --git a/docs/source/en/model_doc/isaac.md b/docs/source/en/model_doc/isaac.md index 538d33033d53..0d4c74b3a06c 100644 --- a/docs/source/en/model_doc/isaac.md +++ b/docs/source/en/model_doc/isaac.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-12.* +*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-30.* *This model was added to Hugging Face Transformers in 2025.*
diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index 96ab227aa87e..fd6820ad83af 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -144,6 +144,8 @@ def __init__( if vision_attn is not None: self.vision_config._attn_implementation = vision_attn + if getattr(self, "_attn_implementation", None) is None: + self._attn_implementation = "sdpa" # Vision normalization parameters self.vision_rescale_factor = float(vision_rescale_factor) diff --git a/src/transformers/models/isaac/image_processing_isaac_fast.py b/src/transformers/models/isaac/image_processing_isaac_fast.py index f487c249571d..caf35d3eacc6 100644 --- a/src/transformers/models/isaac/image_processing_isaac_fast.py +++ b/src/transformers/models/isaac/image_processing_isaac_fast.py @@ -28,8 +28,6 @@ from ...image_utils import PILImageResampling from ...processing_utils import Unpack from ...utils import TensorType, auto_docstring - -# Vision preprocessing constants from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD from ...utils.import_utils import is_torch_available @@ -158,7 +156,6 @@ def get_image_size_for_max_num_patches( @auto_docstring class IsaacImageProcessorFast(BaseImageProcessorFast): MAX_PIXELS = 60_000_000 # 60โ€‘megapixel ceiling โ‰ˆ 8200 ร— 7300 px - r"""Fast torch-based image processor for Isaac vision inputs.""" resample = PILImageResampling.BILINEAR model_input_names = ["patches", "token_grids"] @@ -234,21 +231,10 @@ def _preprocess( grouped_outputs = {} for shape, stacked_images in grouped_images.items(): - if stacked_images.ndim != 4: - raise ValueError( - f"Expected images shaped as (batch, channels, height, width); got shape {tuple(stacked_images.shape)}." - ) - batch_size, channels, original_height, original_width = stacked_images.shape if bool(self.do_convert_rgb) and channels == 1: stacked_images = stacked_images.repeat(1, 3, 1, 1) - channels = 3 - - if original_height * original_width > self.MAX_PIXELS: - raise ValueError( - f"Image area {original_height * original_width} (h={original_height}, w={original_width}) exceeds MAX_PIXELS={self.MAX_PIXELS}; enable resizing or provide smaller inputs." - ) target_height, target_width = get_image_size_for_max_num_patches( original_height, @@ -258,43 +244,31 @@ def _preprocess( min_num_patches=min_num_patches, pixel_shuffle_scale=pixel_shuffle_scale, ) - if do_resize: - resize_size = SizeDict(height=target_height, width=target_width) image_batch = self.resize( - image=stacked_images, - size=resize_size, - interpolation=interpolation, + stacked_images, SizeDict(height=target_height, width=target_width), interpolation=interpolation ) else: - if ((original_height % patch_size) != 0) or ((original_width % patch_size) != 0): + if (original_height % patch_size) or (original_width % patch_size): raise ValueError( f"Image dimensions (h={original_height}, w={original_width}) must be divisible by patch_size={patch_size} when resize is disabled; enable resizing or adjust the input resolution." ) - image_batch = stacked_images - target_height, target_width = original_height, original_width - - if do_rescale: - image_batch = self.rescale_and_normalize( - image_batch, - do_rescale=do_rescale, - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - ) + image_batch, target_height, target_width = stacked_images, original_height, original_width + + image_batch = self.rescale_and_normalize( + image_batch, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + ) patches = torch_extract_patches(image_batch, patch_size, patch_size) _, height_tokens, width_tokens, _ = patches.shape token_grid = ( - torch.tensor( - [height_tokens, width_tokens], - dtype=torch.long, - device=patches.device, - ) - .unsqueeze(0) - .repeat(batch_size, 1) + torch.tensor([height_tokens, width_tokens], device=patches.device).long().expand(batch_size, 2) ) real_dim = ( @@ -325,8 +299,7 @@ def _preprocess( ) grouped_outputs[shape] = (patches, token_grid, virtual_dim, real_dim) - # Helper to reorder a single item of the tuple payloads using the same grouped_images_index - def _reorder_grouped_item( + def _reorder_grouped_item( # reorder an item of tuple payloads using the same grouped_images_index grouped: dict[tuple[int, ...], tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], grouped_index: dict[tuple[int, ...], list[int]], item_idx: int, diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 7f31ae47c480..528706593a4b 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -202,7 +202,6 @@ def _pack_to_batch( - packed_pixel_values: (batch, max_len, patch_dim) padded with zeros, or None if batch_size == 0 - seq_lengths: (batch,) lengths for each image """ - # Per-image token counts seq_lengths = spatial_shapes.long().prod(dim=-1) # (B,) batch_size = int(seq_lengths.numel()) if batch_size == 0: @@ -218,8 +217,6 @@ def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor """Flatten a padded batch back to packed sequence order using `seq_lengths`.""" lengths = seq_lengths.to(device=embeddings.device).tolist() chunks = [embeddings[i, :l] for i, l in enumerate(lengths) if l > 0] - if not chunks: - return embeddings.new_zeros((0, embeddings.size(-1))) return torch.cat(chunks, dim=0) @@ -647,18 +644,6 @@ def forward(self, image_features): class IsaacVisionEmbedding(nn.Module): - """Wraps the vision tower plus projection into the text hidden size. - - Args: - config (IsaacConfig): Composite config containing both vision and text settings. - - Inputs: - vision_tokens (Tuple[Tensor, Tensor]): Packed vision patches and token grids. - - Returns: - torch.Tensor: Projected vision embeddings aligned to the text hidden size. - """ - _supports_sdpa = True def __init__(self, config: IsaacConfig): @@ -722,17 +707,12 @@ def forward( with torch.no_grad(): pos = position_ids.clone() not_spatial = modality_tensor != ModalityType.image.value - if not_spatial.any(): - # Collapse non-vision modalities to 1D positions so rotary embedding - # treats them like text tokens while keeping image tokens 3D. - data_1d = pos[not_spatial][..., 0].unsqueeze(-1) - pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) - + data_1d = pos[not_spatial][..., 0].unsqueeze(-1) # Collapse non-vision modalities to 1D positions + pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) pos_axes = pos.permute(2, 0, 1).contiguous() cos_axes, sin_axes = super().forward(hidden_states, pos_axes) - cos_axes = cos_axes.to(hidden_states.dtype) - sin_axes = sin_axes.to(hidden_states.dtype) + cos_axes, sin_axes = cos_axes.to(hidden_states.dtype), sin_axes.to(hidden_states.dtype) cos_combined, sin_combined = self._combine_axes(cos_axes), self._combine_axes(sin_axes) return cos_combined, sin_combined @@ -768,6 +748,7 @@ def __init__(self, config: IsaacConfig): self.max_sequence_length = config.max_sequence_length self.vision_rescale_factor = config.vision_rescale_factor self.vision_token = config.vision_token + self.rope_deltas = None self.post_init() @@ -810,6 +791,7 @@ def embed_packed_inputs( - vision_token_grids: (num_images, 2) token grid sizes or None - vision_token_offsets: (num_images,) offsets into each image's virtual token span (optional) - vision_token_lengths: (num_images,) surviving virtual token lengths per image (optional) + - vision_token_batch_indices: (num_images,) batch row for each image (optional; defaults to zeros) """ modality = packed_inputs["modality_tensor"].to(device=input_ids.device, dtype=torch.long) embeds = self.text_model.embed_tokens(input_ids) @@ -822,42 +804,103 @@ def embed_packed_inputs( vision = self.vision_embedding((vision_patches, token_grids)) # (total_tokens, hidden) # per-image token counts AFTER pixel-shuffle - s = int(self.config.vision_config.pixel_shuffle_scale_factor) - sizes = token_grids.prod(-1).div(s * s, rounding_mode="floor").tolist() + vision_reduction_factor = int(self.config.vision_config.pixel_shuffle_scale_factor) + sizes = ( + token_grids.prod(-1).div(vision_reduction_factor * vision_reduction_factor, rounding_mode="floor").tolist() + ) offsets = packed_inputs.get("vision_token_offsets") lengths = packed_inputs.get("vision_token_lengths") + batch_indices = packed_inputs.get("vision_token_batch_indices") + + chunks = vision.split(sizes, dim=0) + picked: list[torch.Tensor] = [] + picked_batch: list[int] = [] + for chunk, size, offset, length, batch_index in zip( + chunks, + sizes, + offsets.tolist(), + lengths.tolist(), + (batch_indices.tolist() if batch_indices is not None else [0] * len(sizes)), + ): + if size <= 0: + continue + offset = max(0, min(int(offset), size)) + length = max(0, min(int(length), size - offset)) + if length: + picked.append(chunk[offset : offset + length]) + picked_batch.append(int(batch_index)) + if picked: + vision_chunks = picked + vision_batch_idx = picked_batch + else: + vision_chunks = vision_batch_idx = [] - if offsets is not None or lengths is not None: - off = ( - offsets.to(device=vision.device, dtype=torch.long) - if offsets is not None - else torch.zeros(len(sizes), device=vision.device, dtype=torch.long) - ) - ln = ( - lengths.to(device=vision.device, dtype=torch.long) - if lengths is not None - else torch.tensor(sizes, device=vision.device, dtype=torch.long) - ) - - # Honor per-image crop windows (after pixel shuffle) so we only splice back - # the surviving virtual tokens instead of the full vision span. - chunks = vision.split(sizes, dim=0) - picked: list[torch.Tensor] = [] - for c, n, o, l in zip(chunks, sizes, off.tolist(), ln.tolist()): - if n <= 0: - continue - o = max(0, min(int(o), n)) - l = max(0, min(int(l), n - o)) - if l: - picked.append(c[o : o + l]) - vision = torch.cat(picked, 0) if picked else vision.new_zeros((0, vision.size(-1))) - - m = modality == ModalityType.image.value + vision = torch.cat(vision_chunks, 0) if vision_chunks else vision.new_zeros((0, vision.size(-1))) embeds = embeds.clone() - embeds[m] = vision.to(device=embeds.device, dtype=embeds.dtype) + num_batches = modality.shape[0] + image_positions = [ + (modality[b] == ModalityType.image.value).nonzero(as_tuple=False).squeeze(-1) for b in range(num_batches) + ] + cursors = [0 for _ in range(num_batches)] + + for chunk, batch_index in zip(vision_chunks, vision_batch_idx): + if chunk.numel() == 0: + continue + positions = image_positions[batch_index] + start = cursors[batch_index] + end = start + chunk.shape[0] + embeds[batch_index, positions[start:end]] = chunk.to(device=embeds.device, dtype=embeds.dtype) + cursors[batch_index] = end return embeds, modality + def get_rope_index( + self, + *, + position_ids: Optional[torch.Tensor] = None, + attention_mask: torch.Tensor, + inputs_embeds: torch.Tensor, + cache_position: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Build 3D position ids and per-batch RoPE deltas.""" + + device = inputs_embeds.device + batch_size, seq_len = inputs_embeds.shape[:2] + + if position_ids is None: + cp = cache_position.to(device=device, dtype=torch.long) + if cp.ndim == 1: + cp = cp.view(1, -1).expand(batch_size or 1, -1) + + base_delta = torch.as_tensor( + 0 if self.rope_deltas is None else self.rope_deltas, + device=device, + dtype=torch.long, + ).reshape(-1, 1) + base_delta = torch.broadcast_to(base_delta, (batch_size, 1)) + + mask_delta = attention_mask.to(device=device, dtype=torch.long).sum(1, keepdim=True) - attention_mask.size( + 1 + ) + rope_position = cp + base_delta + mask_delta + pos_3d = rope_position.unsqueeze(-1).expand(-1, -1, 3) + return pos_3d, base_delta + + position_ids = position_ids.to(device=device) + if position_ids.ndim == 2: + position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) + + if position_ids.shape[1] != seq_len: + start_positions = position_ids[:, :1, 0] + position_ids = torch.arange(seq_len, device=position_ids.device).view(1, -1) + start_positions + position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) + + attn = attention_mask.to(device=device, dtype=torch.long) + m_per_batch = position_ids.amax(dim=(1, 2)) + seq_lens = attn.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=device) + rope_deltas = (m_per_batch + 1 - seq_lens).to(dtype=position_ids.dtype).unsqueeze(1) + return position_ids, rope_deltas + @auto_docstring @check_model_inputs def forward( @@ -879,8 +922,7 @@ def forward( Args: packed_inputs (`dict`, *optional*): - Plain tensor payloads extracted from a TensorStream. When provided, it replaces the TensorStream path - and requires `input_ids` for text tokens (or `text_token_ids` so `input_ids` can be rebuilt). + Plain tensor payloads. When provided, requires `input_ids` for text tokens (or `text_token_ids` so `input_ids` can be rebuilt). modality_tensor (`torch.LongTensor`, *optional*): Modality identifiers aligned with the embedded sequence, shaped `(batch_size, seq_len)` and containing values from `ModalityType`. Automatically built from `packed_inputs` or treated as text-only when omitted. @@ -888,58 +930,47 @@ def forward( output_attentions = kwargs.pop("output_attentions", None) - # Resolve the input source (prefer packed_inputs > ids > embeds). modality_tensor: Optional[torch.Tensor] = None - precomputed_position_ids: Optional[torch.Tensor] = None if packed_inputs is not None: inputs_embeds, modality_tensor = self.embed_packed_inputs(input_ids, packed_inputs) - precomputed_position_ids = packed_inputs.get("position_ids") - if precomputed_position_ids is not None: - precomputed_position_ids = precomputed_position_ids.to(inputs_embeds.device) elif input_ids is not None: inputs_embeds = self.text_model.embed_tokens(input_ids) device = inputs_embeds.device batch_size, seq_len = inputs_embeds.shape[:2] - # Ensure cache exists when requested if use_cache and past_key_values is None: - cache_config = self.config.get_text_config() if hasattr(self.config, "get_text_config") else self.config - past_key_values = DynamicCache(config=cache_config) + past_key_values = DynamicCache(config=self.config.get_text_config()) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=device) if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.long) + attention_mask = torch.ones(inputs_embeds.shape[:2], device=inputs_embeds.device, dtype=torch.long) - position_ids = position_ids if position_ids is not None else precomputed_position_ids - if position_ids is None: - position_ids = cache_position.view(1, -1).expand(batch_size, -1) + if position_ids is None and packed_inputs is not None and packed_inputs.get("position_ids") is not None: + position_ids = packed_inputs.get("position_ids").to(device=device) + + position_ids, rope_deltas = self.get_rope_index( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + ) + self.rope_deltas = rope_deltas if modality_tensor is None: modality_tensor = torch.full( (batch_size, seq_len), ModalityType.text.value, device=device, dtype=torch.long ) - else: - modality_tensor = modality_tensor.to(device=device, dtype=torch.long) - - position_ids = position_ids.to(device=device) - - if position_ids.ndim == 2: - position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) - if position_ids.shape[1] != seq_len: - start_positions = position_ids[:, :1, 0] - position_ids = torch.arange(seq_len, device=device).view(1, -1) + start_positions - position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) cos, sin = self.rotary_emb(position_ids, modality_tensor, hidden_states=inputs_embeds) decoder_position_ids = position_ids[..., 0] if position_ids.ndim == 3 else position_ids - if not isinstance(attention_mask, dict): # Prepare attention mask + if not isinstance(attention_mask, dict): attention_mask = create_masks_for_generate( config=self.config, input_embeds=inputs_embeds, @@ -1219,22 +1250,18 @@ class IsaacPreTrainedModel(PreTrainedModel): @auto_docstring class IsaacForConditionalGeneration(IsaacPreTrainedModel, GenerationMixin): - """Isaac multimodal model for conditional generation.""" - _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - config_class = IsaacConfig _can_compile_fullgraph = False all_tied_weights_keys: dict[str, str] = {"lm_head.weight": "model.text_model.embed_tokens.weight"} def __init__(self, config: IsaacConfig): super().__init__(config) - self.model = IsaacModel(config) # Use our custom model + self.model = IsaacModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - self.rope_deltas = None # Initialize weights and apply final processing self.post_init() @@ -1258,48 +1285,15 @@ def forward( """Run multimodal CausalLM forward, accepting packed vision/text inputs. Args: - input_ids: Text token ids. packed_inputs (`dict`, *optional*): Packed vision/text payload from ``IsaacProcessor`` containing modality ids, MRoPE position ids, and vision patch tensors/grids (with optional offsets/lengths) used to rebuild embeddings. - attention_mask: Attention mask or mask dict; created if not provided. - position_ids: Optional 3D MRoPE positions; auto-derived when absent. - past_key_values: Cache for decoding. - inputs_embeds: Precomputed embeddings (bypass embedding layer). - labels: Target ids for computing language modeling loss. - use_cache: Whether to return caches. - cache_position: Positions for cache-aware generation. Returns: CausalLMOutputWithPast: logits, optional loss, caches, hidden states, attentions. """ output_attentions = kwargs.pop("output_attentions", None) - if position_ids is None and packed_inputs is not None: - pos_3d = packed_inputs.get("position_ids") - if pos_3d is not None: - position_ids, self.rope_deltas = self.get_rope_index( - position_ids=pos_3d, - attention_mask=attention_mask, - ) - - elif position_ids is None and cache_position is not None and self.rope_deltas is not None: - if input_ids is not None: - base_position_ids = torch.arange(input_ids.size(1), device=input_ids.device)[None, :, None].expand( - input_ids.size(0), -1, 3 - ) - else: - batch_size, seq_len = inputs_embeds.shape[:2] - dummy_ids = torch.zeros((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long) - base_position_ids = torch.arange(dummy_ids.size(1), device=dummy_ids.device)[None, :, None].expand( - dummy_ids.size(0), -1, 3 - ) - - rope_delta = (cache_position[0] + self.rope_deltas).to(base_position_ids.device) - if not isinstance(rope_delta, int): - rope_delta = rope_delta.repeat_interleave(base_position_ids.shape[0] // rope_delta.shape[0], dim=0) - position_ids = base_position_ids.add(rope_delta) - outputs = self.model( input_ids=input_ids, packed_inputs=packed_inputs, @@ -1312,10 +1306,8 @@ def forward( cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] logits = self.lm_head(hidden_states) - loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) @@ -1328,53 +1320,6 @@ def forward( attentions=outputs.attentions if output_attentions else None, ) - def set_input_embeddings(self, value: nn.Module) -> None: - self.model.set_input_embeddings(value) - vocab_size = getattr(value, "num_embeddings", None) - self.config.vocab_size = vocab_size - self.model.config.vocab_size = vocab_size - self.model.text_model.config.vocab_size = vocab_size - if self.lm_head.weight.shape[0] != vocab_size: - self.lm_head = nn.Linear(self.config.hidden_size, vocab_size, bias=False) - self.lm_head.weight = self.model.text_model.embed_tokens.weight - - def get_rope_index( - self, - *, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Compute (position_ids_3d, rope_deltas) without TensorStream. - - - If `position_ids` is provided, it must be shape (B, L, 3). - - Else, if `input_ids` is provided, position ids are synthesized as (B, L, 3). - - `rope_deltas` is (B, 1) used to advance positions during decode. - """ - - if position_ids is None: - pos_3d = torch.arange(input_ids.size(1), device=input_ids.device)[None, :, None].expand( - input_ids.size(0), -1, 3 - ) - else: - pos_3d = position_ids - if pos_3d.ndim != 3 or pos_3d.size(-1) != 3: - raise ValueError( - f"`position_ids` must have shape (batch, seq_len, 3) for MRoPE; got shape {tuple(pos_3d.shape)}." - ) - - B, L, _ = pos_3d.shape - m_per_batch = pos_3d.amax(dim=(1, 2)) - - if attention_mask is None: - seq_lens = torch.full((B,), L, device=pos_3d.device, dtype=m_per_batch.dtype) - else: - seq_lens = attention_mask.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=m_per_batch.device) - - rope_deltas = (m_per_batch + 1 - seq_lens).to(dtype=pos_3d.dtype).unsqueeze(1) - return pos_3d, rope_deltas - def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, @@ -1398,8 +1343,8 @@ def prepare_inputs_for_generation( if packed_inputs is None: return model_inputs - cache_position = model_inputs.get("cache_position", cache_position) - first_step = cache_position is None or cache_position[0] == 0 + past_len = past_key_values.get_seq_length() if past_key_values is not None else 0 + first_step = past_len == 0 model_inputs["packed_inputs"] = packed_inputs if first_step else None model_inputs["position_ids"] = None @@ -1409,5 +1354,15 @@ def prepare_inputs_for_generation( def can_generate(cls) -> bool: return True + def set_input_embeddings(self, value: nn.Module) -> None: + self.model.set_input_embeddings(value) + vocab_size = getattr(value, "num_embeddings", None) + self.config.vocab_size = vocab_size + self.model.config.vocab_size = vocab_size + self.model.text_model.config.vocab_size = vocab_size + if self.lm_head.weight.shape[0] != vocab_size: + self.lm_head = nn.Linear(self.config.hidden_size, vocab_size, bias=False) + self.lm_head.weight = self.model.text_model.embed_tokens.weight + __all__ = ["IsaacModel", "IsaacPreTrainedModel", "IsaacForConditionalGeneration"] diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 31c915de5b03..1d561ddca798 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -18,31 +18,8 @@ import copy import math from collections.abc import Callable, Sequence -from typing import Any, Optional, Union - -from ...utils.import_utils import ( - is_torch_available, - is_torchdynamo_compiling, - is_torchvision_available, - is_vision_available, -) - - -if is_torch_available(): - import torch - import torch.nn as nn - import torch.nn.functional as F - - -if is_vision_available(): - from PIL.Image import Image -else: - Image = None - -if is_torchvision_available(): - from ..pix2struct.image_processing_pix2struct_fast import torch_extract_patches - from enum import IntEnum +from typing import Any, Optional, Union from ...cache_utils import DynamicCache from ...configuration_utils import PretrainedConfig, layer_type_validation @@ -65,8 +42,6 @@ from ...models.qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3Model, Qwen3PreTrainedModel from ...processing_utils import ProcessorMixin, Unpack from ...utils import TensorType, auto_docstring - -# Vision preprocessing constants from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD from ...utils.generic import ( @@ -75,6 +50,12 @@ can_return_tuple, check_model_inputs, ) +from ...utils.import_utils import ( + is_torch_available, + is_torchdynamo_compiling, + is_torchvision_available, + is_vision_available, +) from ..qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling from ..siglip2.configuration_siglip2 import Siglip2VisionConfig from ..siglip2.modeling_siglip2 import ( @@ -85,6 +66,18 @@ ) +if is_torch_available(): + import torch + import torch.nn as nn + import torch.nn.functional as F +if is_vision_available(): + from PIL.Image import Image +else: + Image = None +if is_torchvision_available(): + from ..pix2struct.image_processing_pix2struct_fast import torch_extract_patches + + class ModalityType(IntEnum): """ Modality identifiers for events. @@ -159,7 +152,6 @@ class IsaacImageProcessorFastKwargs(ImagesKwargs, total=False): @auto_docstring class IsaacImageProcessorFast(BaseImageProcessorFast): MAX_PIXELS = 60_000_000 # 60โ€‘megapixel ceiling โ‰ˆ 8200 ร— 7300 px - r"""Fast torch-based image processor for Isaac vision inputs.""" resample = PILImageResampling.BILINEAR model_input_names = ["patches", "token_grids"] @@ -235,21 +227,10 @@ def _preprocess( grouped_outputs = {} for shape, stacked_images in grouped_images.items(): - if stacked_images.ndim != 4: - raise ValueError( - f"Expected images shaped as (batch, channels, height, width); got shape {tuple(stacked_images.shape)}." - ) - batch_size, channels, original_height, original_width = stacked_images.shape if bool(self.do_convert_rgb) and channels == 1: stacked_images = stacked_images.repeat(1, 3, 1, 1) - channels = 3 - - if original_height * original_width > self.MAX_PIXELS: - raise ValueError( - f"Image area {original_height * original_width} (h={original_height}, w={original_width}) exceeds MAX_PIXELS={self.MAX_PIXELS}; enable resizing or provide smaller inputs." - ) target_height, target_width = get_image_size_for_max_num_patches( original_height, @@ -259,43 +240,31 @@ def _preprocess( min_num_patches=min_num_patches, pixel_shuffle_scale=pixel_shuffle_scale, ) - if do_resize: - resize_size = SizeDict(height=target_height, width=target_width) image_batch = self.resize( - image=stacked_images, - size=resize_size, - interpolation=interpolation, + stacked_images, SizeDict(height=target_height, width=target_width), interpolation=interpolation ) else: - if ((original_height % patch_size) != 0) or ((original_width % patch_size) != 0): + if (original_height % patch_size) or (original_width % patch_size): raise ValueError( f"Image dimensions (h={original_height}, w={original_width}) must be divisible by patch_size={patch_size} when resize is disabled; enable resizing or adjust the input resolution." ) - image_batch = stacked_images - target_height, target_width = original_height, original_width - - if do_rescale: - image_batch = self.rescale_and_normalize( - image_batch, - do_rescale=do_rescale, - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - ) + image_batch, target_height, target_width = stacked_images, original_height, original_width + + image_batch = self.rescale_and_normalize( + image_batch, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + ) patches = torch_extract_patches(image_batch, patch_size, patch_size) _, height_tokens, width_tokens, _ = patches.shape token_grid = ( - torch.tensor( - [height_tokens, width_tokens], - dtype=torch.long, - device=patches.device, - ) - .unsqueeze(0) - .repeat(batch_size, 1) + torch.tensor([height_tokens, width_tokens], device=patches.device).long().expand(batch_size, 2) ) real_dim = ( @@ -326,8 +295,7 @@ def _preprocess( ) grouped_outputs[shape] = (patches, token_grid, virtual_dim, real_dim) - # Helper to reorder a single item of the tuple payloads using the same grouped_images_index - def _reorder_grouped_item( + def _reorder_grouped_item( # reorder an item of tuple payloads using the same grouped_images_index grouped: dict[tuple[int, ...], tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], grouped_index: dict[tuple[int, ...], list[int]], item_idx: int, @@ -448,7 +416,6 @@ def _pack_to_batch( - packed_pixel_values: (batch, max_len, patch_dim) padded with zeros, or None if batch_size == 0 - seq_lengths: (batch,) lengths for each image """ - # Per-image token counts seq_lengths = spatial_shapes.long().prod(dim=-1) # (B,) batch_size = int(seq_lengths.numel()) if batch_size == 0: @@ -464,8 +431,6 @@ def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor """Flatten a padded batch back to packed sequence order using `seq_lengths`.""" lengths = seq_lengths.to(device=embeddings.device).tolist() chunks = [embeddings[i, :l] for i, l in enumerate(lengths) if l > 0] - if not chunks: - return embeddings.new_zeros((0, embeddings.size(-1))) return torch.cat(chunks, dim=0) @@ -790,18 +755,6 @@ def forward(self, image_features): class IsaacVisionEmbedding(nn.Module): - """Wraps the vision tower plus projection into the text hidden size. - - Args: - config (IsaacConfig): Composite config containing both vision and text settings. - - Inputs: - vision_tokens (Tuple[Tensor, Tensor]): Packed vision patches and token grids. - - Returns: - torch.Tensor: Projected vision embeddings aligned to the text hidden size. - """ - _supports_sdpa = True def __init__(self, config: IsaacConfig): @@ -975,6 +928,8 @@ def __init__( if vision_attn is not None: self.vision_config._attn_implementation = vision_attn + if getattr(self, "_attn_implementation", None) is None: + self._attn_implementation = "sdpa" # Vision normalization parameters self.vision_rescale_factor = float(vision_rescale_factor) @@ -1037,23 +992,95 @@ def __init__( self.image_processor = image_processor super().__init__(image_processor, tokenizer) + text_pad_token_id = getattr(self.tokenizer, "pad_token_id", None) + image_pad_token_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>") + + self.text_pad_token_id = int(text_pad_token_id) + self.image_pad_token_id = int(image_pad_token_id) + self.pad_token_id = self.text_pad_token_id + self.current_processor = self.image_processor self.config = config self.chat_template = getattr(self.tokenizer, "chat_template", None) self.vision_token = vision_token self.max_sequence_length = max_sequence_length + def _pack_batch( + self, texts: list[str], images_list: Optional[list[Optional[list[Image]]]] + ) -> dict[str, Optional[torch.Tensor]]: + if images_list is None: + pairs = ((t, None) for t in texts) + else: + pairs = zip(texts, images_list, strict=True) + + per_sample: list[dict[str, Optional[torch.Tensor]]] = [] + for txt, imgs in pairs: + if imgs is not None and isinstance(imgs, Image): + imgs = [imgs] + per_sample.append(self._pack_single(txt, imgs)) + + lengths = [int(p["input_ids"].shape[1]) for p in per_sample] + max_len = max(lengths, default=0) + batch = len(per_sample) + + # Use first device with data as anchor + base_device = torch.device("cpu") + for p in per_sample: + if p["input_ids"].numel() > 0: + base_device = p["input_ids"].device + break + + pad_id = self.text_pad_token_id + padded_input_ids = torch.full((batch, max_len), pad_id, device=base_device, dtype=torch.long) + padded_modality = torch.full((batch, max_len), ModalityType.text.value, device=base_device, dtype=torch.long) + padded_position_ids = torch.zeros((batch, max_len, 3), device=base_device, dtype=torch.long) + + for i, (sample, l) in enumerate(zip(per_sample, lengths)): + if l: + padded_input_ids[i, -l:] = sample["input_ids"][0] + padded_modality[i, -l:] = sample["modality_tensor"][0] + padded_position_ids[i, -l:] = sample["position_ids"][0] + + # Vision-side aggregation + v_samples = [(b, s) for b, s in enumerate(per_sample) if s["vision_patches"] is not None] + if v_samples: + vision_patches_list = [s["vision_patches"] for _, s in v_samples] + vision_grids_list = [s["vision_token_grids"] for _, s in v_samples] + vision_offsets_list = [s["vision_token_offsets"] for _, s in v_samples] + vision_lengths_list = [s["vision_token_lengths"] for _, s in v_samples] + vision_batch_indices = [torch.full_like(s["vision_token_offsets"], b) for b, s in v_samples] + + vision_patches = torch.cat(vision_patches_list, dim=0) + vision_token_grids = torch.cat(vision_grids_list, dim=0) + vision_token_offsets = torch.cat(vision_offsets_list, dim=0) + vision_token_lengths = torch.cat(vision_lengths_list, dim=0) + vision_token_batch_indices = torch.cat(vision_batch_indices, dim=0) + else: + vision_patches = vision_token_grids = vision_token_offsets = vision_token_lengths = ( + vision_token_batch_indices + ) = None + + return { + "input_ids": padded_input_ids, + "vision_patches": vision_patches, + "vision_token_grids": vision_token_grids, + "vision_token_offsets": vision_token_offsets, + "vision_token_lengths": vision_token_lengths, + "vision_token_batch_indices": vision_token_batch_indices, + "modality_tensor": padded_modality, + "position_ids": padded_position_ids, + } + def _pack_single(self, text: str, images: Optional[list[Image]]) -> dict[str, Optional[torch.Tensor]]: - # Parse by vision_token; interleave text segments and image segments. - segments = text.split(self.vision_token) + segments = text.split(self.vision_token) # Parse by vision_token; interleave text segments and image segments. num_images = len(segments) - 1 - if num_images and (images is None or len(images) != num_images): - raise ValueError( - f"Expected one image per '{self.vision_token}' token: found {num_images} token(s) but received {0 if images is None else len(images)} image(s)." - ) - items: list[dict[str, Any]] = [] total = 0 + num_provided_images = len(images) if images is not None else 0 + if not num_images == num_provided_images: + raise ValueError( + f"IsaacProcessor expects one image per image token, got {num_images} tokens and {num_provided_images} images in sample with text {text} " + ) for index, segment in enumerate(segments): if segment: @@ -1090,7 +1117,7 @@ def _pack_single(self, text: str, images: Optional[list[Image]]) -> dict[str, Op start = max(0, total - self.max_sequence_length) end = total - fill_value = self.pad_token_id + image_pad_value = self.image_pad_token_id base_device: Optional[torch.device] = None position_ids, modality, input_ids = [], [], [] vpatches, grids, vision_token_offsets, vision_token_lengths = [], [], [], [] @@ -1140,7 +1167,7 @@ def _pack_single(self, text: str, images: Optional[list[Image]]) -> dict[str, Op ) ) input_ids.append( - torch.full((segment_kept_length,), fill_value, device=base_device, dtype=torch.long) + torch.full((segment_kept_length,), image_pad_value, device=base_device, dtype=torch.long) ) vpatches.append(item["patches"].to(base_device)) # full patches; slice later via offsets/lengths @@ -1157,6 +1184,9 @@ def _pack_single(self, text: str, images: Optional[list[Image]]) -> dict[str, Op global_offset += segment_length + if base_device is None: + base_device = torch.device("cpu") + modality_tensor = ( torch.cat(modality, 0).unsqueeze(0) if modality @@ -1199,21 +1229,28 @@ def __call__( **kwargs, ) -> BatchFeature: texts = [text] if isinstance(text, str) else text - if len(texts) != 1: - raise ValueError( - f"IsaacProcessor currently supports batch_size=1; received {len(texts)} text prompts. Split the batch and call the processor per sample." - ) - - images_list = None + images_list: Optional[list[Optional[list[Image]]]] = None if images is not None: - images_list = [images] if isinstance(images, Image) else images - n_tok = texts[0].count(self.vision_token) - if n_tok != len(images_list): - raise ValueError( - f"Expected {len(images_list)} occurrences of '{self.vision_token}' (one per provided image), but found {n_tok} in the text." - ) - - packed = self._pack_single(texts[0], images_list) + if isinstance(images, list) and len(images) == len(texts): + if not images: + images_list = [] + elif isinstance(images[0], list): + images_list = images # already per-sample + else: + images_list = [[img] for img in images] # list of images, one per sample + else: + images_list = [] + for t in texts: + n_tok = t.count(self.vision_token) + if n_tok == 0: + images_list.append(None) + else: + if isinstance(images, list): + images_list.append(images) + else: + images_list.append([images]) + + packed = self._pack_batch(texts, images_list) input_ids = packed.pop("input_ids") return BatchFeature(data={"input_ids": input_ids, "packed_inputs": packed}) @@ -1267,17 +1304,12 @@ def forward( with torch.no_grad(): pos = position_ids.clone() not_spatial = modality_tensor != ModalityType.image.value - if not_spatial.any(): - # Collapse non-vision modalities to 1D positions so rotary embedding - # treats them like text tokens while keeping image tokens 3D. - data_1d = pos[not_spatial][..., 0].unsqueeze(-1) - pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) - + data_1d = pos[not_spatial][..., 0].unsqueeze(-1) # Collapse non-vision modalities to 1D positions + pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) pos_axes = pos.permute(2, 0, 1).contiguous() cos_axes, sin_axes = super().forward(hidden_states, pos_axes) - cos_axes = cos_axes.to(hidden_states.dtype) - sin_axes = sin_axes.to(hidden_states.dtype) + cos_axes, sin_axes = cos_axes.to(hidden_states.dtype), sin_axes.to(hidden_states.dtype) cos_combined, sin_combined = self._combine_axes(cos_axes), self._combine_axes(sin_axes) return cos_combined, sin_combined @@ -1306,6 +1338,7 @@ def __init__(self, config: IsaacConfig): self.max_sequence_length = config.max_sequence_length self.vision_rescale_factor = config.vision_rescale_factor self.vision_token = config.vision_token + self.rope_deltas = None self.post_init() @@ -1348,6 +1381,7 @@ def embed_packed_inputs( - vision_token_grids: (num_images, 2) token grid sizes or None - vision_token_offsets: (num_images,) offsets into each image's virtual token span (optional) - vision_token_lengths: (num_images,) surviving virtual token lengths per image (optional) + - vision_token_batch_indices: (num_images,) batch row for each image (optional; defaults to zeros) """ modality = packed_inputs["modality_tensor"].to(device=input_ids.device, dtype=torch.long) embeds = self.text_model.embed_tokens(input_ids) @@ -1360,42 +1394,103 @@ def embed_packed_inputs( vision = self.vision_embedding((vision_patches, token_grids)) # (total_tokens, hidden) # per-image token counts AFTER pixel-shuffle - s = int(self.config.vision_config.pixel_shuffle_scale_factor) - sizes = token_grids.prod(-1).div(s * s, rounding_mode="floor").tolist() + vision_reduction_factor = int(self.config.vision_config.pixel_shuffle_scale_factor) + sizes = ( + token_grids.prod(-1).div(vision_reduction_factor * vision_reduction_factor, rounding_mode="floor").tolist() + ) offsets = packed_inputs.get("vision_token_offsets") lengths = packed_inputs.get("vision_token_lengths") + batch_indices = packed_inputs.get("vision_token_batch_indices") + + chunks = vision.split(sizes, dim=0) + picked: list[torch.Tensor] = [] + picked_batch: list[int] = [] + for chunk, size, offset, length, batch_index in zip( + chunks, + sizes, + offsets.tolist(), + lengths.tolist(), + (batch_indices.tolist() if batch_indices is not None else [0] * len(sizes)), + ): + if size <= 0: + continue + offset = max(0, min(int(offset), size)) + length = max(0, min(int(length), size - offset)) + if length: + picked.append(chunk[offset : offset + length]) + picked_batch.append(int(batch_index)) + if picked: + vision_chunks = picked + vision_batch_idx = picked_batch + else: + vision_chunks = vision_batch_idx = [] - if offsets is not None or lengths is not None: - off = ( - offsets.to(device=vision.device, dtype=torch.long) - if offsets is not None - else torch.zeros(len(sizes), device=vision.device, dtype=torch.long) - ) - ln = ( - lengths.to(device=vision.device, dtype=torch.long) - if lengths is not None - else torch.tensor(sizes, device=vision.device, dtype=torch.long) - ) - - # Honor per-image crop windows (after pixel shuffle) so we only splice back - # the surviving virtual tokens instead of the full vision span. - chunks = vision.split(sizes, dim=0) - picked: list[torch.Tensor] = [] - for c, n, o, l in zip(chunks, sizes, off.tolist(), ln.tolist()): - if n <= 0: - continue - o = max(0, min(int(o), n)) - l = max(0, min(int(l), n - o)) - if l: - picked.append(c[o : o + l]) - vision = torch.cat(picked, 0) if picked else vision.new_zeros((0, vision.size(-1))) - - m = modality == ModalityType.image.value + vision = torch.cat(vision_chunks, 0) if vision_chunks else vision.new_zeros((0, vision.size(-1))) embeds = embeds.clone() - embeds[m] = vision.to(device=embeds.device, dtype=embeds.dtype) + num_batches = modality.shape[0] + image_positions = [ + (modality[b] == ModalityType.image.value).nonzero(as_tuple=False).squeeze(-1) for b in range(num_batches) + ] + cursors = [0 for _ in range(num_batches)] + + for chunk, batch_index in zip(vision_chunks, vision_batch_idx): + if chunk.numel() == 0: + continue + positions = image_positions[batch_index] + start = cursors[batch_index] + end = start + chunk.shape[0] + embeds[batch_index, positions[start:end]] = chunk.to(device=embeds.device, dtype=embeds.dtype) + cursors[batch_index] = end return embeds, modality + def get_rope_index( + self, + *, + position_ids: Optional[torch.Tensor] = None, + attention_mask: torch.Tensor, + inputs_embeds: torch.Tensor, + cache_position: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Build 3D position ids and per-batch RoPE deltas.""" + + device = inputs_embeds.device + batch_size, seq_len = inputs_embeds.shape[:2] + + if position_ids is None: + cp = cache_position.to(device=device, dtype=torch.long) + if cp.ndim == 1: + cp = cp.view(1, -1).expand(batch_size or 1, -1) + + base_delta = torch.as_tensor( + 0 if self.rope_deltas is None else self.rope_deltas, + device=device, + dtype=torch.long, + ).reshape(-1, 1) + base_delta = torch.broadcast_to(base_delta, (batch_size, 1)) + + mask_delta = attention_mask.to(device=device, dtype=torch.long).sum(1, keepdim=True) - attention_mask.size( + 1 + ) + rope_position = cp + base_delta + mask_delta + pos_3d = rope_position.unsqueeze(-1).expand(-1, -1, 3) + return pos_3d, base_delta + + position_ids = position_ids.to(device=device) + if position_ids.ndim == 2: + position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) + + if position_ids.shape[1] != seq_len: + start_positions = position_ids[:, :1, 0] + position_ids = torch.arange(seq_len, device=position_ids.device).view(1, -1) + start_positions + position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) + + attn = attention_mask.to(device=device, dtype=torch.long) + m_per_batch = position_ids.amax(dim=(1, 2)) + seq_lens = attn.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=device) + rope_deltas = (m_per_batch + 1 - seq_lens).to(dtype=position_ids.dtype).unsqueeze(1) + return position_ids, rope_deltas + @auto_docstring @check_model_inputs def forward( @@ -1417,8 +1512,7 @@ def forward( Args: packed_inputs (`dict`, *optional*): - Plain tensor payloads extracted from a TensorStream. When provided, it replaces the TensorStream path - and requires `input_ids` for text tokens (or `text_token_ids` so `input_ids` can be rebuilt). + Plain tensor payloads. When provided, requires `input_ids` for text tokens (or `text_token_ids` so `input_ids` can be rebuilt). modality_tensor (`torch.LongTensor`, *optional*): Modality identifiers aligned with the embedded sequence, shaped `(batch_size, seq_len)` and containing values from `ModalityType`. Automatically built from `packed_inputs` or treated as text-only when omitted. @@ -1426,58 +1520,47 @@ def forward( output_attentions = kwargs.pop("output_attentions", None) - # Resolve the input source (prefer packed_inputs > ids > embeds). modality_tensor: Optional[torch.Tensor] = None - precomputed_position_ids: Optional[torch.Tensor] = None if packed_inputs is not None: inputs_embeds, modality_tensor = self.embed_packed_inputs(input_ids, packed_inputs) - precomputed_position_ids = packed_inputs.get("position_ids") - if precomputed_position_ids is not None: - precomputed_position_ids = precomputed_position_ids.to(inputs_embeds.device) elif input_ids is not None: inputs_embeds = self.text_model.embed_tokens(input_ids) device = inputs_embeds.device batch_size, seq_len = inputs_embeds.shape[:2] - # Ensure cache exists when requested if use_cache and past_key_values is None: - cache_config = self.config.get_text_config() if hasattr(self.config, "get_text_config") else self.config - past_key_values = DynamicCache(config=cache_config) + past_key_values = DynamicCache(config=self.config.get_text_config()) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=device) if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.long) + attention_mask = torch.ones(inputs_embeds.shape[:2], device=inputs_embeds.device, dtype=torch.long) - position_ids = position_ids if position_ids is not None else precomputed_position_ids - if position_ids is None: - position_ids = cache_position.view(1, -1).expand(batch_size, -1) + if position_ids is None and packed_inputs is not None and packed_inputs.get("position_ids") is not None: + position_ids = packed_inputs.get("position_ids").to(device=device) + + position_ids, rope_deltas = self.get_rope_index( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + ) + self.rope_deltas = rope_deltas if modality_tensor is None: modality_tensor = torch.full( (batch_size, seq_len), ModalityType.text.value, device=device, dtype=torch.long ) - else: - modality_tensor = modality_tensor.to(device=device, dtype=torch.long) - - position_ids = position_ids.to(device=device) - - if position_ids.ndim == 2: - position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) - if position_ids.shape[1] != seq_len: - start_positions = position_ids[:, :1, 0] - position_ids = torch.arange(seq_len, device=device).view(1, -1) + start_positions - position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) cos, sin = self.rotary_emb(position_ids, modality_tensor, hidden_states=inputs_embeds) decoder_position_ids = position_ids[..., 0] if position_ids.ndim == 3 else position_ids - if not isinstance(attention_mask, dict): # Prepare attention mask + if not isinstance(attention_mask, dict): attention_mask = create_masks_for_generate( config=self.config, input_embeds=inputs_embeds, @@ -1522,8 +1605,6 @@ def forward( @auto_docstring class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): - """Isaac multimodal model for conditional generation.""" - config_class = IsaacConfig _can_compile_fullgraph = False _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"} @@ -1531,10 +1612,9 @@ class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): def __init__(self, config: IsaacConfig): super().__init__(config) - self.model = IsaacModel(config) # Use our custom model + self.model = IsaacModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - self.rope_deltas = None @auto_docstring @can_return_tuple @@ -1555,48 +1635,15 @@ def forward( """Run multimodal CausalLM forward, accepting packed vision/text inputs. Args: - input_ids: Text token ids. packed_inputs (`dict`, *optional*): Packed vision/text payload from ``IsaacProcessor`` containing modality ids, MRoPE position ids, and vision patch tensors/grids (with optional offsets/lengths) used to rebuild embeddings. - attention_mask: Attention mask or mask dict; created if not provided. - position_ids: Optional 3D MRoPE positions; auto-derived when absent. - past_key_values: Cache for decoding. - inputs_embeds: Precomputed embeddings (bypass embedding layer). - labels: Target ids for computing language modeling loss. - use_cache: Whether to return caches. - cache_position: Positions for cache-aware generation. Returns: CausalLMOutputWithPast: logits, optional loss, caches, hidden states, attentions. """ output_attentions = kwargs.pop("output_attentions", None) - if position_ids is None and packed_inputs is not None: - pos_3d = packed_inputs.get("position_ids") - if pos_3d is not None: - position_ids, self.rope_deltas = self.get_rope_index( - position_ids=pos_3d, - attention_mask=attention_mask, - ) - - elif position_ids is None and cache_position is not None and self.rope_deltas is not None: - if input_ids is not None: - base_position_ids = torch.arange(input_ids.size(1), device=input_ids.device)[None, :, None].expand( - input_ids.size(0), -1, 3 - ) - else: - batch_size, seq_len = inputs_embeds.shape[:2] - dummy_ids = torch.zeros((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long) - base_position_ids = torch.arange(dummy_ids.size(1), device=dummy_ids.device)[None, :, None].expand( - dummy_ids.size(0), -1, 3 - ) - - rope_delta = (cache_position[0] + self.rope_deltas).to(base_position_ids.device) - if not isinstance(rope_delta, int): - rope_delta = rope_delta.repeat_interleave(base_position_ids.shape[0] // rope_delta.shape[0], dim=0) - position_ids = base_position_ids.add(rope_delta) - outputs = self.model( input_ids=input_ids, packed_inputs=packed_inputs, @@ -1609,10 +1656,8 @@ def forward( cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] logits = self.lm_head(hidden_states) - loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) @@ -1625,53 +1670,6 @@ def forward( attentions=outputs.attentions if output_attentions else None, ) - def set_input_embeddings(self, value: nn.Module) -> None: - self.model.set_input_embeddings(value) - vocab_size = getattr(value, "num_embeddings", None) - self.config.vocab_size = vocab_size - self.model.config.vocab_size = vocab_size - self.model.text_model.config.vocab_size = vocab_size - if self.lm_head.weight.shape[0] != vocab_size: - self.lm_head = nn.Linear(self.config.hidden_size, vocab_size, bias=False) - self.lm_head.weight = self.model.text_model.embed_tokens.weight - - def get_rope_index( - self, - *, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Compute (position_ids_3d, rope_deltas) without TensorStream. - - - If `position_ids` is provided, it must be shape (B, L, 3). - - Else, if `input_ids` is provided, position ids are synthesized as (B, L, 3). - - `rope_deltas` is (B, 1) used to advance positions during decode. - """ - - if position_ids is None: - pos_3d = torch.arange(input_ids.size(1), device=input_ids.device)[None, :, None].expand( - input_ids.size(0), -1, 3 - ) - else: - pos_3d = position_ids - if pos_3d.ndim != 3 or pos_3d.size(-1) != 3: - raise ValueError( - f"`position_ids` must have shape (batch, seq_len, 3) for MRoPE; got shape {tuple(pos_3d.shape)}." - ) - - B, L, _ = pos_3d.shape - m_per_batch = pos_3d.amax(dim=(1, 2)) - - if attention_mask is None: - seq_lens = torch.full((B,), L, device=pos_3d.device, dtype=m_per_batch.dtype) - else: - seq_lens = attention_mask.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=m_per_batch.device) - - rope_deltas = (m_per_batch + 1 - seq_lens).to(dtype=pos_3d.dtype).unsqueeze(1) - return pos_3d, rope_deltas - def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, @@ -1695,8 +1693,8 @@ def prepare_inputs_for_generation( if packed_inputs is None: return model_inputs - cache_position = model_inputs.get("cache_position", cache_position) - first_step = cache_position is None or cache_position[0] == 0 + past_len = past_key_values.get_seq_length() if past_key_values is not None else 0 + first_step = past_len == 0 model_inputs["packed_inputs"] = packed_inputs if first_step else None model_inputs["position_ids"] = None @@ -1706,6 +1704,16 @@ def prepare_inputs_for_generation( def can_generate(cls) -> bool: return True + def set_input_embeddings(self, value: nn.Module) -> None: + self.model.set_input_embeddings(value) + vocab_size = getattr(value, "num_embeddings", None) + self.config.vocab_size = vocab_size + self.model.config.vocab_size = vocab_size + self.model.text_model.config.vocab_size = vocab_size + if self.lm_head.weight.shape[0] != vocab_size: + self.lm_head = nn.Linear(self.config.hidden_size, vocab_size, bias=False) + self.lm_head.weight = self.model.text_model.embed_tokens.weight + __all__ = [ "IsaacConfig", diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 6c75ef572c6b..c7308d98d425 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -31,8 +31,6 @@ if is_torch_available(): import torch - - if is_vision_available(): from PIL.Image import Image else: @@ -84,23 +82,95 @@ def __init__( self.image_processor = image_processor super().__init__(image_processor, tokenizer) + text_pad_token_id = getattr(self.tokenizer, "pad_token_id", None) + image_pad_token_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>") + + self.text_pad_token_id = int(text_pad_token_id) + self.image_pad_token_id = int(image_pad_token_id) + self.pad_token_id = self.text_pad_token_id + self.current_processor = self.image_processor self.config = config self.chat_template = getattr(self.tokenizer, "chat_template", None) self.vision_token = vision_token self.max_sequence_length = max_sequence_length + def _pack_batch( + self, texts: list[str], images_list: Optional[list[Optional[list[Image]]]] + ) -> dict[str, Optional[torch.Tensor]]: + if images_list is None: + pairs = ((t, None) for t in texts) + else: + pairs = zip(texts, images_list, strict=True) + + per_sample: list[dict[str, Optional[torch.Tensor]]] = [] + for txt, imgs in pairs: + if imgs is not None and isinstance(imgs, Image): + imgs = [imgs] + per_sample.append(self._pack_single(txt, imgs)) + + lengths = [int(p["input_ids"].shape[1]) for p in per_sample] + max_len = max(lengths, default=0) + batch = len(per_sample) + + # Use first device with data as anchor + base_device = torch.device("cpu") + for p in per_sample: + if p["input_ids"].numel() > 0: + base_device = p["input_ids"].device + break + + pad_id = self.text_pad_token_id + padded_input_ids = torch.full((batch, max_len), pad_id, device=base_device, dtype=torch.long) + padded_modality = torch.full((batch, max_len), ModalityType.text.value, device=base_device, dtype=torch.long) + padded_position_ids = torch.zeros((batch, max_len, 3), device=base_device, dtype=torch.long) + + for i, (sample, l) in enumerate(zip(per_sample, lengths)): + if l: + padded_input_ids[i, -l:] = sample["input_ids"][0] + padded_modality[i, -l:] = sample["modality_tensor"][0] + padded_position_ids[i, -l:] = sample["position_ids"][0] + + # Vision-side aggregation + v_samples = [(b, s) for b, s in enumerate(per_sample) if s["vision_patches"] is not None] + if v_samples: + vision_patches_list = [s["vision_patches"] for _, s in v_samples] + vision_grids_list = [s["vision_token_grids"] for _, s in v_samples] + vision_offsets_list = [s["vision_token_offsets"] for _, s in v_samples] + vision_lengths_list = [s["vision_token_lengths"] for _, s in v_samples] + vision_batch_indices = [torch.full_like(s["vision_token_offsets"], b) for b, s in v_samples] + + vision_patches = torch.cat(vision_patches_list, dim=0) + vision_token_grids = torch.cat(vision_grids_list, dim=0) + vision_token_offsets = torch.cat(vision_offsets_list, dim=0) + vision_token_lengths = torch.cat(vision_lengths_list, dim=0) + vision_token_batch_indices = torch.cat(vision_batch_indices, dim=0) + else: + vision_patches = vision_token_grids = vision_token_offsets = vision_token_lengths = ( + vision_token_batch_indices + ) = None + + return { + "input_ids": padded_input_ids, + "vision_patches": vision_patches, + "vision_token_grids": vision_token_grids, + "vision_token_offsets": vision_token_offsets, + "vision_token_lengths": vision_token_lengths, + "vision_token_batch_indices": vision_token_batch_indices, + "modality_tensor": padded_modality, + "position_ids": padded_position_ids, + } + def _pack_single(self, text: str, images: Optional[list[Image]]) -> dict[str, Optional[torch.Tensor]]: - # Parse by vision_token; interleave text segments and image segments. - segments = text.split(self.vision_token) + segments = text.split(self.vision_token) # Parse by vision_token; interleave text segments and image segments. num_images = len(segments) - 1 - if num_images and (images is None or len(images) != num_images): - raise ValueError( - f"Expected one image per '{self.vision_token}' token: found {num_images} token(s) but received {0 if images is None else len(images)} image(s)." - ) - items: list[dict[str, Any]] = [] total = 0 + num_provided_images = len(images) if images is not None else 0 + if not num_images == num_provided_images: + raise ValueError( + f"IsaacProcessor expects one image per image token, got {num_images} tokens and {num_provided_images} images in sample with text {text} " + ) for index, segment in enumerate(segments): if segment: @@ -137,7 +207,7 @@ def _pack_single(self, text: str, images: Optional[list[Image]]) -> dict[str, Op start = max(0, total - self.max_sequence_length) end = total - fill_value = self.pad_token_id + image_pad_value = self.image_pad_token_id base_device: Optional[torch.device] = None position_ids, modality, input_ids = [], [], [] vpatches, grids, vision_token_offsets, vision_token_lengths = [], [], [], [] @@ -187,7 +257,7 @@ def _pack_single(self, text: str, images: Optional[list[Image]]) -> dict[str, Op ) ) input_ids.append( - torch.full((segment_kept_length,), fill_value, device=base_device, dtype=torch.long) + torch.full((segment_kept_length,), image_pad_value, device=base_device, dtype=torch.long) ) vpatches.append(item["patches"].to(base_device)) # full patches; slice later via offsets/lengths @@ -204,6 +274,9 @@ def _pack_single(self, text: str, images: Optional[list[Image]]) -> dict[str, Op global_offset += segment_length + if base_device is None: + base_device = torch.device("cpu") + modality_tensor = ( torch.cat(modality, 0).unsqueeze(0) if modality @@ -246,21 +319,28 @@ def __call__( **kwargs, ) -> BatchFeature: texts = [text] if isinstance(text, str) else text - if len(texts) != 1: - raise ValueError( - f"IsaacProcessor currently supports batch_size=1; received {len(texts)} text prompts. Split the batch and call the processor per sample." - ) - - images_list = None + images_list: Optional[list[Optional[list[Image]]]] = None if images is not None: - images_list = [images] if isinstance(images, Image) else images - n_tok = texts[0].count(self.vision_token) - if n_tok != len(images_list): - raise ValueError( - f"Expected {len(images_list)} occurrences of '{self.vision_token}' (one per provided image), but found {n_tok} in the text." - ) - - packed = self._pack_single(texts[0], images_list) + if isinstance(images, list) and len(images) == len(texts): + if not images: + images_list = [] + elif isinstance(images[0], list): + images_list = images # already per-sample + else: + images_list = [[img] for img in images] # list of images, one per sample + else: + images_list = [] + for t in texts: + n_tok = t.count(self.vision_token) + if n_tok == 0: + images_list.append(None) + else: + if isinstance(images, list): + images_list.append(images) + else: + images_list.append([images]) + + packed = self._pack_batch(texts, images_list) input_ids = packed.pop("input_ids") return BatchFeature(data={"input_ids": input_ids, "packed_inputs": packed}) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index b27abaa3c643..90f71a53d70a 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -698,14 +698,6 @@ def is_mamba_2_ssm_available() -> bool: return is_torch_cuda_available() and is_available and version.parse(mamba_ssm_version) >= version.parse("2.0.4") -@lru_cache -def is_perceptron_available() -> bool: - if is_torch_cuda_available() and _is_package_available("perceptron"): - return True - else: - return False - - @lru_cache def is_flash_linear_attention_available(): is_available, fla_version = _is_package_available("fla", return_version=True) diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py index 24437ec57221..66da717cd0b6 100644 --- a/tests/models/isaac/test_modeling_isaac.py +++ b/tests/models/isaac/test_modeling_isaac.py @@ -17,7 +17,9 @@ import base64 import io import os +import re import unittest +from collections import namedtuple from functools import lru_cache from pathlib import Path @@ -50,7 +52,6 @@ torch_device, ) from transformers.utils import is_vision_available -from transformers.utils.import_utils import is_perceptron_available if is_vision_available(): @@ -64,15 +65,74 @@ if is_torch_available(): import torch -if is_perceptron_available(): - from perceptron.pointing.parser import extract_points - from perceptron.tensorstream.tensorstream import TensorStream -else: - TensorStream = None - extract_points = None +SinglePoint = namedtuple("SinglePoint", ["x", "y", "mention", "t"], defaults=(None, None)) +BoundingBox = namedtuple( + "BoundingBox", + ["top_left", "bottom_right", "mention", "t"], + defaults=(None, None), +) + +_POINT_OR_BOX_TAG = re.compile( + r"<(?Ppoint|point_box)(?P[^>]*)>(?P[\s\S]*?)", re.IGNORECASE +) +_ATTR_RE = re.compile(r"(\w+)\s*=\s*(?:\"([^\"]*)\"|([^\s>]+))") +_COORD_RE = re.compile(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)") + + +def _maybe_float(val): + if val is None: + return None + try: + return float(val) + except ValueError: + return None + +def _parse_attrs(attr_text: str) -> dict: + attrs = {} + for match in _ATTR_RE.finditer(attr_text or ""): + key = match.group(1) + val = match.group(2) or match.group(3) or "" + attrs[key] = val + return attrs + + +def _parse_point_body(body: str, mention=None, t=None): + match = _COORD_RE.search(body) + if not match: + raise ValueError(f"Malformed tag: {body!r}") + x, y = int(match.group(1)), int(match.group(2)) + return SinglePoint(x, y, mention, _maybe_float(t)) + + +def _parse_box_body(body: str, mention=None, t=None): + coords = list(_COORD_RE.finditer(body)) + if len(coords) < 2: + raise ValueError(f"Malformed tag: {body!r}") + x1, y1 = int(coords[0].group(1)), int(coords[0].group(2)) + x2, y2 = int(coords[1].group(1)), int(coords[1].group(2)) + return BoundingBox(SinglePoint(x1, y1, None, None), SinglePoint(x2, y2, None, None), mention, _maybe_float(t)) + + +def extract_points(text: str, expected: str | None = None): + """Minimal parser for Isaac pointing tags used in tests.""" + + results = [] + for match in _POINT_OR_BOX_TAG.finditer(text or ""): + tag = match.group("tag").lower() + attrs = _parse_attrs(match.group("attrs")) + mention = attrs.get("mention") + t = attrs.get("t") + if tag == "point": + if expected not in (None, "point"): + continue + results.append(_parse_point_body(match.group("body"), mention=mention, t=t)) + elif tag == "point_box": + if expected not in (None, "box"): + continue + results.append(_parse_box_body(match.group("body"), mention=mention, t=t)) + return results -require_tensorstream = pytest.mark.skipif(TensorStream is None, reason="TensorStream backend is not available") BASE_MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1-Base") MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1") @@ -151,6 +211,28 @@ def _rounded(value: torch.Tensor | float) -> float: } +def infer_pad_from_tail(sequence: torch.Tensor) -> tuple[int | None, int]: + """ + Infer the pad value used in a 1D sequence by scanning the repeated tail. + + Returns (pad_value or None if no padding detected, last_nonpad_index). + """ + + if sequence.ndim != 1: + raise ValueError("sequence must be 1D") + + pad_candidate = sequence[-1].item() + idx = sequence.shape[0] - 1 + while idx >= 0 and sequence[idx].item() == pad_candidate: + idx -= 1 + + if idx == sequence.shape[0] - 1: + return None, idx + if idx < 0: + return pad_candidate, -1 + return pad_candidate, idx + + def create_isaac_processor( tokenizer, isaac_config, @@ -427,7 +509,6 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): def test_retain_grad_hidden_states_attentions(self): pass - @require_tensorstream def test_model_forward(self): config, input_ids, attention_mask, _ = self.model_tester.prepare_config_and_inputs() model = IsaacModel(config) @@ -441,7 +522,6 @@ def test_model_forward(self): (self.model_tester.batch_size, self.model_tester.seq_length, config.hidden_size), ) - @require_tensorstream def test_for_conditional_generation(self): config, input_ids, attention_mask, labels = self.model_tester.prepare_config_and_inputs() model = IsaacForConditionalGeneration(config) @@ -456,7 +536,6 @@ def test_for_conditional_generation(self): ) self.assertIsNotNone(result.loss) - @require_tensorstream def test_isaac_for_conditional_generation_initialization(self): config = self.model_tester.get_config() model = IsaacForConditionalGeneration(config) @@ -471,7 +550,6 @@ def test_isaac_for_conditional_generation_initialization(self): outputs = model(input_ids=input_ids, return_dict=True) self.assertEqual(outputs.logits.shape, (1, 10, config.vocab_size)) - @require_tensorstream def test_isaac_for_conditional_generation_loss_and_generate_flag(self): config = self.model_tester.get_config() model = IsaacForConditionalGeneration(config).to(torch_device) @@ -597,120 +675,9 @@ def test_flash_attention_parity_with_sdpa_bf16(self): ) -@require_torch -@require_flash_attn -class IsaacAttentionDtypeTest(unittest.TestCase): - def _make_config(self): - return IsaacVisionConfig( - hidden_size=32, - intermediate_size=64, - num_hidden_layers=1, - num_attention_heads=4, - num_channels=3, - num_patches=64, - patch_size=4, - attention_dropout=0.0, - pixel_shuffle_scale_factor=1, - ) - - def _skip_if_no_cuda_bf16(self): - if not torch.cuda.is_available(): - pytest.skip("CUDA required for flash attention dtype/parity tests.") - if not torch.cuda.is_bf16_supported(): - pytest.skip("CUDA bfloat16 support required.") - - def test_flash_attention_matches_weight_dtype_bf16(self): - self._skip_if_no_cuda_bf16() - torch.manual_seed(0) - - device = torch.device("cuda") - config = self._make_config() - config._attn_implementation = "flash_attention_2" - - attn = IsaacVisionAttention(config).to(device=device, dtype=torch.bfloat16).eval() - - hidden_states = torch.randn(2, 4, config.hidden_size, device=device, dtype=torch.bfloat16) - - with torch.no_grad(): - attn_output, _ = attn(hidden_states) - - assert attn_output.dtype == attn.out_proj.weight.dtype - assert attn_output.dtype == hidden_states.dtype - - def test_flash_attention_matches_weight_dtype_bf16_with_padding(self): - self._skip_if_no_cuda_bf16() - torch.manual_seed(0) - - device = torch.device("cuda") - config = self._make_config() - config._attn_implementation = "flash_attention_2" - - attn = IsaacVisionAttention(config).to(device=device, dtype=torch.bfloat16).eval() - - hidden_states = torch.randn(2, 4, config.hidden_size, device=device, dtype=torch.bfloat16) - attention_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]], device=device, dtype=torch.bool) - - with torch.no_grad(): - attn_output, _ = attn(hidden_states, attention_mask=attention_mask) - - assert attn_output.dtype == attn.out_proj.weight.dtype - assert attn_output.dtype == hidden_states.dtype - - def test_flash_attention_matches_weight_dtype_bf16_with_cu_seqlens(self): - self._skip_if_no_cuda_bf16() - torch.manual_seed(0) - - device = torch.device("cuda") - config = self._make_config() - config._attn_implementation = "flash_attention_2" - - attn = IsaacVisionAttention(config).to(device=device, dtype=torch.bfloat16).eval() - - hidden_states = torch.randn(1, 5, config.hidden_size, device=device, dtype=torch.bfloat16) - cu_seqlens = torch.tensor([0, 3, 5], device=device, dtype=torch.int32) - - with torch.no_grad(): - attn_output, _ = attn(hidden_states, cu_seqlens=cu_seqlens, max_seqlen=3) - - assert attn_output.dtype == attn.out_proj.weight.dtype - assert attn_output.dtype == hidden_states.dtype - - def test_flash_attention_parity_with_sdpa_bf16(self): - self._skip_if_no_cuda_bf16() - torch.manual_seed(0) - - device = torch.device("cuda") - config_sdpa = self._make_config() - config_sdpa._attn_implementation = "sdpa" - - config_fa2 = self._make_config() - config_fa2._attn_implementation = "flash_attention_2" - - attn_sdpa = IsaacVisionAttention(config_sdpa).to(device=device, dtype=torch.bfloat16).eval() - attn_fa2 = IsaacVisionAttention(config_fa2).to(device=device, dtype=torch.bfloat16).eval() - - # Align weights so the only difference is the backend - attn_fa2.load_state_dict(attn_sdpa.state_dict()) - - hidden_states = torch.randn(2, 4, config_sdpa.hidden_size, device=device, dtype=torch.bfloat16) - - with torch.no_grad(): - out_sdpa, _ = attn_sdpa(hidden_states) - out_fa2, _ = attn_fa2(hidden_states) - - torch.testing.assert_close( - out_fa2.float(), - out_sdpa.float(), - rtol=1e-3, - atol=1e-3, - msg="FlashAttention2 output deviates from SDPA baseline beyond tolerance", - ) - - @require_torch @require_vision @slow -@require_tensorstream @require_flash_attn class IsaacGenerationIntegrationTest(unittest.TestCase): max_new_tokens = 25 @@ -737,6 +704,13 @@ def _generate_from_messages(self, messages, images, num_tokens=None): processor_output = self.processor(text=prompt, images=images, return_tensors="pt") packed_inputs = processor_output["packed_inputs"] input_ids = processor_output["input_ids"].to(self.device) + attention_mask = processor_output.get("attention_mask") + if attention_mask is None: + pad_id = self.tokenizer.pad_token_id + if pad_id is None: + pad_id = getattr(self.processor, "pad_token_id", 0) + attention_mask = processor_output["input_ids"].ne(pad_id).long() + attention_mask = attention_mask.to(self.device) prompt_len = input_ids.shape[1] packed_inputs = { key: (value.to(self.device) if isinstance(value, torch.Tensor) else value) @@ -746,6 +720,7 @@ def _generate_from_messages(self, messages, images, num_tokens=None): with torch.no_grad(): outputs = self.model.generate( input_ids=input_ids, + attention_mask=attention_mask, packed_inputs=packed_inputs, max_new_tokens=num_tokens or self.max_new_tokens, do_sample=False, @@ -804,6 +779,47 @@ def test_vqa_from_image(self): expected_response = "\nNo, it is not safe to cross the street at this moment. The traffic light for pedestrians is red, indicating that it is not safe to cross." assert generated_text == expected_response + def _generate_batch(self, prompts, images_list, num_tokens=None): + processor_output = self.processor(text=prompts, images=images_list, return_tensors="pt") + input_ids = processor_output["input_ids"] + if input_ids.dim() == 1: + input_ids = input_ids.unsqueeze(0) + + # Use processor-provided attention_mask if available; otherwise fallback. + attention_mask = processor_output.get("attention_mask", None) + if attention_mask is None: + pad_id = self.tokenizer.pad_token_id + if pad_id is None: + pad_id = getattr(self.processor, "pad_token_id", 0) + attention_mask = input_ids.ne(pad_id).long() + + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + + packed_inputs = { + k: (v.to(self.device) if isinstance(v, torch.Tensor) else v) + for k, v in processor_output["packed_inputs"].items() + } + + with torch.no_grad(): + outputs = self.model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + packed_inputs=packed_inputs, + max_new_tokens=num_tokens or self.max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + ) + sequences = outputs.sequences + generated_texts = [] + for i in range(sequences.shape[0]): + tail_ids = sequences[i, :] # only newly generated tokens + generated_texts.append(self.tokenizer.decode(tail_ids, skip_special_tokens=True)) + + return generated_texts + def test_logit_equivalence(self): image = _load_red_dot_image() if image is None: @@ -852,11 +868,154 @@ def test_logit_equivalence(self): } assert logit_stats == expected_logit_stats + def test_batched_generation_matches_individual(self): + # Build individual scenarios matching existing integration tests + red_image = _load_red_dot_image() + if red_image is None: + pytest.skip("PIL.Image is required for Isaac generation tests.") + + vqa_document = [ + { + "type": "image", + "content": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", + "role": "user", + }, + { + "type": "text", + "content": "Is it safe to cross the street at this moment?", + "role": "user", + }, + ] + + # Text-only + doc_text_only = [{"type": "text", "content": "What is the pythogorean theorem?", "role": "user"}] + messages_text_only, images_text_only = document_to_messages(doc_text_only) + single_text_only = self._generate_from_messages( + messages_text_only, images_text_only, num_tokens=self.max_new_tokens + ) + assert single_text_only, "Text-only single generation is empty" + + # Image + text + messages_image_text = [ + {"role": "user", "content": "Describe this image:"}, + {"role": "user", "content": ""}, + ] + single_image_text = self._generate_from_messages(messages_image_text, [red_image]) + assert single_image_text, "Image-text single generation is empty" + + # VQA + messages_vqa, images_vqa = document_to_messages(vqa_document) + single_vqa = self._generate_from_messages(messages_vqa, images_vqa, num_tokens=self.max_new_tokens) + assert single_vqa, "VQA single generation is empty" + + single_texts = [single_text_only, single_image_text, single_vqa] + + # Build batch inputs + prompts = [ + self.processor.apply_chat_template(messages_text_only, tokenize=False, add_generation_prompt=True).strip(), + self.processor.apply_chat_template( + messages_image_text, tokenize=False, add_generation_prompt=True + ).strip(), + self.processor.apply_chat_template(messages_vqa, tokenize=False, add_generation_prompt=True).strip(), + ] + images_list = [images_text_only, [red_image], images_vqa] + + # Input-level sanity + assert len(prompts) == len(images_list) == 3 + for i, (p, imgs) in enumerate(zip(prompts, images_list)): + expected_tokens = p.count(self.hf_config.vision_token) + num_imgs = len(imgs) + assert expected_tokens == num_imgs, ( + f"sample {i} vision token/image mismatch: {expected_tokens} vs {num_imgs}" + ) + + pad_id = self.tokenizer.pad_token_id + if pad_id is None: + pad_id = getattr(self.processor, "pad_token_id", 0) + + per_sample_outputs = [ + self.processor(text=prompt, images=imgs, return_tensors="pt") for prompt, imgs in zip(prompts, images_list) + ] + batch_outputs = self.processor(text=prompts, images=images_list, return_tensors="pt") + batch_input_ids = batch_outputs["input_ids"] + batch_packed = batch_outputs["packed_inputs"] + + sample_lengths = [output["input_ids"].squeeze(0).shape[0] for output in per_sample_outputs] + max_length = max(sample_lengths) + + expected_vision_patches = [] + expected_vision_grids = [] + expected_vision_offsets = [] + expected_vision_lengths = [] + expected_vision_batch_indices = [] + + for i, (single_output, batch_ids, single_len) in enumerate( + zip(per_sample_outputs, batch_input_ids, sample_lengths) + ): + single_ids = single_output["input_ids"].squeeze(0) + single_packed = single_output["packed_inputs"] + + torch.testing.assert_close(batch_ids[-single_len:], single_ids) + + batch_modality_row = batch_packed["modality_tensor"][i] + expected_modality = torch.full( + (max_length,), + batch_modality_row[-1].item(), + dtype=batch_modality_row.dtype, + device=batch_modality_row.device, + ) + expected_modality[-single_len:] = single_packed["modality_tensor"].squeeze(0) + torch.testing.assert_close(batch_modality_row, expected_modality) + + batch_positions_row = batch_packed["position_ids"][i] + expected_positions = torch.zeros( + (max_length, 3), dtype=batch_positions_row.dtype, device=batch_positions_row.device + ) + expected_positions[-single_len:] = single_packed["position_ids"].squeeze(0) + torch.testing.assert_close(batch_positions_row, expected_positions) + + if single_packed["vision_patches"] is not None: + expected_vision_patches.append(single_packed["vision_patches"]) + expected_vision_grids.append(single_packed["vision_token_grids"]) + expected_vision_offsets.append(single_packed["vision_token_offsets"]) + expected_vision_lengths.append(single_packed["vision_token_lengths"]) + expected_vision_batch_indices.append(torch.full_like(single_packed["vision_token_batch_indices"], i)) + + if single_len == max_length: + continue + + pad_span = batch_ids[: max_length - single_len] + assert torch.all(pad_span == pad_id), f"sample {i} left pad span not padded with pad id" + + attention_mask = batch_ids.ne(pad_id).long() + assert not torch.any(attention_mask[: max_length - single_len]), f"sample {i} mask ones inside left pad" + assert torch.all(attention_mask[-single_len:]), f"sample {i} mask zeros inside content" + + if expected_vision_patches: + torch.testing.assert_close(batch_packed["vision_patches"], torch.cat(expected_vision_patches, dim=0)) + torch.testing.assert_close(batch_packed["vision_token_grids"], torch.cat(expected_vision_grids, dim=0)) + torch.testing.assert_close(batch_packed["vision_token_offsets"], torch.cat(expected_vision_offsets, dim=0)) + torch.testing.assert_close(batch_packed["vision_token_lengths"], torch.cat(expected_vision_lengths, dim=0)) + torch.testing.assert_close( + batch_packed["vision_token_batch_indices"], torch.cat(expected_vision_batch_indices, dim=0) + ) + else: + assert batch_packed["vision_patches"] is None + assert batch_packed["vision_token_grids"] is None + assert batch_packed["vision_token_offsets"] is None + assert batch_packed["vision_token_lengths"] is None + assert batch_packed["vision_token_batch_indices"] is None + + batch_texts = self._generate_batch(prompts, images_list, num_tokens=100) + assert len(batch_texts) == len(single_texts) == 3 + + for i, (btxt, stxt) in enumerate(zip(batch_texts, single_texts)): + assert stxt in btxt, f"batch[{i}] mismatch: {btxt!r} vs single[{i}] {stxt!r}" + @require_torch @require_vision @slow -@require_tensorstream @require_flash_attn class IsaacBoxPointingIntegrationTest(unittest.TestCase): max_new_tokens = 256 diff --git a/tests/models/isaac/test_processing_isaac.py b/tests/models/isaac/test_processing_isaac.py index 41e28bf253e0..944ab62645ad 100644 --- a/tests/models/isaac/test_processing_isaac.py +++ b/tests/models/isaac/test_processing_isaac.py @@ -14,15 +14,22 @@ """Testing suite for the Isaac processor.""" +import os +import unittest +from pathlib import Path + import pytest import torch +from huggingface_hub import is_offline_mode from transformers import IsaacConfig, PythonBackend +from transformers.image_processing_utils import ImageProcessingMixin from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessorFast from transformers.models.isaac.modeling_isaac import ModalityType from transformers.models.isaac.processing_isaac import IsaacProcessor from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_vision_available +from transformers.utils.generic import TensorType if is_vision_available(): @@ -124,6 +131,7 @@ def _assert_common(outputs): "vision_token_grids", "vision_token_offsets", "vision_token_lengths", + "vision_token_batch_indices", "modality_tensor", "position_ids", } @@ -148,6 +156,7 @@ def _assert_no_vision(packed_inputs): assert packed_inputs["vision_token_grids"] is None assert packed_inputs["vision_token_offsets"] is None assert packed_inputs["vision_token_lengths"] is None + assert packed_inputs["vision_token_batch_indices"] is None def _assert_vision_segments(packed_inputs, expected_segments): @@ -155,10 +164,12 @@ def _assert_vision_segments(packed_inputs, expected_segments): assert packed_inputs["vision_token_grids"] is not None assert packed_inputs["vision_token_offsets"] is not None assert packed_inputs["vision_token_lengths"] is not None + assert packed_inputs["vision_token_batch_indices"] is not None assert packed_inputs["vision_token_grids"].shape[0] == expected_segments assert packed_inputs["vision_token_offsets"].shape == (expected_segments,) assert packed_inputs["vision_token_lengths"].shape == (expected_segments,) + assert packed_inputs["vision_token_batch_indices"].shape == (expected_segments,) def _count_modality(packed_inputs, modality_value): @@ -166,6 +177,23 @@ def _count_modality(packed_inputs, modality_value): return int((modality == modality_value).sum().item()) +def _pad_to_max(tensors: list[torch.Tensor], pad_value: int) -> torch.Tensor: + """Pad a list of (L, ...) tensors to (B, L_max, ...).""" + max_len = max(t.shape[0] for t in tensors) + batch = len(tensors) + if tensors[0].ndim == 1: + out = torch.full((batch, max_len), pad_value, device=tensors[0].device, dtype=tensors[0].dtype) + for i, t in enumerate(tensors): + out[i, : t.shape[0]] = t + return out + # assume (L, K) + k = tensors[0].shape[1] + out = torch.full((batch, max_len, k), pad_value, device=tensors[0].device, dtype=tensors[0].dtype) + for i, t in enumerate(tensors): + out[i, : t.shape[0]] = t + return out + + def _get_image_token_length(processor, image, vision_token): outputs = _run_processor(processor, text=vision_token, images=[image]) _, packed = _assert_common(outputs) @@ -302,7 +330,7 @@ def test_error_on_image_mismatch(isaac_processor): text = f"{vision_token} {vision_token}" image = _make_dummy_image() - with pytest.raises(ValueError, match="occurrences of"): + with pytest.raises(ValueError, match="one image per"): _run_processor(isaac_processor, text=text, images=[image]) @@ -461,3 +489,181 @@ def test_crop_removes_all_vision_when_window_excludes_images(isaac_processor, is _assert_no_vision(packed) assert input_ids.shape[1] == tail_tokens assert _count_modality(packed, ModalityType.text.value) == tail_tokens + + +@require_torch +@require_vision +def test_batch_outputs_match_individual_calls(isaac_processor): + texts = ["hi", "this one is longer"] + + per_sample = [_run_processor(isaac_processor, text=t, images=None) for t in texts] + batch_outputs = _run_processor(isaac_processor, text=texts, images=None) + + assert set(batch_outputs.keys()) == {"input_ids", "packed_inputs"} + batch_input_ids = batch_outputs["input_ids"] + batch_packed = batch_outputs["packed_inputs"] + + assert set(batch_packed.keys()) == { + "vision_patches", + "vision_token_grids", + "vision_token_offsets", + "vision_token_lengths", + "vision_token_batch_indices", + "modality_tensor", + "position_ids", + } + + assert batch_input_ids.shape[0] == len(texts) + assert batch_packed["modality_tensor"].shape[0] == len(texts) + assert batch_packed["position_ids"].shape[0] == len(texts) + + sample_lengths = [output["input_ids"].squeeze(0).shape[0] for output in per_sample] + max_length = max(sample_lengths) + pad_id = isaac_processor.pad_token_id + + for i, (single_output, batch_ids, single_len) in enumerate(zip(per_sample, batch_input_ids, sample_lengths)): + single_ids = single_output["input_ids"].squeeze(0) + single_packed = single_output["packed_inputs"] + + torch.testing.assert_close(batch_ids[-single_len:], single_ids) + + batch_modality_row = batch_packed["modality_tensor"][i] + expected_modality = torch.full( + (max_length,), + batch_modality_row[-1].item(), + dtype=batch_modality_row.dtype, + device=batch_modality_row.device, + ) + expected_modality[-single_len:] = single_packed["modality_tensor"].squeeze(0) + torch.testing.assert_close(batch_modality_row, expected_modality) + + batch_positions_row = batch_packed["position_ids"][i] + expected_positions = torch.zeros( + (max_length, 3), dtype=batch_positions_row.dtype, device=batch_positions_row.device + ) + expected_positions[-single_len:] = single_packed["position_ids"].squeeze(0) + torch.testing.assert_close(batch_positions_row, expected_positions) + + if single_len == max_length: + continue + + pad_span = batch_ids[: max_length - single_len] + assert torch.all(pad_span == pad_id) + + attention_mask = batch_ids.ne(pad_id).long() + assert not torch.any(attention_mask[: max_length - single_len]) + assert torch.all(attention_mask[-single_len:]) + + _assert_no_vision(batch_packed) + + +class StubTokenizer(SimpleIsaacTokenizer): + def __init__(self): + super().__init__() + self._base = 2000 + + def encode(self, text, add_special_tokens=False, return_tensors=None): + token_ids = torch.tensor([self._base + b for b in text.encode("utf-8")], dtype=torch.long) + if return_tensors in {"pt", TensorType.PYTORCH}: + return token_ids.unsqueeze(0) + return token_ids + + def convert_tokens_to_ids(self, token): + if token == "<|image_pad|>": + return 151655 + if token == self.pad_token: + return super().convert_tokens_to_ids(token) + return None + + +class StubImageProcessor(ImageProcessingMixin): + def __call__(self, images=None, return_tensors=None): + patches = torch.ones((1, 2, 2, 3), dtype=torch.float32) + sizes = torch.tensor([[1, 2, 2]], dtype=torch.long) + return { + "patches": patches, + "virtual_pixel_size": sizes, + "real_pixel_size": sizes, + } + + +BASE_MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1-Base") +BASE_MODEL_REVISION = os.environ.get("ISAAC_TEST_MODEL_REVISION", "refs/pr/3") or None +LOCAL_CHECKPOINT = os.environ.get("ISAAC_TEST_MODEL_PATH") + + +def _checkpoint_or_skip(): + if LOCAL_CHECKPOINT: + resolved = Path(LOCAL_CHECKPOINT).expanduser() + if not resolved.exists(): + pytest.skip(f"Local checkpoint path {resolved} does not exist.") + return str(resolved) + if is_offline_mode(): + pytest.skip("Offline mode: set ISAAC_TEST_MODEL_PATH to a local checkpoint to run these tests.") + return BASE_MODEL_ID + + +def _create_real_processor(): + checkpoint = _checkpoint_or_skip() + config = IsaacConfig.from_pretrained(checkpoint, revision=BASE_MODEL_REVISION) + processor = IsaacProcessor.from_pretrained(checkpoint, revision=BASE_MODEL_REVISION) + tokenizer = processor.tokenizer + return processor, tokenizer, config + + +@require_torch +@require_vision +class TestIsaacProcessorRealPadding(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.processor, cls.tokenizer, cls.config = _create_real_processor() + cls.dummy_image = _make_dummy_image() + cls.vision_token = cls.config.vision_token + cls.pad_id = cls.tokenizer.pad_token_id + cls.image_pad_id = cls.tokenizer.convert_tokens_to_ids("<|image_pad|>") + if cls.pad_id is None or cls.image_pad_id is None: + pytest.skip("pad/image pad ids unavailable for processor") + + def _check_padding_and_masks(self, input_ids: torch.Tensor, pad_id: int): + for row in range(input_ids.size(0)): + row_ids = input_ids[row] + nonpad_positions = (row_ids != pad_id).nonzero(as_tuple=False) + last_nonpad = int(nonpad_positions.max()) if nonpad_positions.numel() else -1 + if last_nonpad + 1 < row_ids.numel(): + tail = row_ids[last_nonpad + 1 :] + assert torch.all(tail == pad_id) + attn = (row_ids != pad_id).long() + if last_nonpad >= 0: + assert torch.all(attn[: last_nonpad + 1] == 1) + assert int(attn[last_nonpad + 1 :].sum()) == 0 + + def test_single_vs_batched_consistency(self): + prompt = f"hello {self.vision_token} world" + images_single = [self.dummy_image] + + single = self.processor(text=prompt, images=images_single, return_tensors="pt") + single_ids = single["input_ids"].squeeze(0) + + batch_prompts = [prompt, "short"] + batch_images = [images_single, None] + batch = self.processor(text=batch_prompts, images=batch_images, return_tensors="pt") + batch_ids = batch["input_ids"][0] + modality = batch["packed_inputs"]["modality_tensor"][0] + + assert torch.equal(batch_ids[: single_ids.size(0)], single_ids) + + image_positions = modality == ModalityType.image.value + if image_positions.any(): + assert torch.all(batch_ids[image_positions] == self.image_pad_id) + assert torch.all(batch_ids[image_positions] != self.pad_id) + + nonpad = (batch_ids != self.pad_id).nonzero(as_tuple=False) + last_nonpad = int(nonpad.max()) if nonpad.numel() else -1 + if last_nonpad + 1 < batch_ids.numel(): + tail = batch_ids[last_nonpad + 1 :] + assert torch.all(tail == self.pad_id) + + attn = (batch_ids != self.pad_id).long() + if last_nonpad >= 0: + assert torch.all(attn[: last_nonpad + 1] == 1) + assert int(attn[last_nonpad + 1 :].sum()) == 0 From 2884211dc097b1f0208a2e0bc2c3006f6b022070 Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Fri, 9 Jan 2026 18:46:28 +0400 Subject: [PATCH 60/77] Update src/transformers/models/isaac/modular_isaac.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/isaac/modular_isaac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 1d561ddca798..2adc54916a3e 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -1255,7 +1255,7 @@ def __call__( return BatchFeature(data={"input_ids": input_ids, "packed_inputs": packed}) -class IsaacRotaryEmbedding(qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding): +class IsaacRotaryEmbedding(Qwen2_5_VLRotaryEmbedding): def __init__(self, config: IsaacConfig, device=None): rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} From fdbd63331ed7dae3de93051f72fd53549e968c28 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Wed, 14 Jan 2026 20:19:57 +0400 Subject: [PATCH 61/77] Squash merge into main --- .../models/isaac/configuration_isaac.py | 2 + .../isaac/image_processing_isaac_fast.py | 12 +- .../models/isaac/modeling_isaac.py | 291 ++++++++++-------- .../models/isaac/modular_isaac.py | 233 ++++++-------- 4 files changed, 263 insertions(+), 275 deletions(-) diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index fd6820ad83af..e98e9a9da0f2 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -123,6 +123,8 @@ def __init__( self.hidden_act = self.text_config.hidden_act self.use_cache = self.text_config.use_cache self.rope_theta = self.rope_parameters["rope_theta"] + self.max_position_embeddings = getattr(self.text_config, "max_position_embeddings", max_sequence_length) + self.text_config.max_position_embeddings = self.max_position_embeddings self.layer_types = getattr(self.text_config, "layer_types", None) layer_type_validation(self.layer_types, self.num_hidden_layers) diff --git a/src/transformers/models/isaac/image_processing_isaac_fast.py b/src/transformers/models/isaac/image_processing_isaac_fast.py index caf35d3eacc6..c1de230c635e 100644 --- a/src/transformers/models/isaac/image_processing_isaac_fast.py +++ b/src/transformers/models/isaac/image_processing_isaac_fast.py @@ -299,18 +299,14 @@ def _preprocess( ) grouped_outputs[shape] = (patches, token_grid, virtual_dim, real_dim) - def _reorder_grouped_item( # reorder an item of tuple payloads using the same grouped_images_index - grouped: dict[tuple[int, ...], tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], - grouped_index: dict[tuple[int, ...], list[int]], - item_idx: int, - ) -> list[torch.Tensor]: - return reorder_images({k: v[item_idx] for k, v in grouped.items()}, grouped_index) - keys = ("patches", "token_grids", "virtual_pixel_size", "real_pixel_size") tensors: dict[str, torch.Tensor] = {} for i, key in enumerate(keys): - slices = _reorder_grouped_item(grouped_outputs, grouped_images_index, i) + slices = reorder_images( + {shape: values[i] for shape, values in grouped_outputs.items()}, + grouped_images_index, + ) tensors[key] = torch.stack(slices, dim=0) return BatchFeature(data=tensors, tensor_type=return_tensors) diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 528706593a4b..156017d32d6d 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -23,11 +23,11 @@ import copy from collections.abc import Callable from enum import IntEnum -from typing import Any, Optional, Union +from typing import Any, Optional +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache -from ...configuration_utils import PretrainedConfig from ...generation.utils import GenerationMixin from ...image_processing_utils_fast import ImagesKwargs from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func @@ -35,16 +35,16 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...models.qwen3.modeling_qwen3 import Qwen3Model, Qwen3PreTrainedModel +from ...models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3Model, Qwen3PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring -from ...utils.generic import OutputRecorder, TransformersKwargs, can_return_tuple, check_model_inputs +from ...utils.generic import OutputRecorder, TransformersKwargs, can_return_tuple, check_model_inputs, maybe_autocast from ...utils.import_utils import ( is_torch_available, is_torchdynamo_compiling, ) -from ..qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling from .configuration_isaac import IsaacConfig, IsaacVisionConfig @@ -95,7 +95,14 @@ def __init__(self, config: IsaacVisionConfig): self.num_patches = config.num_patches self.position_embedding_size = int(self.num_patches**0.5) - self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + self.position_embedding = nn.Parameter( + torch.empty( + self.position_embedding_size, + self.position_embedding_size, + self.embed_dim, + ) + ) + nn.init.normal_(self.position_embedding) @staticmethod def resize_positional_embeddings( @@ -174,11 +181,7 @@ def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> to target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(packed_pixel_values.to(dtype=target_dtype)) - positional_embeddings = self.position_embedding.weight.reshape( - self.position_embedding_size, - self.position_embedding_size, - -1, - ) + positional_embeddings = self.position_embedding resized_positional_embeddings = self.resize_positional_embeddings( positional_embeddings, spatial_shapes, max_length=packed_pixel_values.shape[1] ) @@ -221,7 +224,7 @@ def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor class IsaacVisionAttention(nn.Module): - """Custom attention that supports variable-length sequences with flash attention.""" + """Custom attention that supports variable-length sequences with flash/SDPA backends.""" def __init__(self, config): super().__init__() @@ -248,14 +251,9 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" - kwargs.pop("output_hidden_states", None) - kwargs.pop("return_dict", None) - batch_size, seq_length, embed_dim = hidden_states.shape queries = self.q_proj(hidden_states) keys = self.k_proj(hidden_states) @@ -270,35 +268,47 @@ def forward( if attn_impl != "sdpa": attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] + seq_sizes = kwargs.pop("seq_sizes", None) + attention_kwargs: dict[str, Any] = { "is_causal": False, "scaling": self.scale, } - supports_varlen = cu_seqlens is not None and attn_impl in { - "flash_attention_2", - "flash_attention_3", - "flex_attention", - "paged|flash_attention_2", - "paged|flash_attention_3", - } - if supports_varlen: - if max_seqlen is not None: - max_q = max_k = int(max_seqlen) - elif cu_seqlens.numel() >= 2: - lengths = cu_seqlens[1:] - cu_seqlens[:-1] - max_q = max_k = lengths.max() if lengths.numel() > 0 else seq_length + if seq_sizes is not None and seq_sizes.numel() > 0: + if attn_impl in {"flash_attention_2", "flash_attention_3"}: + cu_seqlens = F.pad(seq_sizes.cumsum(0).to(torch.int32), (1, 0)) + max_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attention_kwargs.update( + { + "cu_seq_lens_q": cu_seqlens, + "cu_seq_lens_k": cu_seqlens, + "max_length_q": max_len, + "max_length_k": max_len, + } + ) else: - max_q = max_k = seq_length - - attention_kwargs.update( - { - "cu_seq_lens_q": cu_seqlens, - "cu_seq_lens_k": cu_seqlens, - "max_length_q": max_q, - "max_length_k": max_k, - } - ) + seg_ids = torch.repeat_interleave( + torch.arange(seq_sizes.numel(), device=seq_sizes.device), seq_sizes + ).view(1, -1) + mask_function = packed_sequence_mask_function(seg_ids) + cache_position = torch.arange(seq_length, device=hidden_states.device, dtype=torch.long) + mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[attn_impl] + attention_mask = mask_interface( + batch_size=batch_size, + cache_position=cache_position, + kv_length=seq_length, + kv_offset=0, + mask_function=mask_function, + attention_mask=attention_mask, + allow_is_causal_skip=False, + allow_is_bidirectional_skip=False, + dtype=hidden_states.dtype, + config=self.config, + use_vmap=False, + ) + else: + attention_mask = None attn_output, attn_weights = attention_interface( self, @@ -345,18 +355,11 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, output_attentions: bool = False, **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: r""" - cu_seqlens (`torch.Tensor`, *optional*): - Prefix-sum tensor whose length equals the number of documents + 1. The difference between successive - entries gives each document's token count and enables block-diagonal attention masking for packed batches. - max_seqlen (`int`, *optional*): - Maximum document length referenced by `cu_seqlens`. Passed to FlashAttention so it can size temporary - buffers for packed variable-length attention. + Variable-length metadata (e.g., `seq_sizes`) flows via `**kwargs` to attention for backend-specific handling. """ # Run attention directly so variable-length metadata reaches FlashAttention. residual = hidden_states @@ -364,10 +367,10 @@ def forward( attn_output, _ = self.self_attn( hidden_states, attention_mask=attention_mask, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, + output_attentions=output_attentions, **kwargs, ) + hidden_states = residual + attn_output residual = hidden_states @@ -406,48 +409,6 @@ def forward( return BaseModelOutput(last_hidden_state=hidden_states) -def create_document_attention_mask( - config: PretrainedConfig, - input_embeds: torch.Tensor, - cu_seqlens: Optional[torch.Tensor], -) -> Optional[Union[torch.Tensor, Any]]: - """ - Materialize a backend-specific block-diagonal attention mask from packed cu_seqlens. - - Returns None if cu_seqlens is missing/degenerate. - """ - if cu_seqlens is None or cu_seqlens.numel() < 2: - return None # Degenerate input: nothing to mask - - seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() - if seq_sizes.numel() == 0 or int(seq_sizes.sum()) == 0: - return None # All-empty segments produce no attention blocks - - seg_ids = torch.repeat_interleave( - torch.arange(seq_sizes.numel(), device=cu_seqlens.device), - seq_sizes, - ) - mask_function = packed_sequence_mask_function(seg_ids.view(1, -1)) - - seq_len = input_embeds.shape[1] - cache_position = torch.arange(seq_len, device=input_embeds.device, dtype=torch.long) - - mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] - return mask_interface( - batch_size=input_embeds.shape[0], - cache_position=cache_position, - kv_length=seq_len, - kv_offset=0, - mask_function=mask_function, - attention_mask=None, - allow_is_causal_skip=False, - allow_is_bidirectional_skip=False, - dtype=input_embeds.dtype, - config=config, - use_vmap=False, - ) - - def create_pixel_shuffle_index_map( seq_sizes: torch.Tensor, token_grids: torch.Tensor, @@ -560,7 +521,7 @@ def pixel_shuffle_varlen( return out -class IsaacVisionTransformer(nn.Module): +class IsaacVisionTransformer(PreTrainedModel): """Vision tower that packs variable-resolution patches, applies varlen attention, and pixel-shuffles outputs. Args: @@ -575,15 +536,22 @@ class IsaacVisionTransformer(nn.Module): """ _supports_sdpa = True + _supports_flash_attn = True def __init__(self, config: IsaacVisionConfig): - super().__init__() + super().__init__(config) self.config = config self.embeddings = IsaacVisionEmbeddings(config) self.encoder = IsaacVisionEncoder(config) self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, IsaacVisionEmbeddings): + init.zeros_(module.position_embedding) + def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): seq_patches, token_grids = packed_seq_patches seq_sizes = torch.prod(token_grids, dim=-1) @@ -592,19 +560,14 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): hidden_states = self.embeddings(seq_patches, token_grids) # Add a pseudo batch dimension so we can reuse the batch-first encoder stack - # while still driving per-image cu_seqlens through the varlen attention path. + # while still driving per-image sequence metadata through the varlen attention path. hidden_states = hidden_states.unsqueeze(0) - # Generate cumulative sequence lengths for variable-length attention - cu_seqlens = F.pad(seq_sizes.cumsum(0).to(torch.int32), (1, 0)) - - attention_mask = create_document_attention_mask(self.config, hidden_states, cu_seqlens) - - # Pass through encoder with variable-length attention parameters + # Pass through encoder with variable-length metadata for attention encoder_outputs = self.encoder( inputs_embeds=hidden_states, - attention_mask=attention_mask, - cu_seqlens=cu_seqlens, + attention_mask=None, + seq_sizes=seq_sizes, ) hidden_states = encoder_outputs.last_hidden_state @@ -644,8 +607,6 @@ def forward(self, image_features): class IsaacVisionEmbedding(nn.Module): - _supports_sdpa = True - def __init__(self, config: IsaacConfig): super().__init__() vision_cfg = config.vision_config @@ -658,36 +619,66 @@ def forward(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Ten return self.multimodal_projector(hidden_states) -class IsaacRotaryEmbedding(qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding): +class IsaacRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + def __init__(self, config: IsaacConfig, device=None): + super().__init__() rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} config_for_rope = copy.copy(rope_source_cfg) config_for_rope.rope_scaling = rope_scaling init_device = device if device is not None and getattr(device, "type", None) != "meta" else None - super().__init__(config_for_rope, device=init_device) + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) rotary_half_dim = self.inv_freq.shape[0] self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim) self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size @staticmethod - def _resolve_mrope_section(section: Optional[list[int]], rotary_half_dim: int) -> list[int]: - if section is None: - weights = (2, 1, 1) - base = [rotary_half_dim * w // sum(weights) for w in weights] - base[0] += rotary_half_dim - sum(base) - return base + def compute_default_rope_parameters( + config: Optional[IsaacConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - section = [int(v) for v in section] - return section + attention_factor = 1.0 # Unused in this type of RoPE - def _combine_axes(self, tensor: torch.Tensor) -> torch.Tensor: - split_sections = tuple(self.mrope_section * 2) - chunks = tensor.split(split_sections, dim=-1) - return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + # Ignore copy def forward( self, position_ids: torch.Tensor, @@ -706,17 +697,46 @@ def forward( with torch.no_grad(): pos = position_ids.clone() - not_spatial = modality_tensor != ModalityType.image.value + not_spatial = modality_tensor == 1 data_1d = pos[not_spatial][..., 0].unsqueeze(-1) # Collapse non-vision modalities to 1D positions pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) pos_axes = pos.permute(2, 0, 1).contiguous() - cos_axes, sin_axes = super().forward(hidden_states, pos_axes) - cos_axes, sin_axes = cos_axes.to(hidden_states.dtype), sin_axes.to(hidden_states.dtype) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, pos_axes.shape[1], -1, 1) + pos_axes_expanded = pos_axes[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = ( + hidden_states.device.type + if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps" + else "cpu" + ) + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ pos_axes_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + cos_axes, sin_axes = cos.to(hidden_states.dtype), sin.to(hidden_states.dtype) cos_combined, sin_combined = self._combine_axes(cos_axes), self._combine_axes(sin_axes) return cos_combined, sin_combined + @staticmethod + def _resolve_mrope_section(section: Optional[list[int]], rotary_half_dim: int) -> list[int]: + if section is None: + weights = (2, 1, 1) + base = [rotary_half_dim * w // sum(weights) for w in weights] + base[0] += rotary_half_dim - sum(base) + return base + + section = [int(v) for v in section] + return section + + def _combine_axes(self, tensor: torch.Tensor) -> torch.Tensor: + split_sections = tuple(self.mrope_section * 2) + chunks = tensor.split(split_sections, dim=-1) + return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) + @auto_docstring class IsaacModel(PreTrainedModel): @@ -730,7 +750,10 @@ class IsaacModel(PreTrainedModel): _supports_flex_attn = False _can_compile_fullgraph = False _supports_attention_backend = True - _can_record_outputs = {"attentions": OutputRecorder(IsaacVisionAttention, index=1)} + _can_record_outputs = { + "attentions": OutputRecorder(Qwen3Attention, index=1), + "vision_attentions": OutputRecorder(IsaacVisionAttention, index=1), + } all_tied_weights_keys: dict[str, str] = {} def __init__(self, config: IsaacConfig): @@ -744,7 +767,6 @@ def __init__(self, config: IsaacConfig): self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) self.vision_embedding = IsaacVisionEmbedding(config) - self.vision_embedding._supports_sdpa = True self.max_sequence_length = config.max_sequence_length self.vision_rescale_factor = config.vision_rescale_factor self.vision_token = config.vision_token @@ -752,10 +774,6 @@ def __init__(self, config: IsaacConfig): self.post_init() - # Respect config-specified gradient checkpointing - if getattr(config, "gradient_checkpointing", False): - self.gradient_checkpointing_enable() - def get_input_embeddings(self) -> nn.Module: return self.text_model.get_input_embeddings() @@ -862,7 +880,14 @@ def get_rope_index( inputs_embeds: torch.Tensor, cache_position: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - """Build 3D position ids and per-batch RoPE deltas.""" + """Prepare multimodal RoPE positions and carry forward per-batch offsets. + + Unlike vanilla 1D RoPE, Isaac builds 3-axis indices for text and vision tokens. + If callers do not supply positions, we synthesize them from `cache_position` and + use `attention_mask` to strip left padding so pad tokens never consume RoPE slots. + The returned `rope_deltas` capture any custom offset (i.e., prefill length) and + are reused across generation steps so newly decoded tokens keep counting forward + after the cached prefix.""" device = inputs_embeds.device batch_size, seq_len = inputs_embeds.shape[:2] @@ -982,7 +1007,6 @@ def forward( is_mask_dict = isinstance(attention_mask, dict) hidden_states = inputs_embeds - all_attentions = [] if output_attentions else None for layer in self.text_model.layers: layer_mask = attention_mask[layer.attention_type] if is_mask_dict else attention_mask @@ -998,10 +1022,7 @@ def forward( **kwargs, ) - layer_outputs_is_tuple = isinstance(layer_outputs, tuple) - hidden_states = layer_outputs[0] if layer_outputs_is_tuple else layer_outputs - if output_attentions and layer_outputs_is_tuple: - all_attentions.append(layer_outputs[1]) + hidden_states = layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs hidden_states = self.text_model.norm(hidden_states) @@ -1009,7 +1030,7 @@ def forward( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=(hidden_states,), - attentions=tuple(all_attentions) if output_attentions else None, + attentions=None, ) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 2adc54916a3e..e07aab004a3a 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -37,19 +37,14 @@ ) from ...masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, create_masks_for_generate, packed_sequence_mask_function from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...models.qwen3.configuration_qwen3 import Qwen3Config -from ...models.qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3Model, Qwen3PreTrainedModel +from ...models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3ForCausalLM, Qwen3Model, Qwen3PreTrainedModel from ...processing_utils import ProcessorMixin, Unpack from ...utils import TensorType, auto_docstring from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD -from ...utils.generic import ( - OutputRecorder, - TransformersKwargs, - can_return_tuple, - check_model_inputs, -) +from ...utils.generic import OutputRecorder, TransformersKwargs, can_return_tuple, check_model_inputs, maybe_autocast from ...utils.import_utils import ( is_torch_available, is_torchdynamo_compiling, @@ -57,6 +52,7 @@ is_vision_available, ) from ..qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling +from ..qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding from ..siglip2.configuration_siglip2 import Siglip2VisionConfig from ..siglip2.modeling_siglip2 import ( Siglip2Attention, @@ -295,65 +291,19 @@ def _preprocess( ) grouped_outputs[shape] = (patches, token_grid, virtual_dim, real_dim) - def _reorder_grouped_item( # reorder an item of tuple payloads using the same grouped_images_index - grouped: dict[tuple[int, ...], tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], - grouped_index: dict[tuple[int, ...], list[int]], - item_idx: int, - ) -> list[torch.Tensor]: - return reorder_images({k: v[item_idx] for k, v in grouped.items()}, grouped_index) - keys = ("patches", "token_grids", "virtual_pixel_size", "real_pixel_size") tensors: dict[str, torch.Tensor] = {} for i, key in enumerate(keys): - slices = _reorder_grouped_item(grouped_outputs, grouped_images_index, i) + slices = reorder_images( + {shape: values[i] for shape, values in grouped_outputs.items()}, + grouped_images_index, + ) tensors[key] = torch.stack(slices, dim=0) return BatchFeature(data=tensors, tensor_type=return_tensors) -def create_document_attention_mask( - config: PretrainedConfig, - input_embeds: torch.Tensor, - cu_seqlens: Optional[torch.Tensor], -) -> Optional[Union[torch.Tensor, Any]]: - """ - Materialize a backend-specific block-diagonal attention mask from packed cu_seqlens. - - Returns None if cu_seqlens is missing/degenerate. - """ - if cu_seqlens is None or cu_seqlens.numel() < 2: - return None # Degenerate input: nothing to mask - - seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() - if seq_sizes.numel() == 0 or int(seq_sizes.sum()) == 0: - return None # All-empty segments produce no attention blocks - - seg_ids = torch.repeat_interleave( - torch.arange(seq_sizes.numel(), device=cu_seqlens.device), - seq_sizes, - ) - mask_function = packed_sequence_mask_function(seg_ids.view(1, -1)) - - seq_len = input_embeds.shape[1] - cache_position = torch.arange(seq_len, device=input_embeds.device, dtype=torch.long) - - mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] - return mask_interface( - batch_size=input_embeds.shape[0], - cache_position=cache_position, - kv_length=seq_len, - kv_offset=0, - mask_function=mask_function, - attention_mask=None, - allow_is_causal_skip=False, - allow_is_bidirectional_skip=False, - dtype=input_embeds.dtype, - config=config, - use_vmap=False, - ) - - class IsaacVisionEmbeddings(Siglip2VisionEmbeddings): """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences. @@ -375,7 +325,14 @@ def __init__(self, config: IsaacVisionConfig): self.num_patches = config.num_patches self.position_embedding_size = int(self.num_patches**0.5) - self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + self.position_embedding = nn.Parameter( + torch.empty( + self.position_embedding_size, + self.position_embedding_size, + self.embed_dim, + ) + ) + nn.init.normal_(self.position_embedding) @check_model_inputs def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor: @@ -388,11 +345,7 @@ def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> to target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(packed_pixel_values.to(dtype=target_dtype)) - positional_embeddings = self.position_embedding.weight.reshape( - self.position_embedding_size, - self.position_embedding_size, - -1, - ) + positional_embeddings = self.position_embedding resized_positional_embeddings = self.resize_positional_embeddings( positional_embeddings, spatial_shapes, max_length=packed_pixel_values.shape[1] ) @@ -435,20 +388,15 @@ def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor class IsaacVisionAttention(Siglip2Attention): - """Custom attention that supports variable-length sequences with flash attention.""" + """Custom attention that supports variable-length sequences with flash/SDPA backends.""" def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, **kwargs, ): - kwargs.pop("output_hidden_states", None) - kwargs.pop("return_dict", None) - batch_size, seq_length, embed_dim = hidden_states.shape queries = self.q_proj(hidden_states) keys = self.k_proj(hidden_states) @@ -463,35 +411,47 @@ def forward( if attn_impl != "sdpa": attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] + seq_sizes = kwargs.pop("seq_sizes", None) + attention_kwargs: dict[str, Any] = { "is_causal": False, "scaling": self.scale, } - supports_varlen = cu_seqlens is not None and attn_impl in { - "flash_attention_2", - "flash_attention_3", - "flex_attention", - "paged|flash_attention_2", - "paged|flash_attention_3", - } - if supports_varlen: - if max_seqlen is not None: - max_q = max_k = int(max_seqlen) - elif cu_seqlens.numel() >= 2: - lengths = cu_seqlens[1:] - cu_seqlens[:-1] - max_q = max_k = lengths.max() if lengths.numel() > 0 else seq_length + if seq_sizes is not None and seq_sizes.numel() > 0: + if attn_impl in {"flash_attention_2", "flash_attention_3"}: + cu_seqlens = F.pad(seq_sizes.cumsum(0).to(torch.int32), (1, 0)) + max_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attention_kwargs.update( + { + "cu_seq_lens_q": cu_seqlens, + "cu_seq_lens_k": cu_seqlens, + "max_length_q": max_len, + "max_length_k": max_len, + } + ) else: - max_q = max_k = seq_length - - attention_kwargs.update( - { - "cu_seq_lens_q": cu_seqlens, - "cu_seq_lens_k": cu_seqlens, - "max_length_q": max_q, - "max_length_k": max_k, - } - ) + seg_ids = torch.repeat_interleave( + torch.arange(seq_sizes.numel(), device=seq_sizes.device), seq_sizes + ).view(1, -1) + mask_function = packed_sequence_mask_function(seg_ids) + cache_position = torch.arange(seq_length, device=hidden_states.device, dtype=torch.long) + mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[attn_impl] + attention_mask = mask_interface( + batch_size=batch_size, + cache_position=cache_position, + kv_length=seq_length, + kv_offset=0, + mask_function=mask_function, + attention_mask=attention_mask, + allow_is_causal_skip=False, + allow_is_bidirectional_skip=False, + dtype=hidden_states.dtype, + config=self.config, + use_vmap=False, + ) + else: + attention_mask = None attn_output, attn_weights = attention_interface( self, @@ -518,18 +478,11 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, output_attentions: bool = False, **kwargs: Unpack[TransformersKwargs], ): r""" - cu_seqlens (`torch.Tensor`, *optional*): - Prefix-sum tensor whose length equals the number of documents + 1. The difference between successive - entries gives each document's token count and enables block-diagonal attention masking for packed batches. - max_seqlen (`int`, *optional*): - Maximum document length referenced by `cu_seqlens`. Passed to FlashAttention so it can size temporary - buffers for packed variable-length attention. + Variable-length metadata (e.g., `seq_sizes`) flows via `**kwargs` to attention for backend-specific handling. """ # Run attention directly so variable-length metadata reaches FlashAttention. residual = hidden_states @@ -537,10 +490,10 @@ def forward( attn_output, _ = self.self_attn( hidden_states, attention_mask=attention_mask, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, + output_attentions=output_attentions, **kwargs, ) + hidden_states = residual + attn_output residual = hidden_states @@ -671,7 +624,7 @@ def pixel_shuffle_varlen( return out -class IsaacVisionTransformer(nn.Module): +class IsaacVisionTransformer(PreTrainedModel): """Vision tower that packs variable-resolution patches, applies varlen attention, and pixel-shuffles outputs. Args: @@ -686,15 +639,22 @@ class IsaacVisionTransformer(nn.Module): """ _supports_sdpa = True + _supports_flash_attn = True def __init__(self, config: IsaacVisionConfig): - super().__init__() + super().__init__(config) self.config = config self.embeddings = IsaacVisionEmbeddings(config) self.encoder = IsaacVisionEncoder(config) self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, IsaacVisionEmbeddings): + init.zeros_(module.position_embedding) + def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): seq_patches, token_grids = packed_seq_patches seq_sizes = torch.prod(token_grids, dim=-1) @@ -703,19 +663,14 @@ def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): hidden_states = self.embeddings(seq_patches, token_grids) # Add a pseudo batch dimension so we can reuse the batch-first encoder stack - # while still driving per-image cu_seqlens through the varlen attention path. + # while still driving per-image sequence metadata through the varlen attention path. hidden_states = hidden_states.unsqueeze(0) - # Generate cumulative sequence lengths for variable-length attention - cu_seqlens = F.pad(seq_sizes.cumsum(0).to(torch.int32), (1, 0)) - - attention_mask = create_document_attention_mask(self.config, hidden_states, cu_seqlens) - - # Pass through encoder with variable-length attention parameters + # Pass through encoder with variable-length metadata for attention encoder_outputs = self.encoder( inputs_embeds=hidden_states, - attention_mask=attention_mask, - cu_seqlens=cu_seqlens, + attention_mask=None, + seq_sizes=seq_sizes, ) hidden_states = encoder_outputs.last_hidden_state @@ -755,8 +710,6 @@ def forward(self, image_features): class IsaacVisionEmbedding(nn.Module): - _supports_sdpa = True - def __init__(self, config: IsaacConfig): super().__init__() vision_cfg = config.vision_config @@ -907,6 +860,8 @@ def __init__( self.hidden_act = self.text_config.hidden_act self.use_cache = self.text_config.use_cache self.rope_theta = self.rope_parameters["rope_theta"] + self.max_position_embeddings = getattr(self.text_config, "max_position_embeddings", max_sequence_length) + self.text_config.max_position_embeddings = self.max_position_embeddings self.layer_types = getattr(self.text_config, "layer_types", None) layer_type_validation(self.layer_types, self.num_hidden_layers) @@ -1303,13 +1258,26 @@ def forward( with torch.no_grad(): pos = position_ids.clone() - not_spatial = modality_tensor != ModalityType.image.value + not_spatial = modality_tensor == 1 data_1d = pos[not_spatial][..., 0].unsqueeze(-1) # Collapse non-vision modalities to 1D positions pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) pos_axes = pos.permute(2, 0, 1).contiguous() - cos_axes, sin_axes = super().forward(hidden_states, pos_axes) - cos_axes, sin_axes = cos_axes.to(hidden_states.dtype), sin_axes.to(hidden_states.dtype) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, pos_axes.shape[1], -1, 1) + pos_axes_expanded = pos_axes[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = ( + hidden_states.device.type + if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps" + else "cpu" + ) + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ pos_axes_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + cos_axes, sin_axes = cos.to(hidden_states.dtype), sin.to(hidden_states.dtype) cos_combined, sin_combined = self._combine_axes(cos_axes), self._combine_axes(sin_axes) return cos_combined, sin_combined @@ -1320,7 +1288,10 @@ class IsaacModel(Qwen3PreTrainedModel): supports_gradient_checkpointing = True _can_compile_fullgraph = False _supports_flex_attn = False - _can_record_outputs = {"attentions": OutputRecorder(IsaacVisionAttention, index=1)} + _can_record_outputs = { + "attentions": OutputRecorder(Qwen3Attention, index=1), + "vision_attentions": OutputRecorder(IsaacVisionAttention, index=1), + } all_tied_weights_keys: dict[str, str] = {} def __init__(self, config: IsaacConfig): @@ -1334,7 +1305,6 @@ def __init__(self, config: IsaacConfig): self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) self.vision_embedding = IsaacVisionEmbedding(config) - self.vision_embedding._supports_sdpa = True self.max_sequence_length = config.max_sequence_length self.vision_rescale_factor = config.vision_rescale_factor self.vision_token = config.vision_token @@ -1342,10 +1312,6 @@ def __init__(self, config: IsaacConfig): self.post_init() - # Respect config-specified gradient checkpointing - if getattr(config, "gradient_checkpointing", False): - self.gradient_checkpointing_enable() - def get_input_embeddings(self) -> nn.Module: return self.text_model.get_input_embeddings() @@ -1452,7 +1418,14 @@ def get_rope_index( inputs_embeds: torch.Tensor, cache_position: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - """Build 3D position ids and per-batch RoPE deltas.""" + """Prepare multimodal RoPE positions and carry forward per-batch offsets. + + Unlike vanilla 1D RoPE, Isaac builds 3-axis indices for text and vision tokens. + If callers do not supply positions, we synthesize them from `cache_position` and + use `attention_mask` to strip left padding so pad tokens never consume RoPE slots. + The returned `rope_deltas` capture any custom offset (i.e., prefill length) and + are reused across generation steps so newly decoded tokens keep counting forward + after the cached prefix.""" device = inputs_embeds.device batch_size, seq_len = inputs_embeds.shape[:2] @@ -1572,7 +1545,6 @@ def forward( is_mask_dict = isinstance(attention_mask, dict) hidden_states = inputs_embeds - all_attentions = [] if output_attentions else None for layer in self.text_model.layers: layer_mask = attention_mask[layer.attention_type] if is_mask_dict else attention_mask @@ -1588,10 +1560,7 @@ def forward( **kwargs, ) - layer_outputs_is_tuple = isinstance(layer_outputs, tuple) - hidden_states = layer_outputs[0] if layer_outputs_is_tuple else layer_outputs - if output_attentions and layer_outputs_is_tuple: - all_attentions.append(layer_outputs[1]) + hidden_states = layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs hidden_states = self.text_model.norm(hidden_states) @@ -1599,7 +1568,7 @@ def forward( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=(hidden_states,), - attentions=tuple(all_attentions) if output_attentions else None, + attentions=None, ) From 6ba2fdb7cb3480a813b327eac10aff9323cd412b Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Thu, 15 Jan 2026 01:34:44 +0400 Subject: [PATCH 62/77] style: alias norm to communicate scope --- src/transformers/models/isaac/modeling_isaac.py | 6 +++++- src/transformers/models/isaac/modular_isaac.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 156017d32d6d..8221921121b4 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -786,6 +786,10 @@ def set_input_embeddings(self, value: nn.Module) -> None: self.config.text_config.vocab_size = vocab_size self.text_model.config.vocab_size = vocab_size + @property + def final_norm(self) -> nn.Module: + return self.text_model.norm + @property def embed_tokens(self) -> nn.Module: return self.text_model.embed_tokens @@ -1024,7 +1028,7 @@ def forward( hidden_states = layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs - hidden_states = self.text_model.norm(hidden_states) + hidden_states = self.final_norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index e07aab004a3a..f8de4b070d80 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -1324,6 +1324,10 @@ def set_input_embeddings(self, value: nn.Module) -> None: self.config.text_config.vocab_size = vocab_size self.text_model.config.vocab_size = vocab_size + @property + def final_norm(self) -> nn.Module: + return self.text_model.norm + @property def embed_tokens(self) -> nn.Module: return self.text_model.embed_tokens @@ -1562,7 +1566,7 @@ def forward( hidden_states = layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs - hidden_states = self.text_model.norm(hidden_states) + hidden_states = self.final_norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, From 57cbd79475cd95febeee15798503ad916ac8d4e7 Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Wed, 4 Mar 2026 21:44:46 +0400 Subject: [PATCH 63/77] refactor: no packed batch inference (#14) * fix: update imports * fix: replace removed check_model_inputs with merge_with_config_defaults and capture_outputs * fix: no capture outputs within capture outputs * refactor: move isaac vision internals to padded batched flow * refactor: align isaac vision attention with standard mask interfaces * refactor: remove packed_inputs from isaac model api and generation path * chore: purge isaac packing internals and sync modular outputs * refactor: remove isaac packing pipeline and align with transformers batched attention standards * refactor: drop final isaac packed compatibility path * refactor: use OutputRecorder for isaac hidden states * refactor: remove manual output_attentions handling in isaac model * refactor: rely on output recorder for isaac attentions * fix: do not deepcopy text config * style: remove overly defensive checks * style: remove unneeded pops * refactor: simplify pixshuf * style: drop unused vision_model alias * wip simplify * wip simplify 2 * perf: remove device syncs * test: add isaac pixel shuffle strict invariant characterization * refactor: make isaac pixel shuffle tensor-only with strict invariants * chore: regenerate isaac generated files after modular pixel shuffle refactor * style: drop redunant check * refactor: simplify config wiring * refactor: unify multimodal check for input preparation * refactor: drop now redundant init override * style: drop unused attention mask flow through pixel shuffle * style: collapse resize callsite for readability * style: drop more redundant checks * refactor: rely on siglip2 for viison attention * refactor: enforce invariant * refactor: simplify processor * fix: add post init call to vision transformer --- .../models/isaac/configuration_isaac.py | 42 +- .../models/isaac/image_processing_isaac.py | 9 +- .../isaac/image_processing_isaac_fast.py | 58 +- .../models/isaac/modeling_isaac.py | 810 +++++----- .../models/isaac/modular_isaac.py | 1311 ++++++++--------- .../models/isaac/processing_isaac.py | 453 +++--- tests/models/isaac/test_modeling_isaac.py | 168 ++- 7 files changed, 1336 insertions(+), 1515 deletions(-) diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index e98e9a9da0f2..ddd12e55c958 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -4,7 +4,6 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# coding=utf-8 # Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,8 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union - from ...configuration_utils import PreTrainedConfig, PretrainedConfig, layer_type_validation from ...models.qwen3.configuration_qwen3 import Qwen3Config @@ -89,15 +86,13 @@ class IsaacConfig(PretrainedConfig): def __init__( self, - vision_config: Optional[IsaacVisionConfig] = None, - text_config: Optional[Union[Qwen3Config, dict]] = None, + vision_config: IsaacVisionConfig | None = None, + text_config: Qwen3Config | dict | None = None, vision_rescale_factor: float = 1 / 255, max_sequence_length: int = 16384, vision_token: str = "", **kwargs, ): - attn_implementation = kwargs.get("attn_implementation") - if isinstance(text_config, dict): self.text_config = self.sub_configs["text_config"](**text_config) elif isinstance(text_config, Qwen3Config): @@ -105,6 +100,13 @@ def __init__( elif text_config is None: self.text_config = self.sub_configs["text_config"]() + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif isinstance(vision_config, IsaacVisionConfig): + self.vision_config = vision_config + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + # Seed RoPE parameters before base init so the shared mixin can standardize/validate them. self.rope_parameters = getattr(self.text_config, "rope_parameters", None) self.layer_types = getattr(self.text_config, "layer_types", None) @@ -129,23 +131,6 @@ def __init__( self.layer_types = getattr(self.text_config, "layer_types", None) layer_type_validation(self.layer_types, self.num_hidden_layers) - # Handle vision config - either dict or IsaacVisionConfig instance - if isinstance(vision_config, dict): - self.vision_config = self.sub_configs["vision_config"](**vision_config) - elif isinstance(vision_config, IsaacVisionConfig): - self.vision_config = vision_config - elif vision_config is None: - self.vision_config = self.sub_configs["vision_config"]() - - # Propagate user-requested attention backend to the vision sub-config when provided. - if attn_implementation is not None: - if isinstance(attn_implementation, dict): - vision_attn = attn_implementation.get("vision_config", attn_implementation.get("", None)) - else: - vision_attn = attn_implementation - if vision_attn is not None: - self.vision_config._attn_implementation = vision_attn - if getattr(self, "_attn_implementation", None) is None: self._attn_implementation = "sdpa" # Vision normalization parameters @@ -155,14 +140,5 @@ def __init__( self.max_sequence_length = max_sequence_length self.vision_token = vision_token - def to_dict(self): - output = super().to_dict() - # Ensure nested configs round-trip through dict serialization - if hasattr(self, "text_config") and self.text_config is not None: - output["text_config"] = self.text_config.to_dict() - if hasattr(self, "vision_config") and self.vision_config is not None: - output["vision_config"] = self.vision_config.to_dict() - return output - __all__ = ["IsaacConfig"] diff --git a/src/transformers/models/isaac/image_processing_isaac.py b/src/transformers/models/isaac/image_processing_isaac.py index 9e09a15fc072..755690da955a 100644 --- a/src/transformers/models/isaac/image_processing_isaac.py +++ b/src/transformers/models/isaac/image_processing_isaac.py @@ -19,13 +19,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from ...image_processing_utils_fast import ImagesKwargs class IsaacImageProcessorKwargs(ImagesKwargs, total=False): - patch_size: Optional[int] - max_num_patches: Optional[int] - min_num_patches: Optional[int] - pixel_shuffle_scale: Optional[int] + patch_size: int | None + max_num_patches: int | None + min_num_patches: int | None + pixel_shuffle_scale: int | None diff --git a/src/transformers/models/isaac/image_processing_isaac_fast.py b/src/transformers/models/isaac/image_processing_isaac_fast.py index c1de230c635e..fe2b3d5fa8dd 100644 --- a/src/transformers/models/isaac/image_processing_isaac_fast.py +++ b/src/transformers/models/isaac/image_processing_isaac_fast.py @@ -4,7 +4,6 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# coding=utf-8 # Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,12 +20,11 @@ import math from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any from ...feature_extraction_utils import BatchFeature from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict, group_images_by_shape, reorder_images from ...image_utils import PILImageResampling -from ...processing_utils import Unpack from ...utils import TensorType, auto_docstring from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD @@ -80,7 +78,7 @@ def get_image_size_for_max_num_patches( image_width: int, patch_size: int, max_num_patches: int, - min_num_patches: Optional[int] = None, + min_num_patches: int | None = None, eps: float = 1e-5, pixel_shuffle_scale: int = 1, ) -> tuple[int, int]: @@ -164,10 +162,10 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): do_resize = True do_center_crop = False - patch_size: Optional[int] = 16 - max_num_patches: Optional[int] = 256 - min_num_patches: Optional[int] = None - pixel_shuffle_scale: Optional[int] = 1 + patch_size: int | None = 16 + max_num_patches: int | None = 256 + min_num_patches: int | None = None + pixel_shuffle_scale: int | None = 1 do_pad = False do_rescale = True do_normalize = True @@ -176,19 +174,9 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): do_convert_rgb = True disable_grouping = False - def __init__( - self, - **kwargs: Unpack[IsaacImageProcessorFastKwargs], - ) -> None: - super().__init__(**kwargs) - def _validate_preprocess_kwargs(self, **kwargs): # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for. kwargs.pop("do_resize", None) - kwargs.pop("size", None) - kwargs.pop("do_center_crop", None) - kwargs.pop("crop_size", None) - kwargs.pop("disable_grouping", None) return super()._validate_preprocess_kwargs(**kwargs) def resize( @@ -197,33 +185,25 @@ def resize( size: SizeDict, **kwargs, ) -> torch.Tensor: - resize_kwargs: dict[str, Any] = {"align_corners": False} - resize_mode = "bilinear" - - return F.interpolate( - image, - size=(size.height, size.width), - mode=resize_mode, - **resize_kwargs, - ) + return F.interpolate(image, size=(size.height, size.width), mode="bilinear", align_corners=False) def _preprocess( self, images: list[torch.Tensor], do_resize: bool, - interpolation: Optional[Any], - do_rescale: Optional[bool], - rescale_factor: Optional[float], - do_normalize: Optional[bool], - image_mean: Optional[Union[float, Sequence[float]]], - image_std: Optional[Union[float, Sequence[float]]], - disable_grouping: Optional[bool] = None, - return_tensors: Optional[Union[str, TensorType]] = None, + interpolation: Any | None, + do_rescale: bool | None, + rescale_factor: float | None, + do_normalize: bool | None, + image_mean: float | Sequence[float] | None, + image_std: float | Sequence[float] | None, + disable_grouping: bool | None = None, + return_tensors: str | TensorType | None = None, *, - patch_size: Optional[int] = None, - max_num_patches: Optional[int] = None, - min_num_patches: Optional[int] = None, - pixel_shuffle_scale: Optional[int] = None, + patch_size: int | None = None, + max_num_patches: int | None = None, + min_num_patches: int | None = None, + pixel_shuffle_scale: int | None = None, **kwargs, ) -> BatchFeature: grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 8221921121b4..5ce71788ceae 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -4,7 +4,6 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# coding=utf-8 # Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -31,20 +30,21 @@ from ...generation.utils import GenerationMixin from ...image_processing_utils_fast import ImagesKwargs from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func -from ...masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, create_masks_for_generate, packed_sequence_mask_function +from ...masking_utils import create_bidirectional_mask, create_masks_for_generate from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3Model, Qwen3PreTrainedModel +from ...models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3DecoderLayer, Qwen3Model, Qwen3PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring -from ...utils.generic import OutputRecorder, TransformersKwargs, can_return_tuple, check_model_inputs, maybe_autocast +from ...utils import auto_docstring, torch_compilable_check +from ...utils.generic import TransformersKwargs, can_return_tuple, maybe_autocast, merge_with_config_defaults from ...utils.import_utils import ( is_torch_available, is_torchdynamo_compiling, ) +from ...utils.output_capturing import OutputRecorder, capture_outputs from .configuration_isaac import IsaacConfig, IsaacVisionConfig @@ -68,10 +68,21 @@ class ModalityType(IntEnum): class IsaacImageProcessorFastKwargs(ImagesKwargs, total=False): - patch_size: Optional[int] - max_num_patches: Optional[int] - min_num_patches: Optional[int] - pixel_shuffle_scale: Optional[int] + """ + patch_size (`int`, *optional*): + Side length (in pixels) for square patches extracted from resized images. + max_num_patches (`int`, *optional*): + Upper bound on extracted patches per image after resizing. + min_num_patches (`int`, *optional*): + Lower bound on extracted patches per image after resizing. + pixel_shuffle_scale (`int`, *optional*): + Pixel-shuffle reduction factor applied in the vision tower. + """ + + patch_size: int | None + max_num_patches: int | None + min_num_patches: int | None + pixel_shuffle_scale: int | None class IsaacVisionEmbeddings(nn.Module): @@ -143,7 +154,10 @@ def resize_positional_embeddings( for i in range(batch_size): # (1, dim, height, width) -> (1, dim, target_height, target_width) - height, width = spatial_shapes[i] + height, width = spatial_shapes[i].tolist() # will be itemized in F.interpolate either way + torch_compilable_check((width > 0), "Width of resized positional embeddings must be positive.") + torch_compilable_check((height > 0), "Height of resized positional embeddings must be positive.") + torch_compilable_check((height * width) <= max_length, "Resized positional embeddings exceed max_length.") resized_embeddings = F.interpolate( positional_embeddings, size=(height, width), @@ -163,8 +177,14 @@ def resize_positional_embeddings( return resulted_positional_embeddings - @check_model_inputs - def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor: + @merge_with_config_defaults + @capture_outputs + def forward( + self, + pixel_values: torch.Tensor, + spatial_shapes: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: """ Args: pixel_values (`torch.FloatTensor`): @@ -172,55 +192,44 @@ def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> to spatial_shapes (`list[tuple[int, int]]`): Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to """ - # Rebatch packed variable-resolution patches to resize per-image position embeddings - # and track lengths for varlen attention metadata. - packed_pixel_values, seq_lengths = self._pack_to_batch(seq_patches, spatial_shapes) - if packed_pixel_values is None: - return seq_patches.new_zeros((0, self.embed_dim)) - + # pixel_values: (num_images, max_patches, patch_dim) target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding(packed_pixel_values.to(dtype=target_dtype)) + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) - positional_embeddings = self.position_embedding resized_positional_embeddings = self.resize_positional_embeddings( - positional_embeddings, spatial_shapes, max_length=packed_pixel_values.shape[1] + self.position_embedding, + spatial_shapes, + max_length=pixel_values.shape[1], ) - embeddings = patch_embeds + resized_positional_embeddings - return self._unpack_from_batch(embeddings, seq_lengths) - def _pack_to_batch( - self, - seq_patches: torch.Tensor, - spatial_shapes: torch.Tensor, - ) -> tuple[Optional[torch.Tensor], torch.Tensor]: - """Rebatch a packed patch sequence using per-image grids to align embeddings. + if attention_mask is not None: + embeddings = embeddings * attention_mask.unsqueeze(-1).to(dtype=embeddings.dtype) - Args: - seq_patches: Packed patches of shape (total_patches, patch_dim). - spatial_shapes: Per-image patch grids of shape (num_images, 2) as (H_tokens, W_tokens). + return embeddings - Returns: - (packed_pixel_values, seq_lengths) where: - - packed_pixel_values: (batch, max_len, patch_dim) padded with zeros, or None if batch_size == 0 - - seq_lengths: (batch,) lengths for each image - """ - seq_lengths = spatial_shapes.long().prod(dim=-1) # (B,) - batch_size = int(seq_lengths.numel()) - if batch_size == 0: - return None, seq_lengths - # Split the packed sequence into per-image chunks, then pad to a batch - lengths_list = seq_lengths.tolist() - chunks = seq_patches.split(lengths_list, dim=0) - packed_pixel_values = nn.utils.rnn.pad_sequence(chunks, batch_first=True) # zero-padded by default - return packed_pixel_values, seq_lengths +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask - def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor) -> torch.Tensor: - """Flatten a padded batch back to packed sequence order using `seq_lengths`.""" - lengths = seq_lengths.to(device=embeddings.device).tolist() - chunks = [embeddings[i, :l] for i, l in enumerate(lengths) if l > 0] - return torch.cat(chunks, dim=0) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights class IsaacVisionAttention(nn.Module): @@ -249,12 +258,13 @@ def __init__(self, config): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + attention_mask: torch.Tensor | None = None, **kwargs, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: """Input shape: Batch x Time x Channel""" + batch_size, seq_length, embed_dim = hidden_states.shape + queries = self.q_proj(hidden_states) keys = self.k_proj(hidden_states) values = self.v_proj(hidden_states) @@ -263,52 +273,9 @@ def forward( keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - attn_impl = self.config._attn_implementation - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"] - if attn_impl != "sdpa": - attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] - - seq_sizes = kwargs.pop("seq_sizes", None) - - attention_kwargs: dict[str, Any] = { - "is_causal": False, - "scaling": self.scale, - } - - if seq_sizes is not None and seq_sizes.numel() > 0: - if attn_impl in {"flash_attention_2", "flash_attention_3"}: - cu_seqlens = F.pad(seq_sizes.cumsum(0).to(torch.int32), (1, 0)) - max_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - attention_kwargs.update( - { - "cu_seq_lens_q": cu_seqlens, - "cu_seq_lens_k": cu_seqlens, - "max_length_q": max_len, - "max_length_k": max_len, - } - ) - else: - seg_ids = torch.repeat_interleave( - torch.arange(seq_sizes.numel(), device=seq_sizes.device), seq_sizes - ).view(1, -1) - mask_function = packed_sequence_mask_function(seg_ids) - cache_position = torch.arange(seq_length, device=hidden_states.device, dtype=torch.long) - mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[attn_impl] - attention_mask = mask_interface( - batch_size=batch_size, - cache_position=cache_position, - kv_length=seq_length, - kv_offset=0, - mask_function=mask_function, - attention_mask=attention_mask, - allow_is_causal_skip=False, - allow_is_bidirectional_skip=False, - dtype=hidden_states.dtype, - config=self.config, - use_vmap=False, - ) - else: - attention_mask = None + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -316,8 +283,11 @@ def forward( keys, values, attention_mask, - **attention_kwargs, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, ) + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) @@ -340,7 +310,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class IsaacVisionEncoderLayer(GradientCheckpointingLayer): - """Isaac vision encoder layer with variable-length attention.""" + """Isaac vision encoder layer using the shared attention interfaces.""" def __init__(self, config: IsaacVisionConfig): super().__init__() @@ -354,24 +324,18 @@ def __init__(self, config: IsaacVisionConfig): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + attention_mask: torch.Tensor, **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: - r""" - Variable-length metadata (e.g., `seq_sizes`) flows via `**kwargs` to attention for backend-specific handling. - """ - # Run attention directly so variable-length metadata reaches FlashAttention. residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) - attn_output, _ = self.self_attn( - hidden_states, + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, - output_attentions=output_attentions, **kwargs, ) - - hidden_states = residual + attn_output + hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) @@ -382,7 +346,7 @@ def forward( class IsaacVisionEncoder(nn.Module): - """Encoder using Isaac encoder layers with variable-length attention support.""" + """Encoder using Isaac encoder layers.""" def __init__(self, config: IsaacVisionConfig): super().__init__() @@ -395,7 +359,7 @@ def __init__(self, config: IsaacVisionConfig): def forward( self, inputs_embeds, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: hidden_states = inputs_embeds @@ -409,130 +373,83 @@ def forward( return BaseModelOutput(last_hidden_state=hidden_states) -def create_pixel_shuffle_index_map( - seq_sizes: torch.Tensor, - token_grids: torch.Tensor, - scale_factor: int = 1, - device: Optional[torch.device] = None, -) -> torch.Tensor: - """ - Build a gather-index map that tells us, for every *output* token after - pixel-shuffle, which `scale_factor**2` *input* tokens are being merged. - - Args - ---- - seq_sizes : (num_images,) - #patches in each image (row-major order) - token_grids : (num_images,2) - (height, width) for every image - scale_factor : spatial down-scale factor (โ‰ฅ2) - device : (optional) overrides `seq_sizes.device` - - Returns - ------- - gather_idx : (new_total_seq_len, scale_factor**2) int64 tensor. - gather_idx[i, j] is the *flat* index into the *original* - packed sequence for the j-th sub-patch that forms the - i-th output token. - """ - if not is_torchdynamo_compiling(): - if (token_grids % scale_factor).any(): - raise AssertionError( - f"Every (H,W) in token_grids must be divisible by scale_factor={scale_factor}, got {token_grids.tolist()}" - ) - - gather_chunks: list[torch.Tensor] = [] - tok_offset = 0 - for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist()): - # Flat indices for this image's packed segment - grid = torch.arange(seq_len, device=device, dtype=torch.int64).view(h, w) + tok_offset - - # Block into (H/s, W/s) groups; each group contributes s*s indices - grid = ( - grid.view(h // scale_factor, scale_factor, w // scale_factor, scale_factor) - .permute(0, 2, 1, 3) - .contiguous() - ) - gather_chunks.append(grid.view(-1, scale_factor * scale_factor)) - - tok_offset += seq_len - - return torch.cat(gather_chunks, dim=0) - - -def pixel_shuffle_varlen( +def pixel_shuffle_padded( x: torch.Tensor, token_grids: torch.Tensor, scale_factor: int = 1, ) -> torch.Tensor: - r"""Apply pixel shuffle to a packed vision sequence without unpacking per image. + """Apply pixel shuffle per image on padded batched vision embeddings. Args: x (`torch.Tensor`): - Concatenated vision embeddings. Accepts `(seq_len, hidden_size)` or `(1, seq_len, hidden_size)` shapes - produced by stacking image patches. + Vision embeddings of shape `(num_images, max_patches, hidden_size)`. token_grids (`torch.Tensor`): - Integer tensor of shape `(num_images, 2)` whose rows give the `(height, width)` patch grid sizes - corresponding to each image segment inside `x`. + Grid sizes `(height, width)` per image, shape `(num_images, 2)`. scale_factor (`int`, *optional*, defaults to 1): - Spatial down-sampling factor specific to pixel shuffle. Values greater than one merge `scale_factor**2` neighboring patches into a - single embedding channel-group. + Spatial down-sampling factor. Returns: - `torch.Tensor`: Pixel-shuffled embeddings with shape matching the input convention: - `(seq_len, hidden_size * scale_factor**2)` when the input was 2D, or `(1, seq_len, hidden_size * scale_factor**2)` - if the singleton batch dimension was present. - - Raises: - ValueError: If more than one batch item is provided. + Tuple of: + - pixel-shuffled embeddings `(num_images, max_tokens, hidden_size * scale_factor**2)` + - attention mask `(num_images, max_tokens)` + - per-image valid token lengths `(num_images,)` """ - return_with_batch_dim = x.dim() == 3 - if return_with_batch_dim: - if x.size(0) != 1: - raise ValueError( - f"Packed vision sequences expect a singleton batch dimension; received batch_size={x.size(0)}." - ) - embeddings = x.squeeze(0) # (seq, embed) - else: - embeddings = x # (seq, embed) + num_images, max_patches, embed_dim = x.shape + output_dim = embed_dim * scale_factor * scale_factor + + token_grids = token_grids.to(device=x.device, dtype=torch.long) + heights = token_grids[:, 0] + widths = token_grids[:, 1] + full_lengths = heights * widths + + non_empty = full_lengths > 0 + if not is_torchdynamo_compiling(): + divisible = ((heights % scale_factor) == 0) & ((widths % scale_factor) == 0) + torch_compilable_check( + (~non_empty) | divisible, + f"Every non-empty (H, W) grid must be divisible by pixel_shuffle_scale={scale_factor}.", + ) - embed_dim = embeddings.size(-1) - scale_factor = int(scale_factor) + output_lengths = (heights // scale_factor) * (widths // scale_factor) + max_output_tokens = output_lengths.max() + shuffled_4d = x.new_zeros((num_images, max_output_tokens, scale_factor * scale_factor, embed_dim)) - # Calculate seq_sizes from token_grids - seq_sizes = torch.prod(token_grids, dim=-1) + token_positions = torch.arange(max_patches, device=x.device, dtype=torch.long).unsqueeze(0).expand(num_images, -1) + valid_token_mask = token_positions < full_lengths.unsqueeze(1) - # Build a single gather index so pixel shuffle works on the packed stream - # without unpacking per-image grids. - gather_idx = create_pixel_shuffle_index_map( - seq_sizes=seq_sizes, - token_grids=token_grids, - scale_factor=scale_factor, - device=embeddings.device, - ) # (new_seq, scale_factor**2) + safe_widths = torch.where(widths > 0, widths, torch.ones_like(widths)) + row_index = torch.div(token_positions, safe_widths.unsqueeze(1), rounding_mode="floor") + col_index = token_positions.remainder(safe_widths.unsqueeze(1)) - # Gather โ†’ (new_seq, scale_factor**2, embed_dim) - gathered = embeddings[gather_idx] # fancy indexing keeps gradient + output_widths = widths.div(scale_factor, rounding_mode="floor") + output_index = row_index.div(scale_factor, rounding_mode="floor") * output_widths.unsqueeze(1) + output_index = output_index + col_index.div(scale_factor, rounding_mode="floor") + sub_index = row_index.remainder(scale_factor) * scale_factor + col_index.remainder(scale_factor) - # Merge the scale_factor**2 group dimension into channels to finish the shuffle - out = gathered.reshape(gathered.size(0), embed_dim * scale_factor * scale_factor) + batch_index = torch.arange(num_images, device=x.device, dtype=torch.long).unsqueeze(1).expand_as(token_positions) + shuffled_4d[batch_index[valid_token_mask], output_index[valid_token_mask], sub_index[valid_token_mask]] = x[ + valid_token_mask + ] - # Restore batch dimension if needed - if return_with_batch_dim: - out = out.unsqueeze(0) - return out + shuffled = shuffled_4d.view(num_images, max_output_tokens, output_dim) + return shuffled class IsaacVisionTransformer(PreTrainedModel): - """Vision tower that packs variable-resolution patches, applies varlen attention, and pixel-shuffles outputs. + """Vision tower for padded variable-resolution patches with per-image masks. Args: config (IsaacVisionConfig): Vision configuration with pixel-shuffle and patching parameters. Inputs: - packed_seq_patches (Tuple[Tensor, Tensor]): ``(patches, token_grids)`` where ``patches`` is a packed - patch sequence and ``token_grids`` holds per-image (H_tokens, W_tokens). + vision_tokens (Tuple[Tensor, Tensor, Optional[Tensor]]): + `(patches, token_grids, patch_attention_mask)` where: + - `patches`: `(num_images, max_patches, patch_dim)` + - `token_grids`: `(num_images, 2)` with per-image `(H_tokens, W_tokens)` + - `patch_attention_mask`: `(num_images, max_patches)` or `None` Returns: - torch.Tensor: Vision embeddings after encoder + pixel shuffle, shaped ``(seq_len, hidden_size * s^2)``. + Tuple of `(pixel_shuffled_features, attention_mask, token_lengths)`. """ _supports_sdpa = True @@ -546,44 +463,42 @@ def __init__(self, config: IsaacVisionConfig): self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor + self.post_init() + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, IsaacVisionEmbeddings): init.zeros_(module.position_embedding) - def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): - seq_patches, token_grids = packed_seq_patches - seq_sizes = torch.prod(token_grids, dim=-1) - - # Get embeddings from packed sequence - hidden_states = self.embeddings(seq_patches, token_grids) - - # Add a pseudo batch dimension so we can reuse the batch-first encoder stack - # while still driving per-image sequence metadata through the varlen attention path. - hidden_states = hidden_states.unsqueeze(0) + def forward( + self, + vision_tokens: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor | None], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if len(vision_tokens) == 2: + seq_patches, token_grids = vision_tokens + vision_patch_attention_mask = None + else: + seq_patches, token_grids, vision_patch_attention_mask = vision_tokens + hidden_states = self.embeddings( + seq_patches, + token_grids, + attention_mask=vision_patch_attention_mask, + ) - # Pass through encoder with variable-length metadata for attention - encoder_outputs = self.encoder( + encoder_attention_mask = create_bidirectional_mask( + config=self.config, inputs_embeds=hidden_states, - attention_mask=None, - seq_sizes=seq_sizes, + attention_mask=vision_patch_attention_mask, ) - hidden_states = encoder_outputs.last_hidden_state - - # Apply final layer normalization - hidden_states = self.post_layernorm(hidden_states) + encoder_outputs = self.encoder(inputs_embeds=hidden_states, attention_mask=encoder_attention_mask) + hidden_states = self.post_layernorm(encoder_outputs.last_hidden_state) - hidden_states = pixel_shuffle_varlen( + return pixel_shuffle_padded( x=hidden_states, token_grids=token_grids, scale_factor=self.pixel_shuffle_scale_factor, ) - # Remove the pseudo batch dimension we added earlier - hidden_states = hidden_states.squeeze(0) - - # Return the full sequence of embeddings - return hidden_states class IsaacMultiModalProjector(nn.Module): @@ -614,9 +529,14 @@ def __init__(self, config: IsaacConfig): self.vision_tower = IsaacVisionTransformer(vision_cfg) self.multimodal_projector = IsaacMultiModalProjector(config) - def forward(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - hidden_states = self.vision_tower(vision_tokens) - return self.multimodal_projector(hidden_states) + def forward( + self, + vision_tokens: tuple[torch.Tensor, torch.Tensor, torch.Tensor | None], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + vision_patches, token_grids, vision_patch_attention_mask = vision_tokens + hidden_states = self.vision_tower((vision_patches, token_grids, vision_patch_attention_mask)) + projected = self.multimodal_projector(hidden_states) + return projected class IsaacRotaryEmbedding(nn.Module): @@ -628,8 +548,6 @@ def __init__(self, config: IsaacConfig, device=None): rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} config_for_rope = copy.copy(rope_source_cfg) config_for_rope.rope_scaling = rope_scaling - - init_device = device if device is not None and getattr(device, "type", None) != "meta" else None self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings @@ -650,9 +568,9 @@ def __init__(self, config: IsaacConfig, device=None): @staticmethod def compute_default_rope_parameters( - config: Optional[IsaacConfig] = None, + config: IsaacConfig | None = None, device: Optional["torch.device"] = None, - seq_len: Optional[int] = None, + seq_len: int | None = None, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies according to the original RoPE implementation @@ -683,7 +601,7 @@ def forward( self, position_ids: torch.Tensor, modality_tensor: torch.Tensor, - hidden_states: Optional[torch.Tensor] = None, + hidden_states: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if hidden_states is None: batch, seq_len, _ = position_ids.shape @@ -722,7 +640,7 @@ def forward( return cos_combined, sin_combined @staticmethod - def _resolve_mrope_section(section: Optional[list[int]], rotary_half_dim: int) -> list[int]: + def _resolve_mrope_section(section: list[int] | None, rotary_half_dim: int) -> list[int]: if section is None: weights = (2, 1, 1) base = [rotary_half_dim * w // sum(weights) for w in weights] @@ -751,18 +669,15 @@ class IsaacModel(PreTrainedModel): _can_compile_fullgraph = False _supports_attention_backend = True _can_record_outputs = { - "attentions": OutputRecorder(Qwen3Attention, index=1), - "vision_attentions": OutputRecorder(IsaacVisionAttention, index=1), + "hidden_states": OutputRecorder(Qwen3DecoderLayer), + "attentions": Qwen3Attention, + "vision_attentions": IsaacVisionAttention, } all_tied_weights_keys: dict[str, str] = {} def __init__(self, config: IsaacConfig): Qwen3PreTrainedModel.__init__(self, config) - - text_cfg_source = config.text_config - text_cfg = copy.deepcopy(text_cfg_source) - self.text_model = Qwen3Model._from_config(text_cfg) - self.text_model.config = config # Ensure downstream callers observe the composed config + self.text_model = Qwen3Model._from_config(config.text_config) self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) @@ -798,88 +713,73 @@ def embed_tokens(self) -> nn.Module: def embed_tokens(self, value: nn.Module) -> None: self.text_model.embed_tokens = value - @property - def vision_model(self) -> nn.Module: - return self.vision_embedding.vision_tower - - def embed_packed_inputs( - self, input_ids: torch.Tensor, packed_inputs: dict[str, Optional[torch.Tensor]] + def embed_multimodal_inputs( + self, + input_ids: torch.Tensor, + modality_tensor: torch.Tensor, + vision_patches: torch.Tensor, + vision_token_grids: torch.Tensor, + vision_patch_attention_mask: torch.Tensor | None = None, + vision_token_offsets: torch.Tensor | None = None, + vision_token_lengths: torch.Tensor | None = None, + vision_image_attention_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Expects input_ids for text tokens and packed_inputs containing: - - modality_tensor: (batch, seq_len) modality ids aligned to the sequence - - position_ids: (batch, seq_len, 3) MRoPE coordinates (optional) - - vision_patches: concatenated vision tokens shaped (total_tokens, embed_dim) or None - - vision_token_grids: (num_images, 2) token grid sizes or None - - vision_token_offsets: (num_images,) offsets into each image's virtual token span (optional) - - vision_token_lengths: (num_images,) surviving virtual token lengths per image (optional) - - vision_token_batch_indices: (num_images,) batch row for each image (optional; defaults to zeros) - """ - modality = packed_inputs["modality_tensor"].to(device=input_ids.device, dtype=torch.long) + modality = modality_tensor.to(device=input_ids.device, dtype=torch.long) embeds = self.text_model.embed_tokens(input_ids) + image_token_mask = modality == ModalityType.image.value - vision_patches = packed_inputs.get("vision_patches") - if vision_patches is None: + if vision_patches is None or vision_token_grids is None: + if torch.any(image_token_mask): + raise ValueError("Image placeholders require `vision_patches` and `vision_token_grids`.") return embeds, modality - token_grids = packed_inputs["vision_token_grids"].to(device=vision_patches.device, dtype=torch.long) - vision = self.vision_embedding((vision_patches, token_grids)) # (total_tokens, hidden) - - # per-image token counts AFTER pixel-shuffle - vision_reduction_factor = int(self.config.vision_config.pixel_shuffle_scale_factor) - sizes = ( - token_grids.prod(-1).div(vision_reduction_factor * vision_reduction_factor, rounding_mode="floor").tolist() + vision_patches = vision_patches.to(device=embeds.device) + token_grids = vision_token_grids.to(device=embeds.device, dtype=torch.long) + image_attention_mask = ( + vision_image_attention_mask.to(device=embeds.device, dtype=torch.bool) + if vision_image_attention_mask is not None + else torch.ones(token_grids.shape[:2], device=embeds.device, dtype=torch.bool) ) - offsets = packed_inputs.get("vision_token_offsets") - lengths = packed_inputs.get("vision_token_lengths") - batch_indices = packed_inputs.get("vision_token_batch_indices") - - chunks = vision.split(sizes, dim=0) - picked: list[torch.Tensor] = [] - picked_batch: list[int] = [] - for chunk, size, offset, length, batch_index in zip( - chunks, - sizes, - offsets.tolist(), - lengths.tolist(), - (batch_indices.tolist() if batch_indices is not None else [0] * len(sizes)), - ): - if size <= 0: - continue - offset = max(0, min(int(offset), size)) - length = max(0, min(int(length), size - offset)) - if length: - picked.append(chunk[offset : offset + length]) - picked_batch.append(int(batch_index)) - if picked: - vision_chunks = picked - vision_batch_idx = picked_batch - else: - vision_chunks = vision_batch_idx = [] - - vision = torch.cat(vision_chunks, 0) if vision_chunks else vision.new_zeros((0, vision.size(-1))) - embeds = embeds.clone() - num_batches = modality.shape[0] - image_positions = [ - (modality[b] == ModalityType.image.value).nonzero(as_tuple=False).squeeze(-1) for b in range(num_batches) - ] - cursors = [0 for _ in range(num_batches)] - - for chunk, batch_index in zip(vision_chunks, vision_batch_idx): - if chunk.numel() == 0: - continue - positions = image_positions[batch_index] - start = cursors[batch_index] - end = start + chunk.shape[0] - embeds[batch_index, positions[start:end]] = chunk.to(device=embeds.device, dtype=embeds.dtype) - cursors[batch_index] = end + patch_attention_mask = ( + vision_patch_attention_mask.to(device=embeds.device, dtype=torch.long) + if vision_patch_attention_mask is not None + else torch.ones(vision_patches.shape[:3], device=embeds.device, dtype=torch.long) + ) + offsets = ( + vision_token_offsets.to(device=embeds.device, dtype=torch.long) + if vision_token_offsets is not None + else torch.zeros(token_grids.shape[:2], device=embeds.device, dtype=torch.long) + ) + reduction_factor = int(self.config.vision_config.pixel_shuffle_scale_factor) ** 2 + lengths = ( + vision_token_lengths.to(device=embeds.device, dtype=torch.long) + if vision_token_lengths is not None + else token_grids.prod(-1).div(reduction_factor, rounding_mode="floor").to(dtype=torch.long) + ) + + flat_vision_patches = vision_patches[image_attention_mask] + flat_patch_attention_mask = patch_attention_mask[image_attention_mask] + flat_token_grids = token_grids[image_attention_mask] + flat_offsets = offsets[image_attention_mask] + flat_lengths = lengths[image_attention_mask] + + vision_embeddings = self.vision_embedding((flat_vision_patches, flat_token_grids, flat_patch_attention_mask)) + token_positions = torch.arange(flat_lengths.max(), device=embeds.device, dtype=torch.long) + gather_positions = flat_offsets[:, None] + token_positions[None, :] + gather_mask = token_positions[None, :] < flat_lengths[:, None] + image_features = vision_embeddings[ + torch.arange(vision_embeddings.shape[0], device=embeds.device, dtype=torch.long)[:, None], + gather_positions, + ][gather_mask] + scatter_mask = image_token_mask.unsqueeze(-1).expand_as(embeds) + embeds = embeds.masked_scatter(scatter_mask, image_features) return embeds, modality def get_rope_index( self, *, - position_ids: Optional[torch.Tensor] = None, + position_ids: torch.Tensor | None = None, attention_mask: torch.Tensor, inputs_embeds: torch.Tensor, cache_position: torch.Tensor, @@ -901,12 +801,13 @@ def get_rope_index( if cp.ndim == 1: cp = cp.view(1, -1).expand(batch_size or 1, -1) - base_delta = torch.as_tensor( - 0 if self.rope_deltas is None else self.rope_deltas, - device=device, - dtype=torch.long, - ).reshape(-1, 1) - base_delta = torch.broadcast_to(base_delta, (batch_size, 1)) + is_new_prefill = cp[:, :1].eq(0).all(dim=1, keepdim=True) + if self.rope_deltas is None: + base_delta = torch.zeros((batch_size, 1), device=device, dtype=torch.long) + else: + previous_delta = torch.as_tensor(self.rope_deltas, device=device, dtype=torch.long).reshape(-1, 1) + previous_delta = torch.broadcast_to(previous_delta, (batch_size, 1)) + base_delta = torch.where(is_new_prefill, torch.zeros_like(previous_delta), previous_delta) mask_delta = attention_mask.to(device=device, dtype=torch.long).sum(1, keepdim=True) - attention_mask.size( 1 @@ -931,17 +832,24 @@ def get_rope_index( return position_ids, rope_deltas @auto_docstring - @check_model_inputs + @merge_with_config_defaults + @capture_outputs def forward( self, - input_ids: Optional[torch.LongTensor] = None, - packed_inputs: Optional[dict[str, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor | None = None, + modality_tensor: torch.LongTensor | None = None, + vision_patches: torch.Tensor | None = None, + vision_patch_attention_mask: torch.Tensor | None = None, + vision_token_grids: torch.LongTensor | None = None, + vision_token_offsets: torch.LongTensor | None = None, + vision_token_lengths: torch.LongTensor | None = None, + vision_image_attention_mask: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPast: """ @@ -950,21 +858,59 @@ def forward( Computes position embeddings once and passes them through all layers. Args: - packed_inputs (`dict`, *optional*): - Plain tensor payloads. When provided, requires `input_ids` for text tokens (or `text_token_ids` so `input_ids` can be rebuilt). modality_tensor (`torch.LongTensor`, *optional*): Modality identifiers aligned with the embedded sequence, shaped `(batch_size, seq_len)` and containing - values from `ModalityType`. Automatically built from `packed_inputs` or treated as text-only when omitted. + values from `ModalityType`. Treated as text-only when omitted. + vision_patches (`torch.FloatTensor`, *optional*): + Padded per-image patch vectors of shape `(batch_size, max_images, max_patches, patch_dim)`. + vision_patch_attention_mask (`torch.LongTensor`, *optional*): + Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. + vision_token_grids (`torch.LongTensor`, *optional*): + Per-image patch grids `(h, w)` with shape `(batch_size, max_images, 2)`. + vision_token_offsets (`torch.LongTensor`, *optional*): + Start offsets inside the per-image vision embedding sequence, shape `(batch_size, max_images)`. + vision_token_lengths (`torch.LongTensor`, *optional*): + Number of vision tokens to consume per image, shape `(batch_size, max_images)`. + vision_image_attention_mask (`torch.LongTensor`, *optional*): + Mask indicating which image slots are populated, shape `(batch_size, max_images)`. """ - output_attentions = kwargs.pop("output_attentions", None) - - modality_tensor: Optional[torch.Tensor] = None + if inputs_embeds is None: + if input_ids is None: + raise ValueError("`input_ids` or `inputs_embeds` must be provided.") + + has_vision_inputs = any( + value is not None + for value in ( + vision_patches, + vision_patch_attention_mask, + vision_token_grids, + vision_token_offsets, + vision_token_lengths, + vision_image_attention_mask, + ) + ) + if modality_tensor is not None or has_vision_inputs: + if modality_tensor is None: + modality_tensor = torch.full_like(input_ids, ModalityType.text.value) + inputs_embeds, modality_tensor = self.embed_multimodal_inputs( + input_ids=input_ids, + modality_tensor=modality_tensor, + vision_patches=vision_patches, + vision_patch_attention_mask=vision_patch_attention_mask, + vision_token_grids=vision_token_grids, + vision_token_offsets=vision_token_offsets, + vision_token_lengths=vision_token_lengths, + vision_image_attention_mask=vision_image_attention_mask, + ) + else: + inputs_embeds = self.text_model.embed_tokens(input_ids) - if packed_inputs is not None: - inputs_embeds, modality_tensor = self.embed_packed_inputs(input_ids, packed_inputs) - elif input_ids is not None: - inputs_embeds = self.text_model.embed_tokens(input_ids) + if modality_tensor is None: + batch_size, seq_len = inputs_embeds.shape[:2] + modality_tensor = torch.full( + (batch_size, seq_len), ModalityType.text.value, device=inputs_embeds.device, dtype=torch.long + ) device = inputs_embeds.device batch_size, seq_len = inputs_embeds.shape[:2] @@ -979,9 +925,6 @@ def forward( if attention_mask is None: attention_mask = torch.ones(inputs_embeds.shape[:2], device=inputs_embeds.device, dtype=torch.long) - if position_ids is None and packed_inputs is not None and packed_inputs.get("position_ids") is not None: - position_ids = packed_inputs.get("position_ids").to(device=device) - position_ids, rope_deltas = self.get_rope_index( position_ids=position_ids, attention_mask=attention_mask, @@ -990,11 +933,6 @@ def forward( ) self.rope_deltas = rope_deltas - if modality_tensor is None: - modality_tensor = torch.full( - (batch_size, seq_len), ModalityType.text.value, device=device, dtype=torch.long - ) - cos, sin = self.rotary_emb(position_ids, modality_tensor, hidden_states=inputs_embeds) decoder_position_ids = position_ids[..., 0] if position_ids.ndim == 3 else position_ids @@ -1002,7 +940,7 @@ def forward( if not isinstance(attention_mask, dict): attention_mask = create_masks_for_generate( config=self.config, - input_embeds=inputs_embeds, + inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, @@ -1022,7 +960,6 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=(cos, sin), - output_attentions=output_attentions, **kwargs, ) @@ -1032,9 +969,7 @@ def forward( return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values, - hidden_states=(hidden_states,), - attentions=None, + past_key_values=past_key_values if use_cache else None, ) @@ -1067,7 +1002,7 @@ def rotate_half(x): @use_kernel_func_from_hub("rotary_pos_emb") -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -1075,8 +1010,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -1106,32 +1039,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs: Unpack[TransformersKwargs], -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - @use_kernelized_func(apply_rotary_pos_emb) class IsaacAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -1167,11 +1074,11 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_values: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -1187,9 +1094,9 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self, @@ -1223,12 +1130,12 @@ def __init__(self, config: IsaacConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states @@ -1276,7 +1183,7 @@ class IsaacPreTrainedModel(PreTrainedModel): @auto_docstring class IsaacForConditionalGeneration(IsaacPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"} - _tp_plan = {"lm_head": "colwise_rep"} + _tp_plan = {"lm_head": "colwise_gather_output"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} config_class = IsaacConfig _can_compile_fullgraph = False @@ -1293,41 +1200,56 @@ def __init__(self, config: IsaacConfig): @auto_docstring @can_return_tuple - @check_model_inputs + @merge_with_config_defaults def forward( self, - input_ids: Optional[torch.LongTensor] = None, - packed_inputs: Optional[dict[str, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor | None = None, + modality_tensor: torch.LongTensor | None = None, + vision_patches: torch.Tensor | None = None, + vision_patch_attention_mask: torch.Tensor | None = None, + vision_token_grids: torch.LongTensor | None = None, + vision_token_offsets: torch.LongTensor | None = None, + vision_token_lengths: torch.LongTensor | None = None, + vision_image_attention_mask: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | CausalLMOutputWithPast: - """Run multimodal CausalLM forward, accepting packed vision/text inputs. - - Args: - packed_inputs (`dict`, *optional*): - Packed vision/text payload from ``IsaacProcessor`` containing modality ids, MRoPE position ids, and - vision patch tensors/grids (with optional offsets/lengths) used to rebuild embeddings. - - Returns: - CausalLMOutputWithPast: logits, optional loss, caches, hidden states, attentions. + r""" + modality_tensor (`torch.LongTensor`, *optional*): + Modality identifiers aligned with the token sequence, shaped `(batch_size, seq_len)`. + vision_patches (`torch.FloatTensor`, *optional*): + Padded per-image patch vectors of shape `(batch_size, max_images, max_patches, patch_dim)`. + vision_patch_attention_mask (`torch.LongTensor`, *optional*): + Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. + vision_token_grids (`torch.LongTensor`, *optional*): + Per-image patch grids `(h, w)` with shape `(batch_size, max_images, 2)`. + vision_token_offsets (`torch.LongTensor`, *optional*): + Start offsets inside the per-image vision embedding sequence, shape `(batch_size, max_images)`. + vision_token_lengths (`torch.LongTensor`, *optional*): + Number of vision tokens to consume per image, shape `(batch_size, max_images)`. + vision_image_attention_mask (`torch.LongTensor`, *optional*): + Mask indicating which image slots are populated, shape `(batch_size, max_images)`. """ - output_attentions = kwargs.pop("output_attentions", None) - outputs = self.model( input_ids=input_ids, - packed_inputs=packed_inputs, + modality_tensor=modality_tensor, + vision_patches=vision_patches, + vision_patch_attention_mask=vision_patch_attention_mask, + vision_token_grids=vision_token_grids, + vision_token_offsets=vision_token_offsets, + vision_token_lengths=vision_token_lengths, + vision_image_attention_mask=vision_image_attention_mask, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, cache_position=cache_position, **kwargs, ) @@ -1342,18 +1264,24 @@ def forward( logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, - attentions=outputs.attentions if output_attentions else None, + attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, - past_key_values: Optional[list[torch.FloatTensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - packed_inputs: Optional[dict[str, torch.Tensor]] = None, - cache_position: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + past_key_values: list[torch.FloatTensor] | None = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + modality_tensor: torch.LongTensor | None = None, + vision_patches: torch.Tensor | None = None, + vision_patch_attention_mask: torch.Tensor | None = None, + vision_token_grids: torch.LongTensor | None = None, + vision_token_offsets: torch.LongTensor | None = None, + vision_token_lengths: torch.LongTensor | None = None, + vision_image_attention_mask: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, **kwargs, ) -> dict[str, Any]: model_inputs = super().prepare_inputs_for_generation( @@ -1362,16 +1290,26 @@ def prepare_inputs_for_generation( attention_mask=attention_mask, inputs_embeds=inputs_embeds, cache_position=cache_position, - position_ids=position_ids, + position_ids=None, **kwargs, ) - if packed_inputs is None: + multimodal_inputs = { + "modality_tensor": modality_tensor, + "vision_patches": vision_patches, + "vision_patch_attention_mask": vision_patch_attention_mask, + "vision_token_grids": vision_token_grids, + "vision_token_offsets": vision_token_offsets, + "vision_token_lengths": vision_token_lengths, + "vision_image_attention_mask": vision_image_attention_mask, + } + if not any(value is not None for value in multimodal_inputs.values()): return model_inputs past_len = past_key_values.get_seq_length() if past_key_values is not None else 0 first_step = past_len == 0 - model_inputs["packed_inputs"] = packed_inputs if first_step else None - model_inputs["position_ids"] = None + for key, value in multimodal_inputs.items(): + model_inputs[key] = value if first_step else None + model_inputs["position_ids"] = position_ids if first_step else None return model_inputs diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index f8de4b070d80..851f9b4bc339 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,10 +16,11 @@ import copy import math -from collections.abc import Callable, Sequence +from collections.abc import Sequence from enum import IntEnum -from typing import Any, Optional, Union +from typing import Any +from ... import initialization as init from ...cache_utils import DynamicCache from ...configuration_utils import PretrainedConfig, layer_type_validation from ...feature_extraction_utils import BatchFeature @@ -35,23 +35,29 @@ from ...image_utils import ( PILImageResampling, ) -from ...masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, create_masks_for_generate, packed_sequence_mask_function +from ...masking_utils import create_bidirectional_mask, create_masks_for_generate from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_utils import PreTrainedModel from ...models.qwen3.configuration_qwen3 import Qwen3Config -from ...models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3ForCausalLM, Qwen3Model, Qwen3PreTrainedModel +from ...models.qwen3.modeling_qwen3 import ( + Qwen3Attention, + Qwen3DecoderLayer, + Qwen3ForCausalLM, + Qwen3Model, + Qwen3PreTrainedModel, +) from ...processing_utils import ProcessorMixin, Unpack -from ...utils import TensorType, auto_docstring +from ...utils import TensorType, auto_docstring, torch_compilable_check from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD -from ...utils.generic import OutputRecorder, TransformersKwargs, can_return_tuple, check_model_inputs, maybe_autocast +from ...utils.generic import TransformersKwargs, can_return_tuple, maybe_autocast, merge_with_config_defaults from ...utils.import_utils import ( is_torch_available, is_torchdynamo_compiling, is_torchvision_available, is_vision_available, ) -from ..qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling +from ...utils.output_capturing import OutputRecorder, capture_outputs from ..qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding from ..siglip2.configuration_siglip2 import Siglip2VisionConfig from ..siglip2.modeling_siglip2 import ( @@ -139,10 +145,21 @@ def __init__( class IsaacImageProcessorFastKwargs(ImagesKwargs, total=False): - patch_size: Optional[int] - max_num_patches: Optional[int] - min_num_patches: Optional[int] - pixel_shuffle_scale: Optional[int] + """ + patch_size (`int`, *optional*): + Side length (in pixels) for square patches extracted from resized images. + max_num_patches (`int`, *optional*): + Upper bound on extracted patches per image after resizing. + min_num_patches (`int`, *optional*): + Lower bound on extracted patches per image after resizing. + pixel_shuffle_scale (`int`, *optional*): + Pixel-shuffle reduction factor applied in the vision tower. + """ + + patch_size: int | None + max_num_patches: int | None + min_num_patches: int | None + pixel_shuffle_scale: int | None @auto_docstring @@ -156,10 +173,10 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): do_resize = True do_center_crop = False - patch_size: Optional[int] = 16 - max_num_patches: Optional[int] = 256 - min_num_patches: Optional[int] = None - pixel_shuffle_scale: Optional[int] = 1 + patch_size: int | None = 16 + max_num_patches: int | None = 256 + min_num_patches: int | None = None + pixel_shuffle_scale: int | None = 1 do_pad = False do_rescale = True do_normalize = True @@ -168,19 +185,9 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): do_convert_rgb = True disable_grouping = False - def __init__( - self, - **kwargs: Unpack[IsaacImageProcessorFastKwargs], - ) -> None: - super().__init__(**kwargs) - def _validate_preprocess_kwargs(self, **kwargs): # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for. kwargs.pop("do_resize", None) - kwargs.pop("size", None) - kwargs.pop("do_center_crop", None) - kwargs.pop("crop_size", None) - kwargs.pop("disable_grouping", None) return super()._validate_preprocess_kwargs(**kwargs) def resize( @@ -189,33 +196,25 @@ def resize( size: SizeDict, **kwargs, ) -> torch.Tensor: - resize_kwargs: dict[str, Any] = {"align_corners": False} - resize_mode = "bilinear" - - return F.interpolate( - image, - size=(size.height, size.width), - mode=resize_mode, - **resize_kwargs, - ) + return F.interpolate(image, size=(size.height, size.width), mode="bilinear", align_corners=False) def _preprocess( self, images: list[torch.Tensor], do_resize: bool, - interpolation: Optional[Any], - do_rescale: Optional[bool], - rescale_factor: Optional[float], - do_normalize: Optional[bool], - image_mean: Optional[Union[float, Sequence[float]]], - image_std: Optional[Union[float, Sequence[float]]], - disable_grouping: Optional[bool] = None, - return_tensors: Optional[Union[str, TensorType]] = None, + interpolation: Any | None, + do_rescale: bool | None, + rescale_factor: float | None, + do_normalize: bool | None, + image_mean: float | Sequence[float] | None, + image_std: float | Sequence[float] | None, + disable_grouping: bool | None = None, + return_tensors: str | TensorType | None = None, *, - patch_size: Optional[int] = None, - max_num_patches: Optional[int] = None, - min_num_patches: Optional[int] = None, - pixel_shuffle_scale: Optional[int] = None, + patch_size: int | None = None, + max_num_patches: int | None = None, + min_num_patches: int | None = None, + pixel_shuffle_scale: int | None = None, **kwargs, ) -> BatchFeature: grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) @@ -334,308 +333,130 @@ def __init__(self, config: IsaacVisionConfig): ) nn.init.normal_(self.position_embedding) - @check_model_inputs - def forward(self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor) -> torch.Tensor: - # Rebatch packed variable-resolution patches to resize per-image position embeddings - # and track lengths for varlen attention metadata. - packed_pixel_values, seq_lengths = self._pack_to_batch(seq_patches, spatial_shapes) - if packed_pixel_values is None: - return seq_patches.new_zeros((0, self.embed_dim)) - + @merge_with_config_defaults + @capture_outputs + def forward( + self, + pixel_values: torch.Tensor, + spatial_shapes: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + # pixel_values: (num_images, max_patches, patch_dim) target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding(packed_pixel_values.to(dtype=target_dtype)) + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) - positional_embeddings = self.position_embedding resized_positional_embeddings = self.resize_positional_embeddings( - positional_embeddings, spatial_shapes, max_length=packed_pixel_values.shape[1] + self.position_embedding, + spatial_shapes, + max_length=pixel_values.shape[1], ) - embeddings = patch_embeds + resized_positional_embeddings - return self._unpack_from_batch(embeddings, seq_lengths) - - def _pack_to_batch( - self, - seq_patches: torch.Tensor, - spatial_shapes: torch.Tensor, - ) -> tuple[Optional[torch.Tensor], torch.Tensor]: - """Rebatch a packed patch sequence using per-image grids to align embeddings. - Args: - seq_patches: Packed patches of shape (total_patches, patch_dim). - spatial_shapes: Per-image patch grids of shape (num_images, 2) as (H_tokens, W_tokens). - - Returns: - (packed_pixel_values, seq_lengths) where: - - packed_pixel_values: (batch, max_len, patch_dim) padded with zeros, or None if batch_size == 0 - - seq_lengths: (batch,) lengths for each image - """ - seq_lengths = spatial_shapes.long().prod(dim=-1) # (B,) - batch_size = int(seq_lengths.numel()) - if batch_size == 0: - return None, seq_lengths + if attention_mask is not None: + embeddings = embeddings * attention_mask.unsqueeze(-1).to(dtype=embeddings.dtype) - # Split the packed sequence into per-image chunks, then pad to a batch - lengths_list = seq_lengths.tolist() - chunks = seq_patches.split(lengths_list, dim=0) - packed_pixel_values = nn.utils.rnn.pad_sequence(chunks, batch_first=True) # zero-padded by default - return packed_pixel_values, seq_lengths - - def _unpack_from_batch(self, embeddings: torch.Tensor, seq_lengths: torch.Tensor) -> torch.Tensor: - """Flatten a padded batch back to packed sequence order using `seq_lengths`.""" - lengths = seq_lengths.to(device=embeddings.device).tolist() - chunks = [embeddings[i, :l] for i, l in enumerate(lengths) if l > 0] - return torch.cat(chunks, dim=0) + return embeddings class IsaacVisionAttention(Siglip2Attention): """Custom attention that supports variable-length sequences with flash/SDPA backends.""" - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - **kwargs, - ): - batch_size, seq_length, embed_dim = hidden_states.shape - queries = self.q_proj(hidden_states) - keys = self.k_proj(hidden_states) - values = self.v_proj(hidden_states) - - queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - - attn_impl = self.config._attn_implementation - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"] - if attn_impl != "sdpa": - attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] - - seq_sizes = kwargs.pop("seq_sizes", None) - - attention_kwargs: dict[str, Any] = { - "is_causal": False, - "scaling": self.scale, - } - - if seq_sizes is not None and seq_sizes.numel() > 0: - if attn_impl in {"flash_attention_2", "flash_attention_3"}: - cu_seqlens = F.pad(seq_sizes.cumsum(0).to(torch.int32), (1, 0)) - max_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - attention_kwargs.update( - { - "cu_seq_lens_q": cu_seqlens, - "cu_seq_lens_k": cu_seqlens, - "max_length_q": max_len, - "max_length_k": max_len, - } - ) - else: - seg_ids = torch.repeat_interleave( - torch.arange(seq_sizes.numel(), device=seq_sizes.device), seq_sizes - ).view(1, -1) - mask_function = packed_sequence_mask_function(seg_ids) - cache_position = torch.arange(seq_length, device=hidden_states.device, dtype=torch.long) - mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[attn_impl] - attention_mask = mask_interface( - batch_size=batch_size, - cache_position=cache_position, - kv_length=seq_length, - kv_offset=0, - mask_function=mask_function, - attention_mask=attention_mask, - allow_is_causal_skip=False, - allow_is_bidirectional_skip=False, - dtype=hidden_states.dtype, - config=self.config, - use_vmap=False, - ) - else: - attention_mask = None - - attn_output, attn_weights = attention_interface( - self, - queries, - keys, - values, - attention_mask, - **attention_kwargs, - ) - attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights + pass class IsaacVisionEncoderLayer(Siglip2EncoderLayer): - """Isaac vision encoder layer with variable-length attention.""" + """Isaac vision encoder layer using the shared attention interfaces.""" def __init__(self, config: IsaacVisionConfig): super().__init__(config) self.self_attn = IsaacVisionAttention(config) - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - **kwargs: Unpack[TransformersKwargs], - ): - r""" - Variable-length metadata (e.g., `seq_sizes`) flows via `**kwargs` to attention for backend-specific handling. - """ - # Run attention directly so variable-length metadata reaches FlashAttention. - residual = hidden_states - hidden_states = self.layer_norm1(hidden_states) - attn_output, _ = self.self_attn( - hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - **kwargs, - ) - - hidden_states = residual + attn_output - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states - class IsaacVisionEncoder(Siglip2Encoder): - """Encoder using Isaac encoder layers with variable-length attention support.""" + """Encoder using Isaac encoder layers.""" def __init__(self, config: IsaacVisionConfig): super().__init__(config) self.layers = nn.ModuleList([IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) -def create_pixel_shuffle_index_map( - seq_sizes: torch.Tensor, - token_grids: torch.Tensor, - scale_factor: int = 1, - device: Optional[torch.device] = None, -) -> torch.Tensor: - """ - Build a gather-index map that tells us, for every *output* token after - pixel-shuffle, which `scale_factor**2` *input* tokens are being merged. - - Args - ---- - seq_sizes : (num_images,) - #patches in each image (row-major order) - token_grids : (num_images,2) - (height, width) for every image - scale_factor : spatial down-scale factor (โ‰ฅ2) - device : (optional) overrides `seq_sizes.device` - - Returns - ------- - gather_idx : (new_total_seq_len, scale_factor**2) int64 tensor. - gather_idx[i, j] is the *flat* index into the *original* - packed sequence for the j-th sub-patch that forms the - i-th output token. - """ - if not is_torchdynamo_compiling(): - if (token_grids % scale_factor).any(): - raise AssertionError( - f"Every (H,W) in token_grids must be divisible by scale_factor={scale_factor}, got {token_grids.tolist()}" - ) - - gather_chunks: list[torch.Tensor] = [] - tok_offset = 0 - for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist()): - # Flat indices for this image's packed segment - grid = torch.arange(seq_len, device=device, dtype=torch.int64).view(h, w) + tok_offset - - # Block into (H/s, W/s) groups; each group contributes s*s indices - grid = ( - grid.view(h // scale_factor, scale_factor, w // scale_factor, scale_factor) - .permute(0, 2, 1, 3) - .contiguous() - ) - gather_chunks.append(grid.view(-1, scale_factor * scale_factor)) - - tok_offset += seq_len - - return torch.cat(gather_chunks, dim=0) - - -def pixel_shuffle_varlen( +def pixel_shuffle_padded( x: torch.Tensor, token_grids: torch.Tensor, scale_factor: int = 1, ) -> torch.Tensor: - r"""Apply pixel shuffle to a packed vision sequence without unpacking per image. + """Apply pixel shuffle per image on padded batched vision embeddings. Args: x (`torch.Tensor`): - Concatenated vision embeddings. Accepts `(seq_len, hidden_size)` or `(1, seq_len, hidden_size)` shapes - produced by stacking image patches. + Vision embeddings of shape `(num_images, max_patches, hidden_size)`. token_grids (`torch.Tensor`): - Integer tensor of shape `(num_images, 2)` whose rows give the `(height, width)` patch grid sizes - corresponding to each image segment inside `x`. + Grid sizes `(height, width)` per image, shape `(num_images, 2)`. scale_factor (`int`, *optional*, defaults to 1): - Spatial down-sampling factor specific to pixel shuffle. Values greater than one merge `scale_factor**2` neighboring patches into a - single embedding channel-group. + Spatial down-sampling factor. Returns: - `torch.Tensor`: Pixel-shuffled embeddings with shape matching the input convention: - `(seq_len, hidden_size * scale_factor**2)` when the input was 2D, or `(1, seq_len, hidden_size * scale_factor**2)` - if the singleton batch dimension was present. - - Raises: - ValueError: If more than one batch item is provided. + Tuple of: + - pixel-shuffled embeddings `(num_images, max_tokens, hidden_size * scale_factor**2)` + - attention mask `(num_images, max_tokens)` + - per-image valid token lengths `(num_images,)` """ - return_with_batch_dim = x.dim() == 3 - if return_with_batch_dim: - if x.size(0) != 1: - raise ValueError( - f"Packed vision sequences expect a singleton batch dimension; received batch_size={x.size(0)}." - ) - embeddings = x.squeeze(0) # (seq, embed) - else: - embeddings = x # (seq, embed) + num_images, max_patches, embed_dim = x.shape + output_dim = embed_dim * scale_factor * scale_factor - embed_dim = embeddings.size(-1) - scale_factor = int(scale_factor) + token_grids = token_grids.to(device=x.device, dtype=torch.long) + heights = token_grids[:, 0] + widths = token_grids[:, 1] + full_lengths = heights * widths - # Calculate seq_sizes from token_grids - seq_sizes = torch.prod(token_grids, dim=-1) + non_empty = full_lengths > 0 + if not is_torchdynamo_compiling(): + divisible = ((heights % scale_factor) == 0) & ((widths % scale_factor) == 0) + torch_compilable_check( + (~non_empty) | divisible, + f"Every non-empty (H, W) grid must be divisible by pixel_shuffle_scale={scale_factor}.", + ) + + output_lengths = (heights // scale_factor) * (widths // scale_factor) + max_output_tokens = output_lengths.max() + shuffled_4d = x.new_zeros((num_images, max_output_tokens, scale_factor * scale_factor, embed_dim)) - # Build a single gather index so pixel shuffle works on the packed stream - # without unpacking per-image grids. - gather_idx = create_pixel_shuffle_index_map( - seq_sizes=seq_sizes, - token_grids=token_grids, - scale_factor=scale_factor, - device=embeddings.device, - ) # (new_seq, scale_factor**2) + token_positions = torch.arange(max_patches, device=x.device, dtype=torch.long).unsqueeze(0).expand(num_images, -1) + valid_token_mask = token_positions < full_lengths.unsqueeze(1) - # Gather โ†’ (new_seq, scale_factor**2, embed_dim) - gathered = embeddings[gather_idx] # fancy indexing keeps gradient + safe_widths = torch.where(widths > 0, widths, torch.ones_like(widths)) + row_index = torch.div(token_positions, safe_widths.unsqueeze(1), rounding_mode="floor") + col_index = token_positions.remainder(safe_widths.unsqueeze(1)) - # Merge the scale_factor**2 group dimension into channels to finish the shuffle - out = gathered.reshape(gathered.size(0), embed_dim * scale_factor * scale_factor) + output_widths = widths.div(scale_factor, rounding_mode="floor") + output_index = row_index.div(scale_factor, rounding_mode="floor") * output_widths.unsqueeze(1) + output_index = output_index + col_index.div(scale_factor, rounding_mode="floor") + sub_index = row_index.remainder(scale_factor) * scale_factor + col_index.remainder(scale_factor) - # Restore batch dimension if needed - if return_with_batch_dim: - out = out.unsqueeze(0) - return out + batch_index = torch.arange(num_images, device=x.device, dtype=torch.long).unsqueeze(1).expand_as(token_positions) + shuffled_4d[batch_index[valid_token_mask], output_index[valid_token_mask], sub_index[valid_token_mask]] = x[ + valid_token_mask + ] + + shuffled = shuffled_4d.view(num_images, max_output_tokens, output_dim) + return shuffled class IsaacVisionTransformer(PreTrainedModel): - """Vision tower that packs variable-resolution patches, applies varlen attention, and pixel-shuffles outputs. + """Vision tower for padded variable-resolution patches with per-image masks. Args: config (IsaacVisionConfig): Vision configuration with pixel-shuffle and patching parameters. Inputs: - packed_seq_patches (Tuple[Tensor, Tensor]): ``(patches, token_grids)`` where ``patches`` is a packed - patch sequence and ``token_grids`` holds per-image (H_tokens, W_tokens). + vision_tokens (Tuple[Tensor, Tensor, Optional[Tensor]]): + `(patches, token_grids, patch_attention_mask)` where: + - `patches`: `(num_images, max_patches, patch_dim)` + - `token_grids`: `(num_images, 2)` with per-image `(H_tokens, W_tokens)` + - `patch_attention_mask`: `(num_images, max_patches)` or `None` Returns: - torch.Tensor: Vision embeddings after encoder + pixel shuffle, shaped ``(seq_len, hidden_size * s^2)``. + Tuple of `(pixel_shuffled_features, attention_mask, token_lengths)`. """ _supports_sdpa = True @@ -649,44 +470,42 @@ def __init__(self, config: IsaacVisionConfig): self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor + self.post_init() + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, IsaacVisionEmbeddings): init.zeros_(module.position_embedding) - def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): - seq_patches, token_grids = packed_seq_patches - seq_sizes = torch.prod(token_grids, dim=-1) - - # Get embeddings from packed sequence - hidden_states = self.embeddings(seq_patches, token_grids) - - # Add a pseudo batch dimension so we can reuse the batch-first encoder stack - # while still driving per-image sequence metadata through the varlen attention path. - hidden_states = hidden_states.unsqueeze(0) + def forward( + self, + vision_tokens: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor | None], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if len(vision_tokens) == 2: + seq_patches, token_grids = vision_tokens + vision_patch_attention_mask = None + else: + seq_patches, token_grids, vision_patch_attention_mask = vision_tokens + hidden_states = self.embeddings( + seq_patches, + token_grids, + attention_mask=vision_patch_attention_mask, + ) - # Pass through encoder with variable-length metadata for attention - encoder_outputs = self.encoder( + encoder_attention_mask = create_bidirectional_mask( + config=self.config, inputs_embeds=hidden_states, - attention_mask=None, - seq_sizes=seq_sizes, + attention_mask=vision_patch_attention_mask, ) - hidden_states = encoder_outputs.last_hidden_state - - # Apply final layer normalization - hidden_states = self.post_layernorm(hidden_states) + encoder_outputs = self.encoder(inputs_embeds=hidden_states, attention_mask=encoder_attention_mask) + hidden_states = self.post_layernorm(encoder_outputs.last_hidden_state) - hidden_states = pixel_shuffle_varlen( + return pixel_shuffle_padded( x=hidden_states, token_grids=token_grids, scale_factor=self.pixel_shuffle_scale_factor, ) - # Remove the pseudo batch dimension we added earlier - hidden_states = hidden_states.squeeze(0) - - # Return the full sequence of embeddings - return hidden_states class IsaacMultiModalProjector(nn.Module): @@ -717,9 +536,14 @@ def __init__(self, config: IsaacConfig): self.vision_tower = IsaacVisionTransformer(vision_cfg) self.multimodal_projector = IsaacMultiModalProjector(config) - def forward(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - hidden_states = self.vision_tower(vision_tokens) - return self.multimodal_projector(hidden_states) + def forward( + self, + vision_tokens: tuple[torch.Tensor, torch.Tensor, torch.Tensor | None], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + vision_patches, token_grids, vision_patch_attention_mask = vision_tokens + hidden_states = self.vision_tower((vision_patches, token_grids, vision_patch_attention_mask)) + projected = self.multimodal_projector(hidden_states) + return projected def get_scaled_image_size( @@ -740,7 +564,7 @@ def get_image_size_for_max_num_patches( image_width: int, patch_size: int, max_num_patches: int, - min_num_patches: Optional[int] = None, + min_num_patches: int | None = None, eps: float = 1e-5, pixel_shuffle_scale: int = 1, ) -> tuple[int, int]: @@ -826,15 +650,13 @@ class IsaacConfig(PretrainedConfig): def __init__( self, - vision_config: Optional[IsaacVisionConfig] = None, - text_config: Optional[Union[Qwen3Config, dict]] = None, + vision_config: IsaacVisionConfig | None = None, + text_config: Qwen3Config | dict | None = None, vision_rescale_factor: float = 1 / 255, max_sequence_length: int = 16384, vision_token: str = "", **kwargs, ): - attn_implementation = kwargs.get("attn_implementation") - if isinstance(text_config, dict): self.text_config = self.sub_configs["text_config"](**text_config) elif isinstance(text_config, Qwen3Config): @@ -842,6 +664,13 @@ def __init__( elif text_config is None: self.text_config = self.sub_configs["text_config"]() + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif isinstance(vision_config, IsaacVisionConfig): + self.vision_config = vision_config + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + # Seed RoPE parameters before base init so the shared mixin can standardize/validate them. self.rope_parameters = getattr(self.text_config, "rope_parameters", None) self.layer_types = getattr(self.text_config, "layer_types", None) @@ -866,23 +695,6 @@ def __init__( self.layer_types = getattr(self.text_config, "layer_types", None) layer_type_validation(self.layer_types, self.num_hidden_layers) - # Handle vision config - either dict or IsaacVisionConfig instance - if isinstance(vision_config, dict): - self.vision_config = self.sub_configs["vision_config"](**vision_config) - elif isinstance(vision_config, IsaacVisionConfig): - self.vision_config = vision_config - elif vision_config is None: - self.vision_config = self.sub_configs["vision_config"]() - - # Propagate user-requested attention backend to the vision sub-config when provided. - if attn_implementation is not None: - if isinstance(attn_implementation, dict): - vision_attn = attn_implementation.get("vision_config", attn_implementation.get("", None)) - else: - vision_attn = attn_implementation - if vision_attn is not None: - self.vision_config._attn_implementation = vision_attn - if getattr(self, "_attn_implementation", None) is None: self._attn_implementation = "sdpa" # Vision normalization parameters @@ -892,15 +704,6 @@ def __init__( self.max_sequence_length = max_sequence_length self.vision_token = vision_token - def to_dict(self): - output = super().to_dict() - # Ensure nested configs round-trip through dict serialization - if hasattr(self, "text_config") and self.text_config is not None: - output["text_config"] = self.text_config.to_dict() - if hasattr(self, "vision_config") and self.vision_config is not None: - output["vision_config"] = self.vision_config.to_dict() - return output - class IsaacProcessor(ProcessorMixin): """Processor that pairs the Isaac image processor with the Qwen2 tokenizer. @@ -914,7 +717,7 @@ class IsaacProcessor(ProcessorMixin): config (IsaacConfig | dict, optional): If provided, overrides processor defaults from the model config. Returns: - BatchFeature: Contains ``input_ids`` and ``packed_inputs`` (patch tensors, grids, offsets, lengths, modality, positions). + BatchFeature: Top-level batched text and vision tensors. """ attributes = ["image_processor", "tokenizer"] @@ -929,8 +732,8 @@ def __init__( *, vision_token: str = "", max_sequence_length: int = 16384, - rescale_factor: Optional[float] = None, - config: Optional[Union[IsaacConfig, dict]] = None, + rescale_factor: float | None = None, + config: IsaacConfig | dict | None = None, ) -> None: if isinstance(config, dict): config = IsaacConfig(**config) @@ -960,254 +763,260 @@ def __init__( self.vision_token = vision_token self.max_sequence_length = max_sequence_length - def _pack_batch( - self, texts: list[str], images_list: Optional[list[Optional[list[Image]]]] - ) -> dict[str, Optional[torch.Tensor]]: - if images_list is None: - pairs = ((t, None) for t in texts) - else: + def _build_batch( + self, + text: str | list[str], + images: Image | list[Image] | None = None, + ) -> dict[str, torch.Tensor | None]: + texts = [text] if isinstance(text, str) else text + if images is None: + pairs = ((text_value, None) for text_value in texts) + elif isinstance(images, list) and len(images) == len(texts): + if not images: + images_list = [] + elif isinstance(images[0], list): + images_list = images + else: + images_list = [[image] for image in images] pairs = zip(texts, images_list, strict=True) - - per_sample: list[dict[str, Optional[torch.Tensor]]] = [] - for txt, imgs in pairs: - if imgs is not None and isinstance(imgs, Image): - imgs = [imgs] - per_sample.append(self._pack_single(txt, imgs)) - - lengths = [int(p["input_ids"].shape[1]) for p in per_sample] - max_len = max(lengths, default=0) - batch = len(per_sample) - - # Use first device with data as anchor - base_device = torch.device("cpu") - for p in per_sample: - if p["input_ids"].numel() > 0: - base_device = p["input_ids"].device - break - - pad_id = self.text_pad_token_id - padded_input_ids = torch.full((batch, max_len), pad_id, device=base_device, dtype=torch.long) - padded_modality = torch.full((batch, max_len), ModalityType.text.value, device=base_device, dtype=torch.long) - padded_position_ids = torch.zeros((batch, max_len, 3), device=base_device, dtype=torch.long) - - for i, (sample, l) in enumerate(zip(per_sample, lengths)): - if l: - padded_input_ids[i, -l:] = sample["input_ids"][0] - padded_modality[i, -l:] = sample["modality_tensor"][0] - padded_position_ids[i, -l:] = sample["position_ids"][0] - - # Vision-side aggregation - v_samples = [(b, s) for b, s in enumerate(per_sample) if s["vision_patches"] is not None] - if v_samples: - vision_patches_list = [s["vision_patches"] for _, s in v_samples] - vision_grids_list = [s["vision_token_grids"] for _, s in v_samples] - vision_offsets_list = [s["vision_token_offsets"] for _, s in v_samples] - vision_lengths_list = [s["vision_token_lengths"] for _, s in v_samples] - vision_batch_indices = [torch.full_like(s["vision_token_offsets"], b) for b, s in v_samples] - - vision_patches = torch.cat(vision_patches_list, dim=0) - vision_token_grids = torch.cat(vision_grids_list, dim=0) - vision_token_offsets = torch.cat(vision_offsets_list, dim=0) - vision_token_lengths = torch.cat(vision_lengths_list, dim=0) - vision_token_batch_indices = torch.cat(vision_batch_indices, dim=0) else: - vision_patches = vision_token_grids = vision_token_offsets = vision_token_lengths = ( - vision_token_batch_indices - ) = None - - return { - "input_ids": padded_input_ids, - "vision_patches": vision_patches, - "vision_token_grids": vision_token_grids, - "vision_token_offsets": vision_token_offsets, - "vision_token_lengths": vision_token_lengths, - "vision_token_batch_indices": vision_token_batch_indices, - "modality_tensor": padded_modality, - "position_ids": padded_position_ids, - } - - def _pack_single(self, text: str, images: Optional[list[Image]]) -> dict[str, Optional[torch.Tensor]]: - segments = text.split(self.vision_token) # Parse by vision_token; interleave text segments and image segments. - num_images = len(segments) - 1 - items: list[dict[str, Any]] = [] - total = 0 - num_provided_images = len(images) if images is not None else 0 - if not num_images == num_provided_images: - raise ValueError( - f"IsaacProcessor expects one image per image token, got {num_images} tokens and {num_provided_images} images in sample with text {text} " + pairs = ( + ( + text_value, + None + if text_value.count(self.vision_token) == 0 + else images + if isinstance(images, list) + else [images], + ) + for text_value in texts ) - for index, segment in enumerate(segments): - if segment: - tok = ( - self.tokenizer.encode(segment, add_special_tokens=False, return_tensors="pt") - .squeeze(0) - .to(torch.long) - ) - segment_length = int(tok.numel()) - items.append({"type": "text", "segment_length": segment_length, "tok": tok}) - total += segment_length - - if index < num_images: - feat = self.image_processor(images=images[index], return_tensors=TensorType.PYTORCH) - patches = feat["patches"][0].reshape(-1, feat["patches"].shape[-1]) - - virtual_pixel_size = feat["virtual_pixel_size"][0].to(torch.long).tolist() - real_pixel_size = feat["real_pixel_size"][0].to(torch.long).tolist() - dims = tuple((virtual_pixel_size + [1, 1, 1])[:3]) # (T,H,W) in virtual space - segment_length = int(dims[0] * dims[1] * dims[2]) - - items.append( - { - "type": "image", - "segment_length": segment_length, - "dims": dims, - "patches": patches, - "grid": (int(real_pixel_size[1]), int(real_pixel_size[2])), - } - ) - total += segment_length - - # Tail crop window. - start = max(0, total - self.max_sequence_length) - end = total - - image_pad_value = self.image_pad_token_id - base_device: Optional[torch.device] = None - position_ids, modality, input_ids = [], [], [] - vpatches, grids, vision_token_offsets, vision_token_lengths = [], [], [], [] - - global_offset = 0 - position_offset = 0 - - for item in items: - segment_length = int(item["segment_length"]) - current_window_start = max(start, global_offset) - current_window_end = min(end, global_offset + segment_length) - has_overlap = current_window_end > current_window_start - - if has_overlap and base_device is None: - base_device = item["patches"].device if item["type"] == "image" else item["tok"].device - - if has_overlap: - segment_local_start = int(current_window_start - global_offset) - segment_local_end = int(current_window_end - global_offset) - segment_local_indices = torch.arange( - segment_local_start, segment_local_end, device=base_device, dtype=torch.long + sample_input_ids: list[torch.Tensor] = [] + sample_modality: list[torch.Tensor] = [] + sample_position_ids: list[torch.Tensor] = [] + sample_vision_patches: list[list[torch.Tensor]] = [] + sample_vision_grids: list[torch.Tensor] = [] + sample_vision_offsets: list[torch.Tensor] = [] + sample_vision_lengths: list[torch.Tensor] = [] + + for text_value, sample_images in pairs: + segments = text_value.split(self.vision_token) + num_images = len(segments) - 1 + num_provided_images = len(sample_images) if sample_images is not None else 0 + if num_images != num_provided_images: + raise ValueError( + f"IsaacProcessor expects one image per image token, got {num_images} tokens and {num_provided_images} images in sample with text {text_value} " ) - segment_kept_length = segment_local_end - segment_local_start - - if item["type"] == "text": - slice_index = segment_local_indices + position_offset - zero_axis_pad = torch.zeros_like(slice_index) - position_ids.append(torch.stack((slice_index, zero_axis_pad, zero_axis_pad), -1)) - modality.append( - torch.full( - (segment_kept_length,), ModalityType.text.value, device=base_device, dtype=torch.long - ) + + items: list[dict[str, Any]] = [] + total = 0 + for index, segment in enumerate(segments): + if segment: + text_tokens = ( + self.tokenizer.encode(segment, add_special_tokens=False, return_tensors="pt") + .squeeze(0) + .to(torch.long) ) - input_ids.append(item["tok"].to(base_device)[segment_local_start:segment_local_end]) - position_offset += segment_length - else: - num_pos_slices, grid_height_tokens, grid_width_tokens = item["dims"] - hw = grid_height_tokens * grid_width_tokens - slice_index = (segment_local_indices // hw) + position_offset - rem = segment_local_indices % hw - row_index = rem // grid_width_tokens - col_index = rem % grid_width_tokens - position_ids.append(torch.stack((slice_index, row_index, col_index), -1)) - modality.append( - torch.full( - (segment_kept_length,), ModalityType.image.value, device=base_device, dtype=torch.long - ) + segment_length = int(text_tokens.numel()) + items.append({"type": "text", "segment_length": segment_length, "tokens": text_tokens}) + total += segment_length + + if index < num_images: + feature = self.image_processor(images=sample_images[index], return_tensors=TensorType.PYTORCH) + patches = feature["patches"][0].reshape(-1, feature["patches"].shape[-1]) + virtual_pixel_size = feature["virtual_pixel_size"][0].to(torch.long).tolist() + real_pixel_size = feature["real_pixel_size"][0].to(torch.long).tolist() + dims = tuple((virtual_pixel_size + [1, 1, 1])[:3]) + segment_length = int(dims[0] * dims[1] * dims[2]) + items.append( + { + "type": "image", + "segment_length": segment_length, + "dims": dims, + "patches": patches, + "grid": (int(real_pixel_size[1]), int(real_pixel_size[2])), + } ) - input_ids.append( - torch.full((segment_kept_length,), image_pad_value, device=base_device, dtype=torch.long) + total += segment_length + + start = max(0, total - self.max_sequence_length) + end = total + base_device: torch.device | None = None + input_ids_chunks, modality_chunks, position_chunks = [], [], [] + vision_patches, vision_grids, vision_offsets, vision_lengths = [], [], [], [] + global_offset = 0 + position_offset = 0 + + for item in items: + segment_length = int(item["segment_length"]) + current_window_start = max(start, global_offset) + current_window_end = min(end, global_offset + segment_length) + has_overlap = current_window_end > current_window_start + + if has_overlap and base_device is None: + base_device = item["patches"].device if item["type"] == "image" else item["tokens"].device + + if has_overlap: + segment_local_start = int(current_window_start - global_offset) + segment_local_end = int(current_window_end - global_offset) + segment_local_indices = torch.arange( + segment_local_start, segment_local_end, device=base_device, dtype=torch.long ) + segment_kept_length = segment_local_end - segment_local_start + + if item["type"] == "text": + slice_index = segment_local_indices + position_offset + zero_axis = torch.zeros_like(slice_index) + position_chunks.append(torch.stack((slice_index, zero_axis, zero_axis), -1)) + modality_chunks.append( + torch.full( + (segment_kept_length,), ModalityType.text.value, device=base_device, dtype=torch.long + ) + ) + input_ids_chunks.append(item["tokens"].to(base_device)[segment_local_start:segment_local_end]) + position_offset += segment_length + else: + num_pos_slices, grid_height_tokens, grid_width_tokens = item["dims"] + hw = grid_height_tokens * grid_width_tokens + slice_index = (segment_local_indices // hw) + position_offset + rem = segment_local_indices % hw + position_chunks.append( + torch.stack((slice_index, rem // grid_width_tokens, rem % grid_width_tokens), -1) + ) + modality_chunks.append( + torch.full( + (segment_kept_length,), ModalityType.image.value, device=base_device, dtype=torch.long + ) + ) + input_ids_chunks.append( + torch.full( + (segment_kept_length,), self.image_pad_token_id, device=base_device, dtype=torch.long + ) + ) - vpatches.append(item["patches"].to(base_device)) # full patches; slice later via offsets/lengths - # Record per-image slice boundaries so we can drop cropped virtual tokens - # after pixel shuffle without re-packing the entire vision stream. - grids.append(item["grid"]) - vision_token_offsets.append(segment_local_start) - vision_token_lengths.append(segment_kept_length) - - position_offset += int(num_pos_slices) + vision_patches.append(item["patches"].to(base_device)) + vision_grids.append(item["grid"]) + vision_offsets.append(segment_local_start) + vision_lengths.append(segment_kept_length) + position_offset += int(num_pos_slices) + else: + position_offset += segment_length if item["type"] == "text" else int(item["dims"][0]) - else: - position_offset += segment_length if item["type"] == "text" else int(item["dims"][0]) + global_offset += segment_length - global_offset += segment_length + if base_device is None: + base_device = torch.device("cpu") - if base_device is None: - base_device = torch.device("cpu") + sample_input_ids.append( + torch.cat(input_ids_chunks, 0) + if input_ids_chunks + else torch.zeros((0,), device=base_device, dtype=torch.long) + ) + sample_modality.append( + torch.cat(modality_chunks, 0) + if modality_chunks + else torch.zeros((0,), device=base_device, dtype=torch.long) + ) + sample_position_ids.append( + torch.cat(position_chunks, 0) + if position_chunks + else torch.zeros((0, 3), device=base_device, dtype=torch.long) + ) + sample_vision_patches.append(vision_patches) + if vision_patches: + sample_vision_grids.append(torch.tensor(vision_grids, device=base_device, dtype=torch.long)) + sample_vision_offsets.append(torch.tensor(vision_offsets, device=base_device, dtype=torch.long)) + sample_vision_lengths.append(torch.tensor(vision_lengths, device=base_device, dtype=torch.long)) + else: + sample_vision_grids.append(torch.zeros((0, 2), device=base_device, dtype=torch.long)) + sample_vision_offsets.append(torch.zeros((0,), device=base_device, dtype=torch.long)) + sample_vision_lengths.append(torch.zeros((0,), device=base_device, dtype=torch.long)) - modality_tensor = ( - torch.cat(modality, 0).unsqueeze(0) - if modality - else torch.zeros((1, 0), device=base_device, dtype=torch.long) - ) - position_ids = ( - torch.cat(position_ids, 0).unsqueeze(0) - if position_ids - else torch.zeros((1, 0, 3), device=base_device, dtype=torch.long) + batch_size = len(sample_input_ids) + lengths = [int(sample_input.shape[0]) for sample_input in sample_input_ids] + max_len = max(lengths, default=0) + base_device = next( + (sample_input.device for sample_input in sample_input_ids if sample_input.numel() > 0), + torch.device("cpu"), ) - input_ids = ( - torch.cat(input_ids, 0).unsqueeze(0) - if input_ids - else torch.zeros((1, 0), device=base_device, dtype=torch.long) + + input_ids = torch.full((batch_size, max_len), self.text_pad_token_id, device=base_device, dtype=torch.long) + attention_mask = torch.zeros((batch_size, max_len), device=base_device, dtype=torch.long) + modality_tensor = torch.full( + (batch_size, max_len), ModalityType.text.value, device=base_device, dtype=torch.long ) + position_ids = torch.zeros((batch_size, max_len, 3), device=base_device, dtype=torch.long) - if vpatches: - vision_patches = torch.cat(vpatches, 0) - vision_token_grids = torch.tensor(grids, device=base_device, dtype=torch.long) - vision_token_offsets = torch.tensor(vision_token_offsets, device=base_device, dtype=torch.long) - vision_token_lengths = torch.tensor(vision_token_lengths, device=base_device, dtype=torch.long) + for batch_idx, length in enumerate(lengths): + if length == 0: + continue + input_ids[batch_idx, -length:] = sample_input_ids[batch_idx] + attention_mask[batch_idx, -length:] = 1 + modality_tensor[batch_idx, -length:] = sample_modality[batch_idx] + position_ids[batch_idx, -length:] = sample_position_ids[batch_idx] + + image_counts = [len(patches) for patches in sample_vision_patches] + max_images = max(image_counts, default=0) + if max_images == 0: + vision_patches = None + vision_patch_attention_mask = None + vision_token_grids = None + vision_token_offsets = None + vision_token_lengths = None + vision_image_attention_mask = None else: - vision_patches = vision_token_grids = vision_token_offsets = vision_token_lengths = None + first_patch = next((patches[0] for patches in sample_vision_patches if patches), None) + patch_dim = first_patch.shape[-1] + patch_dtype = first_patch.dtype + max_patches = max((patch.shape[0] for patches in sample_vision_patches for patch in patches), default=0) + + vision_patches = torch.zeros( + (batch_size, max_images, max_patches, patch_dim), device=base_device, dtype=patch_dtype + ) + vision_patch_attention_mask = torch.zeros( + (batch_size, max_images, max_patches), device=base_device, dtype=torch.long + ) + vision_token_grids = torch.zeros((batch_size, max_images, 2), device=base_device, dtype=torch.long) + vision_token_offsets = torch.zeros((batch_size, max_images), device=base_device, dtype=torch.long) + vision_token_lengths = torch.zeros((batch_size, max_images), device=base_device, dtype=torch.long) + vision_image_attention_mask = torch.zeros((batch_size, max_images), device=base_device, dtype=torch.long) + + for batch_idx, sample_patches in enumerate(sample_vision_patches): + sample_image_count = len(sample_patches) + if sample_image_count == 0: + continue + vision_token_grids[batch_idx, :sample_image_count] = sample_vision_grids[batch_idx] + vision_token_offsets[batch_idx, :sample_image_count] = sample_vision_offsets[batch_idx] + vision_token_lengths[batch_idx, :sample_image_count] = sample_vision_lengths[batch_idx] + vision_image_attention_mask[batch_idx, :sample_image_count] = 1 + + for image_idx, patches in enumerate(sample_patches): + patch_count = int(patches.shape[0]) + vision_patches[batch_idx, image_idx, :patch_count] = patches + vision_patch_attention_mask[batch_idx, image_idx, :patch_count] = 1 return { "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "modality_tensor": modality_tensor, "vision_patches": vision_patches, + "vision_patch_attention_mask": vision_patch_attention_mask, "vision_token_grids": vision_token_grids, "vision_token_offsets": vision_token_offsets, "vision_token_lengths": vision_token_lengths, - "modality_tensor": modality_tensor, - "position_ids": position_ids, + "vision_image_attention_mask": vision_image_attention_mask, } def __call__( self, - text: Union[str, list[str]], - images: Optional[Union[Image, list[Image]]] = None, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + text: str | list[str], + images: Image | list[Image] | None = None, + return_tensors: str | TensorType | None = TensorType.PYTORCH, **kwargs, ) -> BatchFeature: - texts = [text] if isinstance(text, str) else text - images_list: Optional[list[Optional[list[Image]]]] = None - if images is not None: - if isinstance(images, list) and len(images) == len(texts): - if not images: - images_list = [] - elif isinstance(images[0], list): - images_list = images # already per-sample - else: - images_list = [[img] for img in images] # list of images, one per sample - else: - images_list = [] - for t in texts: - n_tok = t.count(self.vision_token) - if n_tok == 0: - images_list.append(None) - else: - if isinstance(images, list): - images_list.append(images) - else: - images_list.append([images]) - - packed = self._pack_batch(texts, images_list) - input_ids = packed.pop("input_ids") - return BatchFeature(data={"input_ids": input_ids, "packed_inputs": packed}) + return BatchFeature(data=self._build_batch(text=text, images=images), tensor_type=return_tensors) class IsaacRotaryEmbedding(Qwen2_5_VLRotaryEmbedding): @@ -1217,15 +1026,17 @@ def __init__(self, config: IsaacConfig, device=None): config_for_rope = copy.copy(rope_source_cfg) config_for_rope.rope_scaling = rope_scaling - init_device = device if device is not None and getattr(device, "type", None) != "meta" else None - super().__init__(config_for_rope, device=init_device) + super().__init__( + config_for_rope, + device=device if device is not None and getattr(device, "type", None) != "meta" else None, + ) rotary_half_dim = self.inv_freq.shape[0] self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim) self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size @staticmethod - def _resolve_mrope_section(section: Optional[list[int]], rotary_half_dim: int) -> list[int]: + def _resolve_mrope_section(section: list[int] | None, rotary_half_dim: int) -> list[int]: if section is None: weights = (2, 1, 1) base = [rotary_half_dim * w // sum(weights) for w in weights] @@ -1244,7 +1055,7 @@ def forward( self, position_ids: torch.Tensor, modality_tensor: torch.Tensor, - hidden_states: Optional[torch.Tensor] = None, + hidden_states: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if hidden_states is None: batch, seq_len, _ = position_ids.shape @@ -1289,18 +1100,15 @@ class IsaacModel(Qwen3PreTrainedModel): _can_compile_fullgraph = False _supports_flex_attn = False _can_record_outputs = { - "attentions": OutputRecorder(Qwen3Attention, index=1), - "vision_attentions": OutputRecorder(IsaacVisionAttention, index=1), + "hidden_states": OutputRecorder(Qwen3DecoderLayer), + "attentions": Qwen3Attention, + "vision_attentions": IsaacVisionAttention, } all_tied_weights_keys: dict[str, str] = {} def __init__(self, config: IsaacConfig): Qwen3PreTrainedModel.__init__(self, config) - - text_cfg_source = config.text_config - text_cfg = copy.deepcopy(text_cfg_source) - self.text_model = Qwen3Model._from_config(text_cfg) - self.text_model.config = config # Ensure downstream callers observe the composed config + self.text_model = Qwen3Model._from_config(config.text_config) self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) @@ -1336,88 +1144,73 @@ def embed_tokens(self) -> nn.Module: def embed_tokens(self, value: nn.Module) -> None: self.text_model.embed_tokens = value - @property - def vision_model(self) -> nn.Module: - return self.vision_embedding.vision_tower - - def embed_packed_inputs( - self, input_ids: torch.Tensor, packed_inputs: dict[str, Optional[torch.Tensor]] + def embed_multimodal_inputs( + self, + input_ids: torch.Tensor, + modality_tensor: torch.Tensor, + vision_patches: torch.Tensor, + vision_token_grids: torch.Tensor, + vision_patch_attention_mask: torch.Tensor | None = None, + vision_token_offsets: torch.Tensor | None = None, + vision_token_lengths: torch.Tensor | None = None, + vision_image_attention_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Expects input_ids for text tokens and packed_inputs containing: - - modality_tensor: (batch, seq_len) modality ids aligned to the sequence - - position_ids: (batch, seq_len, 3) MRoPE coordinates (optional) - - vision_patches: concatenated vision tokens shaped (total_tokens, embed_dim) or None - - vision_token_grids: (num_images, 2) token grid sizes or None - - vision_token_offsets: (num_images,) offsets into each image's virtual token span (optional) - - vision_token_lengths: (num_images,) surviving virtual token lengths per image (optional) - - vision_token_batch_indices: (num_images,) batch row for each image (optional; defaults to zeros) - """ - modality = packed_inputs["modality_tensor"].to(device=input_ids.device, dtype=torch.long) + modality = modality_tensor.to(device=input_ids.device, dtype=torch.long) embeds = self.text_model.embed_tokens(input_ids) + image_token_mask = modality == ModalityType.image.value - vision_patches = packed_inputs.get("vision_patches") - if vision_patches is None: + if vision_patches is None or vision_token_grids is None: + if torch.any(image_token_mask): + raise ValueError("Image placeholders require `vision_patches` and `vision_token_grids`.") return embeds, modality - token_grids = packed_inputs["vision_token_grids"].to(device=vision_patches.device, dtype=torch.long) - vision = self.vision_embedding((vision_patches, token_grids)) # (total_tokens, hidden) - - # per-image token counts AFTER pixel-shuffle - vision_reduction_factor = int(self.config.vision_config.pixel_shuffle_scale_factor) - sizes = ( - token_grids.prod(-1).div(vision_reduction_factor * vision_reduction_factor, rounding_mode="floor").tolist() + vision_patches = vision_patches.to(device=embeds.device) + token_grids = vision_token_grids.to(device=embeds.device, dtype=torch.long) + image_attention_mask = ( + vision_image_attention_mask.to(device=embeds.device, dtype=torch.bool) + if vision_image_attention_mask is not None + else torch.ones(token_grids.shape[:2], device=embeds.device, dtype=torch.bool) + ) + patch_attention_mask = ( + vision_patch_attention_mask.to(device=embeds.device, dtype=torch.long) + if vision_patch_attention_mask is not None + else torch.ones(vision_patches.shape[:3], device=embeds.device, dtype=torch.long) + ) + offsets = ( + vision_token_offsets.to(device=embeds.device, dtype=torch.long) + if vision_token_offsets is not None + else torch.zeros(token_grids.shape[:2], device=embeds.device, dtype=torch.long) + ) + reduction_factor = int(self.config.vision_config.pixel_shuffle_scale_factor) ** 2 + lengths = ( + vision_token_lengths.to(device=embeds.device, dtype=torch.long) + if vision_token_lengths is not None + else token_grids.prod(-1).div(reduction_factor, rounding_mode="floor").to(dtype=torch.long) ) - offsets = packed_inputs.get("vision_token_offsets") - lengths = packed_inputs.get("vision_token_lengths") - batch_indices = packed_inputs.get("vision_token_batch_indices") - - chunks = vision.split(sizes, dim=0) - picked: list[torch.Tensor] = [] - picked_batch: list[int] = [] - for chunk, size, offset, length, batch_index in zip( - chunks, - sizes, - offsets.tolist(), - lengths.tolist(), - (batch_indices.tolist() if batch_indices is not None else [0] * len(sizes)), - ): - if size <= 0: - continue - offset = max(0, min(int(offset), size)) - length = max(0, min(int(length), size - offset)) - if length: - picked.append(chunk[offset : offset + length]) - picked_batch.append(int(batch_index)) - if picked: - vision_chunks = picked - vision_batch_idx = picked_batch - else: - vision_chunks = vision_batch_idx = [] - - vision = torch.cat(vision_chunks, 0) if vision_chunks else vision.new_zeros((0, vision.size(-1))) - embeds = embeds.clone() - num_batches = modality.shape[0] - image_positions = [ - (modality[b] == ModalityType.image.value).nonzero(as_tuple=False).squeeze(-1) for b in range(num_batches) - ] - cursors = [0 for _ in range(num_batches)] - - for chunk, batch_index in zip(vision_chunks, vision_batch_idx): - if chunk.numel() == 0: - continue - positions = image_positions[batch_index] - start = cursors[batch_index] - end = start + chunk.shape[0] - embeds[batch_index, positions[start:end]] = chunk.to(device=embeds.device, dtype=embeds.dtype) - cursors[batch_index] = end + + flat_vision_patches = vision_patches[image_attention_mask] + flat_patch_attention_mask = patch_attention_mask[image_attention_mask] + flat_token_grids = token_grids[image_attention_mask] + flat_offsets = offsets[image_attention_mask] + flat_lengths = lengths[image_attention_mask] + + vision_embeddings = self.vision_embedding((flat_vision_patches, flat_token_grids, flat_patch_attention_mask)) + token_positions = torch.arange(flat_lengths.max(), device=embeds.device, dtype=torch.long) + gather_positions = flat_offsets[:, None] + token_positions[None, :] + gather_mask = token_positions[None, :] < flat_lengths[:, None] + image_features = vision_embeddings[ + torch.arange(vision_embeddings.shape[0], device=embeds.device, dtype=torch.long)[:, None], + gather_positions, + ][gather_mask] + scatter_mask = image_token_mask.unsqueeze(-1).expand_as(embeds) + embeds = embeds.masked_scatter(scatter_mask, image_features) return embeds, modality def get_rope_index( self, *, - position_ids: Optional[torch.Tensor] = None, + position_ids: torch.Tensor | None = None, attention_mask: torch.Tensor, inputs_embeds: torch.Tensor, cache_position: torch.Tensor, @@ -1439,12 +1232,13 @@ def get_rope_index( if cp.ndim == 1: cp = cp.view(1, -1).expand(batch_size or 1, -1) - base_delta = torch.as_tensor( - 0 if self.rope_deltas is None else self.rope_deltas, - device=device, - dtype=torch.long, - ).reshape(-1, 1) - base_delta = torch.broadcast_to(base_delta, (batch_size, 1)) + is_new_prefill = cp[:, :1].eq(0).all(dim=1, keepdim=True) + if self.rope_deltas is None: + base_delta = torch.zeros((batch_size, 1), device=device, dtype=torch.long) + else: + previous_delta = torch.as_tensor(self.rope_deltas, device=device, dtype=torch.long).reshape(-1, 1) + previous_delta = torch.broadcast_to(previous_delta, (batch_size, 1)) + base_delta = torch.where(is_new_prefill, torch.zeros_like(previous_delta), previous_delta) mask_delta = attention_mask.to(device=device, dtype=torch.long).sum(1, keepdim=True) - attention_mask.size( 1 @@ -1469,17 +1263,24 @@ def get_rope_index( return position_ids, rope_deltas @auto_docstring - @check_model_inputs + @merge_with_config_defaults + @capture_outputs def forward( self, - input_ids: Optional[torch.LongTensor] = None, - packed_inputs: Optional[dict[str, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor | None = None, + modality_tensor: torch.LongTensor | None = None, + vision_patches: torch.Tensor | None = None, + vision_patch_attention_mask: torch.Tensor | None = None, + vision_token_grids: torch.LongTensor | None = None, + vision_token_offsets: torch.LongTensor | None = None, + vision_token_lengths: torch.LongTensor | None = None, + vision_image_attention_mask: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPast: """ @@ -1488,21 +1289,59 @@ def forward( Computes position embeddings once and passes them through all layers. Args: - packed_inputs (`dict`, *optional*): - Plain tensor payloads. When provided, requires `input_ids` for text tokens (or `text_token_ids` so `input_ids` can be rebuilt). modality_tensor (`torch.LongTensor`, *optional*): Modality identifiers aligned with the embedded sequence, shaped `(batch_size, seq_len)` and containing - values from `ModalityType`. Automatically built from `packed_inputs` or treated as text-only when omitted. + values from `ModalityType`. Treated as text-only when omitted. + vision_patches (`torch.FloatTensor`, *optional*): + Padded per-image patch vectors of shape `(batch_size, max_images, max_patches, patch_dim)`. + vision_patch_attention_mask (`torch.LongTensor`, *optional*): + Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. + vision_token_grids (`torch.LongTensor`, *optional*): + Per-image patch grids `(h, w)` with shape `(batch_size, max_images, 2)`. + vision_token_offsets (`torch.LongTensor`, *optional*): + Start offsets inside the per-image vision embedding sequence, shape `(batch_size, max_images)`. + vision_token_lengths (`torch.LongTensor`, *optional*): + Number of vision tokens to consume per image, shape `(batch_size, max_images)`. + vision_image_attention_mask (`torch.LongTensor`, *optional*): + Mask indicating which image slots are populated, shape `(batch_size, max_images)`. """ - output_attentions = kwargs.pop("output_attentions", None) - - modality_tensor: Optional[torch.Tensor] = None + if inputs_embeds is None: + if input_ids is None: + raise ValueError("`input_ids` or `inputs_embeds` must be provided.") + + has_vision_inputs = any( + value is not None + for value in ( + vision_patches, + vision_patch_attention_mask, + vision_token_grids, + vision_token_offsets, + vision_token_lengths, + vision_image_attention_mask, + ) + ) + if modality_tensor is not None or has_vision_inputs: + if modality_tensor is None: + modality_tensor = torch.full_like(input_ids, ModalityType.text.value) + inputs_embeds, modality_tensor = self.embed_multimodal_inputs( + input_ids=input_ids, + modality_tensor=modality_tensor, + vision_patches=vision_patches, + vision_patch_attention_mask=vision_patch_attention_mask, + vision_token_grids=vision_token_grids, + vision_token_offsets=vision_token_offsets, + vision_token_lengths=vision_token_lengths, + vision_image_attention_mask=vision_image_attention_mask, + ) + else: + inputs_embeds = self.text_model.embed_tokens(input_ids) - if packed_inputs is not None: - inputs_embeds, modality_tensor = self.embed_packed_inputs(input_ids, packed_inputs) - elif input_ids is not None: - inputs_embeds = self.text_model.embed_tokens(input_ids) + if modality_tensor is None: + batch_size, seq_len = inputs_embeds.shape[:2] + modality_tensor = torch.full( + (batch_size, seq_len), ModalityType.text.value, device=inputs_embeds.device, dtype=torch.long + ) device = inputs_embeds.device batch_size, seq_len = inputs_embeds.shape[:2] @@ -1517,9 +1356,6 @@ def forward( if attention_mask is None: attention_mask = torch.ones(inputs_embeds.shape[:2], device=inputs_embeds.device, dtype=torch.long) - if position_ids is None and packed_inputs is not None and packed_inputs.get("position_ids") is not None: - position_ids = packed_inputs.get("position_ids").to(device=device) - position_ids, rope_deltas = self.get_rope_index( position_ids=position_ids, attention_mask=attention_mask, @@ -1528,11 +1364,6 @@ def forward( ) self.rope_deltas = rope_deltas - if modality_tensor is None: - modality_tensor = torch.full( - (batch_size, seq_len), ModalityType.text.value, device=device, dtype=torch.long - ) - cos, sin = self.rotary_emb(position_ids, modality_tensor, hidden_states=inputs_embeds) decoder_position_ids = position_ids[..., 0] if position_ids.ndim == 3 else position_ids @@ -1540,7 +1371,7 @@ def forward( if not isinstance(attention_mask, dict): attention_mask = create_masks_for_generate( config=self.config, - input_embeds=inputs_embeds, + inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, @@ -1560,7 +1391,6 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=(cos, sin), - output_attentions=output_attentions, **kwargs, ) @@ -1570,9 +1400,7 @@ def forward( return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values, - hidden_states=(hidden_states,), - attentions=None, + past_key_values=past_key_values if use_cache else None, ) @@ -1591,41 +1419,56 @@ def __init__(self, config: IsaacConfig): @auto_docstring @can_return_tuple - @check_model_inputs + @merge_with_config_defaults def forward( self, - input_ids: Optional[torch.LongTensor] = None, - packed_inputs: Optional[dict[str, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor | None = None, + modality_tensor: torch.LongTensor | None = None, + vision_patches: torch.Tensor | None = None, + vision_patch_attention_mask: torch.Tensor | None = None, + vision_token_grids: torch.LongTensor | None = None, + vision_token_offsets: torch.LongTensor | None = None, + vision_token_lengths: torch.LongTensor | None = None, + vision_image_attention_mask: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | CausalLMOutputWithPast: - """Run multimodal CausalLM forward, accepting packed vision/text inputs. - - Args: - packed_inputs (`dict`, *optional*): - Packed vision/text payload from ``IsaacProcessor`` containing modality ids, MRoPE position ids, and - vision patch tensors/grids (with optional offsets/lengths) used to rebuild embeddings. - - Returns: - CausalLMOutputWithPast: logits, optional loss, caches, hidden states, attentions. + r""" + modality_tensor (`torch.LongTensor`, *optional*): + Modality identifiers aligned with the token sequence, shaped `(batch_size, seq_len)`. + vision_patches (`torch.FloatTensor`, *optional*): + Padded per-image patch vectors of shape `(batch_size, max_images, max_patches, patch_dim)`. + vision_patch_attention_mask (`torch.LongTensor`, *optional*): + Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. + vision_token_grids (`torch.LongTensor`, *optional*): + Per-image patch grids `(h, w)` with shape `(batch_size, max_images, 2)`. + vision_token_offsets (`torch.LongTensor`, *optional*): + Start offsets inside the per-image vision embedding sequence, shape `(batch_size, max_images)`. + vision_token_lengths (`torch.LongTensor`, *optional*): + Number of vision tokens to consume per image, shape `(batch_size, max_images)`. + vision_image_attention_mask (`torch.LongTensor`, *optional*): + Mask indicating which image slots are populated, shape `(batch_size, max_images)`. """ - output_attentions = kwargs.pop("output_attentions", None) - outputs = self.model( input_ids=input_ids, - packed_inputs=packed_inputs, + modality_tensor=modality_tensor, + vision_patches=vision_patches, + vision_patch_attention_mask=vision_patch_attention_mask, + vision_token_grids=vision_token_grids, + vision_token_offsets=vision_token_offsets, + vision_token_lengths=vision_token_lengths, + vision_image_attention_mask=vision_image_attention_mask, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, cache_position=cache_position, **kwargs, ) @@ -1640,18 +1483,24 @@ def forward( logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, - attentions=outputs.attentions if output_attentions else None, + attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, - past_key_values: Optional[list[torch.FloatTensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - packed_inputs: Optional[dict[str, torch.Tensor]] = None, - cache_position: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + past_key_values: list[torch.FloatTensor] | None = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + modality_tensor: torch.LongTensor | None = None, + vision_patches: torch.Tensor | None = None, + vision_patch_attention_mask: torch.Tensor | None = None, + vision_token_grids: torch.LongTensor | None = None, + vision_token_offsets: torch.LongTensor | None = None, + vision_token_lengths: torch.LongTensor | None = None, + vision_image_attention_mask: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, **kwargs, ) -> dict[str, Any]: model_inputs = super().prepare_inputs_for_generation( @@ -1660,16 +1509,26 @@ def prepare_inputs_for_generation( attention_mask=attention_mask, inputs_embeds=inputs_embeds, cache_position=cache_position, - position_ids=position_ids, + position_ids=None, **kwargs, ) - if packed_inputs is None: + multimodal_inputs = { + "modality_tensor": modality_tensor, + "vision_patches": vision_patches, + "vision_patch_attention_mask": vision_patch_attention_mask, + "vision_token_grids": vision_token_grids, + "vision_token_offsets": vision_token_offsets, + "vision_token_lengths": vision_token_lengths, + "vision_image_attention_mask": vision_image_attention_mask, + } + if not any(value is not None for value in multimodal_inputs.values()): return model_inputs past_len = past_key_values.get_seq_length() if past_key_values is not None else 0 first_step = past_len == 0 - model_inputs["packed_inputs"] = packed_inputs if first_step else None - model_inputs["position_ids"] = None + for key, value in multimodal_inputs.items(): + model_inputs[key] = value if first_step else None + model_inputs["position_ids"] = position_ids if first_step else None return model_inputs diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index c7308d98d425..81688aa74144 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -4,7 +4,6 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# coding=utf-8 # Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any from ...feature_extraction_utils import BatchFeature from ...processing_utils import ProcessorMixin @@ -49,7 +48,7 @@ class IsaacProcessor(ProcessorMixin): config (IsaacConfig | dict, optional): If provided, overrides processor defaults from the model config. Returns: - BatchFeature: Contains ``input_ids`` and ``packed_inputs`` (patch tensors, grids, offsets, lengths, modality, positions). + BatchFeature: Top-level batched text and vision tensors. """ attributes = ["image_processor", "tokenizer"] @@ -64,8 +63,8 @@ def __init__( *, vision_token: str = "", max_sequence_length: int = 16384, - rescale_factor: Optional[float] = None, - config: Optional[Union[IsaacConfig, dict]] = None, + rescale_factor: float | None = None, + config: IsaacConfig | dict | None = None, ) -> None: if isinstance(config, dict): config = IsaacConfig(**config) @@ -95,254 +94,260 @@ def __init__( self.vision_token = vision_token self.max_sequence_length = max_sequence_length - def _pack_batch( - self, texts: list[str], images_list: Optional[list[Optional[list[Image]]]] - ) -> dict[str, Optional[torch.Tensor]]: - if images_list is None: - pairs = ((t, None) for t in texts) - else: + def _build_batch( + self, + text: str | list[str], + images: Image | list[Image] | None = None, + ) -> dict[str, torch.Tensor | None]: + texts = [text] if isinstance(text, str) else text + if images is None: + pairs = ((text_value, None) for text_value in texts) + elif isinstance(images, list) and len(images) == len(texts): + if not images: + images_list = [] + elif isinstance(images[0], list): + images_list = images + else: + images_list = [[image] for image in images] pairs = zip(texts, images_list, strict=True) - - per_sample: list[dict[str, Optional[torch.Tensor]]] = [] - for txt, imgs in pairs: - if imgs is not None and isinstance(imgs, Image): - imgs = [imgs] - per_sample.append(self._pack_single(txt, imgs)) - - lengths = [int(p["input_ids"].shape[1]) for p in per_sample] - max_len = max(lengths, default=0) - batch = len(per_sample) - - # Use first device with data as anchor - base_device = torch.device("cpu") - for p in per_sample: - if p["input_ids"].numel() > 0: - base_device = p["input_ids"].device - break - - pad_id = self.text_pad_token_id - padded_input_ids = torch.full((batch, max_len), pad_id, device=base_device, dtype=torch.long) - padded_modality = torch.full((batch, max_len), ModalityType.text.value, device=base_device, dtype=torch.long) - padded_position_ids = torch.zeros((batch, max_len, 3), device=base_device, dtype=torch.long) - - for i, (sample, l) in enumerate(zip(per_sample, lengths)): - if l: - padded_input_ids[i, -l:] = sample["input_ids"][0] - padded_modality[i, -l:] = sample["modality_tensor"][0] - padded_position_ids[i, -l:] = sample["position_ids"][0] - - # Vision-side aggregation - v_samples = [(b, s) for b, s in enumerate(per_sample) if s["vision_patches"] is not None] - if v_samples: - vision_patches_list = [s["vision_patches"] for _, s in v_samples] - vision_grids_list = [s["vision_token_grids"] for _, s in v_samples] - vision_offsets_list = [s["vision_token_offsets"] for _, s in v_samples] - vision_lengths_list = [s["vision_token_lengths"] for _, s in v_samples] - vision_batch_indices = [torch.full_like(s["vision_token_offsets"], b) for b, s in v_samples] - - vision_patches = torch.cat(vision_patches_list, dim=0) - vision_token_grids = torch.cat(vision_grids_list, dim=0) - vision_token_offsets = torch.cat(vision_offsets_list, dim=0) - vision_token_lengths = torch.cat(vision_lengths_list, dim=0) - vision_token_batch_indices = torch.cat(vision_batch_indices, dim=0) else: - vision_patches = vision_token_grids = vision_token_offsets = vision_token_lengths = ( - vision_token_batch_indices - ) = None - - return { - "input_ids": padded_input_ids, - "vision_patches": vision_patches, - "vision_token_grids": vision_token_grids, - "vision_token_offsets": vision_token_offsets, - "vision_token_lengths": vision_token_lengths, - "vision_token_batch_indices": vision_token_batch_indices, - "modality_tensor": padded_modality, - "position_ids": padded_position_ids, - } - - def _pack_single(self, text: str, images: Optional[list[Image]]) -> dict[str, Optional[torch.Tensor]]: - segments = text.split(self.vision_token) # Parse by vision_token; interleave text segments and image segments. - num_images = len(segments) - 1 - items: list[dict[str, Any]] = [] - total = 0 - num_provided_images = len(images) if images is not None else 0 - if not num_images == num_provided_images: - raise ValueError( - f"IsaacProcessor expects one image per image token, got {num_images} tokens and {num_provided_images} images in sample with text {text} " + pairs = ( + ( + text_value, + None + if text_value.count(self.vision_token) == 0 + else images + if isinstance(images, list) + else [images], + ) + for text_value in texts ) - for index, segment in enumerate(segments): - if segment: - tok = ( - self.tokenizer.encode(segment, add_special_tokens=False, return_tensors="pt") - .squeeze(0) - .to(torch.long) + sample_input_ids: list[torch.Tensor] = [] + sample_modality: list[torch.Tensor] = [] + sample_position_ids: list[torch.Tensor] = [] + sample_vision_patches: list[list[torch.Tensor]] = [] + sample_vision_grids: list[torch.Tensor] = [] + sample_vision_offsets: list[torch.Tensor] = [] + sample_vision_lengths: list[torch.Tensor] = [] + + for text_value, sample_images in pairs: + segments = text_value.split(self.vision_token) + num_images = len(segments) - 1 + num_provided_images = len(sample_images) if sample_images is not None else 0 + if num_images != num_provided_images: + raise ValueError( + f"IsaacProcessor expects one image per image token, got {num_images} tokens and {num_provided_images} images in sample with text {text_value} " ) - segment_length = int(tok.numel()) - items.append({"type": "text", "segment_length": segment_length, "tok": tok}) - total += segment_length - - if index < num_images: - feat = self.image_processor(images=images[index], return_tensors=TensorType.PYTORCH) - patches = feat["patches"][0].reshape(-1, feat["patches"].shape[-1]) - - virtual_pixel_size = feat["virtual_pixel_size"][0].to(torch.long).tolist() - real_pixel_size = feat["real_pixel_size"][0].to(torch.long).tolist() - dims = tuple((virtual_pixel_size + [1, 1, 1])[:3]) # (T,H,W) in virtual space - segment_length = int(dims[0] * dims[1] * dims[2]) - - items.append( - { - "type": "image", - "segment_length": segment_length, - "dims": dims, - "patches": patches, - "grid": (int(real_pixel_size[1]), int(real_pixel_size[2])), - } - ) - total += segment_length - - # Tail crop window. - start = max(0, total - self.max_sequence_length) - end = total - - image_pad_value = self.image_pad_token_id - base_device: Optional[torch.device] = None - position_ids, modality, input_ids = [], [], [] - vpatches, grids, vision_token_offsets, vision_token_lengths = [], [], [], [] - - global_offset = 0 - position_offset = 0 - - for item in items: - segment_length = int(item["segment_length"]) - current_window_start = max(start, global_offset) - current_window_end = min(end, global_offset + segment_length) - has_overlap = current_window_end > current_window_start - - if has_overlap and base_device is None: - base_device = item["patches"].device if item["type"] == "image" else item["tok"].device - - if has_overlap: - segment_local_start = int(current_window_start - global_offset) - segment_local_end = int(current_window_end - global_offset) - segment_local_indices = torch.arange( - segment_local_start, segment_local_end, device=base_device, dtype=torch.long - ) - segment_kept_length = segment_local_end - segment_local_start - - if item["type"] == "text": - slice_index = segment_local_indices + position_offset - zero_axis_pad = torch.zeros_like(slice_index) - position_ids.append(torch.stack((slice_index, zero_axis_pad, zero_axis_pad), -1)) - modality.append( - torch.full( - (segment_kept_length,), ModalityType.text.value, device=base_device, dtype=torch.long - ) + + items: list[dict[str, Any]] = [] + total = 0 + for index, segment in enumerate(segments): + if segment: + text_tokens = ( + self.tokenizer.encode(segment, add_special_tokens=False, return_tensors="pt") + .squeeze(0) + .to(torch.long) ) - input_ids.append(item["tok"].to(base_device)[segment_local_start:segment_local_end]) - position_offset += segment_length - else: - num_pos_slices, grid_height_tokens, grid_width_tokens = item["dims"] - hw = grid_height_tokens * grid_width_tokens - slice_index = (segment_local_indices // hw) + position_offset - rem = segment_local_indices % hw - row_index = rem // grid_width_tokens - col_index = rem % grid_width_tokens - position_ids.append(torch.stack((slice_index, row_index, col_index), -1)) - modality.append( - torch.full( - (segment_kept_length,), ModalityType.image.value, device=base_device, dtype=torch.long - ) + segment_length = int(text_tokens.numel()) + items.append({"type": "text", "segment_length": segment_length, "tokens": text_tokens}) + total += segment_length + + if index < num_images: + feature = self.image_processor(images=sample_images[index], return_tensors=TensorType.PYTORCH) + patches = feature["patches"][0].reshape(-1, feature["patches"].shape[-1]) + virtual_pixel_size = feature["virtual_pixel_size"][0].to(torch.long).tolist() + real_pixel_size = feature["real_pixel_size"][0].to(torch.long).tolist() + dims = tuple((virtual_pixel_size + [1, 1, 1])[:3]) + segment_length = int(dims[0] * dims[1] * dims[2]) + items.append( + { + "type": "image", + "segment_length": segment_length, + "dims": dims, + "patches": patches, + "grid": (int(real_pixel_size[1]), int(real_pixel_size[2])), + } ) - input_ids.append( - torch.full((segment_kept_length,), image_pad_value, device=base_device, dtype=torch.long) + total += segment_length + + start = max(0, total - self.max_sequence_length) + end = total + base_device: torch.device | None = None + input_ids_chunks, modality_chunks, position_chunks = [], [], [] + vision_patches, vision_grids, vision_offsets, vision_lengths = [], [], [], [] + global_offset = 0 + position_offset = 0 + + for item in items: + segment_length = int(item["segment_length"]) + current_window_start = max(start, global_offset) + current_window_end = min(end, global_offset + segment_length) + has_overlap = current_window_end > current_window_start + + if has_overlap and base_device is None: + base_device = item["patches"].device if item["type"] == "image" else item["tokens"].device + + if has_overlap: + segment_local_start = int(current_window_start - global_offset) + segment_local_end = int(current_window_end - global_offset) + segment_local_indices = torch.arange( + segment_local_start, segment_local_end, device=base_device, dtype=torch.long ) + segment_kept_length = segment_local_end - segment_local_start + + if item["type"] == "text": + slice_index = segment_local_indices + position_offset + zero_axis = torch.zeros_like(slice_index) + position_chunks.append(torch.stack((slice_index, zero_axis, zero_axis), -1)) + modality_chunks.append( + torch.full( + (segment_kept_length,), ModalityType.text.value, device=base_device, dtype=torch.long + ) + ) + input_ids_chunks.append(item["tokens"].to(base_device)[segment_local_start:segment_local_end]) + position_offset += segment_length + else: + num_pos_slices, grid_height_tokens, grid_width_tokens = item["dims"] + hw = grid_height_tokens * grid_width_tokens + slice_index = (segment_local_indices // hw) + position_offset + rem = segment_local_indices % hw + position_chunks.append( + torch.stack((slice_index, rem // grid_width_tokens, rem % grid_width_tokens), -1) + ) + modality_chunks.append( + torch.full( + (segment_kept_length,), ModalityType.image.value, device=base_device, dtype=torch.long + ) + ) + input_ids_chunks.append( + torch.full( + (segment_kept_length,), self.image_pad_token_id, device=base_device, dtype=torch.long + ) + ) - vpatches.append(item["patches"].to(base_device)) # full patches; slice later via offsets/lengths - # Record per-image slice boundaries so we can drop cropped virtual tokens - # after pixel shuffle without re-packing the entire vision stream. - grids.append(item["grid"]) - vision_token_offsets.append(segment_local_start) - vision_token_lengths.append(segment_kept_length) - - position_offset += int(num_pos_slices) + vision_patches.append(item["patches"].to(base_device)) + vision_grids.append(item["grid"]) + vision_offsets.append(segment_local_start) + vision_lengths.append(segment_kept_length) + position_offset += int(num_pos_slices) + else: + position_offset += segment_length if item["type"] == "text" else int(item["dims"][0]) - else: - position_offset += segment_length if item["type"] == "text" else int(item["dims"][0]) + global_offset += segment_length - global_offset += segment_length + if base_device is None: + base_device = torch.device("cpu") - if base_device is None: - base_device = torch.device("cpu") + sample_input_ids.append( + torch.cat(input_ids_chunks, 0) + if input_ids_chunks + else torch.zeros((0,), device=base_device, dtype=torch.long) + ) + sample_modality.append( + torch.cat(modality_chunks, 0) + if modality_chunks + else torch.zeros((0,), device=base_device, dtype=torch.long) + ) + sample_position_ids.append( + torch.cat(position_chunks, 0) + if position_chunks + else torch.zeros((0, 3), device=base_device, dtype=torch.long) + ) + sample_vision_patches.append(vision_patches) + if vision_patches: + sample_vision_grids.append(torch.tensor(vision_grids, device=base_device, dtype=torch.long)) + sample_vision_offsets.append(torch.tensor(vision_offsets, device=base_device, dtype=torch.long)) + sample_vision_lengths.append(torch.tensor(vision_lengths, device=base_device, dtype=torch.long)) + else: + sample_vision_grids.append(torch.zeros((0, 2), device=base_device, dtype=torch.long)) + sample_vision_offsets.append(torch.zeros((0,), device=base_device, dtype=torch.long)) + sample_vision_lengths.append(torch.zeros((0,), device=base_device, dtype=torch.long)) - modality_tensor = ( - torch.cat(modality, 0).unsqueeze(0) - if modality - else torch.zeros((1, 0), device=base_device, dtype=torch.long) - ) - position_ids = ( - torch.cat(position_ids, 0).unsqueeze(0) - if position_ids - else torch.zeros((1, 0, 3), device=base_device, dtype=torch.long) - ) - input_ids = ( - torch.cat(input_ids, 0).unsqueeze(0) - if input_ids - else torch.zeros((1, 0), device=base_device, dtype=torch.long) + batch_size = len(sample_input_ids) + lengths = [int(sample_input.shape[0]) for sample_input in sample_input_ids] + max_len = max(lengths, default=0) + base_device = next( + (sample_input.device for sample_input in sample_input_ids if sample_input.numel() > 0), + torch.device("cpu"), ) - if vpatches: - vision_patches = torch.cat(vpatches, 0) - vision_token_grids = torch.tensor(grids, device=base_device, dtype=torch.long) - vision_token_offsets = torch.tensor(vision_token_offsets, device=base_device, dtype=torch.long) - vision_token_lengths = torch.tensor(vision_token_lengths, device=base_device, dtype=torch.long) + input_ids = torch.full((batch_size, max_len), self.text_pad_token_id, device=base_device, dtype=torch.long) + attention_mask = torch.zeros((batch_size, max_len), device=base_device, dtype=torch.long) + modality_tensor = torch.full( + (batch_size, max_len), ModalityType.text.value, device=base_device, dtype=torch.long + ) + position_ids = torch.zeros((batch_size, max_len, 3), device=base_device, dtype=torch.long) + + for batch_idx, length in enumerate(lengths): + if length == 0: + continue + input_ids[batch_idx, -length:] = sample_input_ids[batch_idx] + attention_mask[batch_idx, -length:] = 1 + modality_tensor[batch_idx, -length:] = sample_modality[batch_idx] + position_ids[batch_idx, -length:] = sample_position_ids[batch_idx] + + image_counts = [len(patches) for patches in sample_vision_patches] + max_images = max(image_counts, default=0) + if max_images == 0: + vision_patches = None + vision_patch_attention_mask = None + vision_token_grids = None + vision_token_offsets = None + vision_token_lengths = None + vision_image_attention_mask = None else: - vision_patches = vision_token_grids = vision_token_offsets = vision_token_lengths = None + first_patch = next((patches[0] for patches in sample_vision_patches if patches), None) + patch_dim = first_patch.shape[-1] + patch_dtype = first_patch.dtype + max_patches = max((patch.shape[0] for patches in sample_vision_patches for patch in patches), default=0) + + vision_patches = torch.zeros( + (batch_size, max_images, max_patches, patch_dim), device=base_device, dtype=patch_dtype + ) + vision_patch_attention_mask = torch.zeros( + (batch_size, max_images, max_patches), device=base_device, dtype=torch.long + ) + vision_token_grids = torch.zeros((batch_size, max_images, 2), device=base_device, dtype=torch.long) + vision_token_offsets = torch.zeros((batch_size, max_images), device=base_device, dtype=torch.long) + vision_token_lengths = torch.zeros((batch_size, max_images), device=base_device, dtype=torch.long) + vision_image_attention_mask = torch.zeros((batch_size, max_images), device=base_device, dtype=torch.long) + + for batch_idx, sample_patches in enumerate(sample_vision_patches): + sample_image_count = len(sample_patches) + if sample_image_count == 0: + continue + vision_token_grids[batch_idx, :sample_image_count] = sample_vision_grids[batch_idx] + vision_token_offsets[batch_idx, :sample_image_count] = sample_vision_offsets[batch_idx] + vision_token_lengths[batch_idx, :sample_image_count] = sample_vision_lengths[batch_idx] + vision_image_attention_mask[batch_idx, :sample_image_count] = 1 + + for image_idx, patches in enumerate(sample_patches): + patch_count = int(patches.shape[0]) + vision_patches[batch_idx, image_idx, :patch_count] = patches + vision_patch_attention_mask[batch_idx, image_idx, :patch_count] = 1 return { "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "modality_tensor": modality_tensor, "vision_patches": vision_patches, + "vision_patch_attention_mask": vision_patch_attention_mask, "vision_token_grids": vision_token_grids, "vision_token_offsets": vision_token_offsets, "vision_token_lengths": vision_token_lengths, - "modality_tensor": modality_tensor, - "position_ids": position_ids, + "vision_image_attention_mask": vision_image_attention_mask, } def __call__( self, - text: Union[str, list[str]], - images: Optional[Union[Image, list[Image]]] = None, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + text: str | list[str], + images: Image | list[Image] | None = None, + return_tensors: str | TensorType | None = TensorType.PYTORCH, **kwargs, ) -> BatchFeature: - texts = [text] if isinstance(text, str) else text - images_list: Optional[list[Optional[list[Image]]]] = None - if images is not None: - if isinstance(images, list) and len(images) == len(texts): - if not images: - images_list = [] - elif isinstance(images[0], list): - images_list = images # already per-sample - else: - images_list = [[img] for img in images] # list of images, one per sample - else: - images_list = [] - for t in texts: - n_tok = t.count(self.vision_token) - if n_tok == 0: - images_list.append(None) - else: - if isinstance(images, list): - images_list.append(images) - else: - images_list.append([images]) - - packed = self._pack_batch(texts, images_list) - input_ids = packed.pop("input_ids") - return BatchFeature(data={"input_ids": input_ids, "packed_inputs": packed}) + return BatchFeature(data=self._build_batch(text=text, images=images), tensor_type=return_tensors) __all__ = ["IsaacProcessor"] diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py index 66da717cd0b6..6f7fa5984620 100644 --- a/tests/models/isaac/test_modeling_isaac.py +++ b/tests/models/isaac/test_modeling_isaac.py @@ -38,10 +38,12 @@ is_torch_available, ) from transformers.image_utils import load_image +from transformers.masking_utils import create_bidirectional_mask from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessorFast from transformers.models.isaac.modeling_isaac import ( IsaacVisionAttention, IsaacVisionConfig, + pixel_shuffle_padded, ) from transformers.models.isaac.processing_isaac import IsaacProcessor from transformers.testing_utils import ( @@ -233,6 +235,35 @@ def infer_pad_from_tail(sequence: torch.Tensor) -> tuple[int | None, int]: return pad_candidate, idx +def _pixel_shuffle_reference(x: torch.Tensor, token_grids: torch.Tensor, scale_factor: int): + num_images, _, embed_dim = x.shape + output_lengths = [] + for i in range(num_images): + h, w = token_grids[i].tolist() + output_lengths.append((h // scale_factor) * (w // scale_factor)) + + max_output_tokens = max(output_lengths, default=0) + output_dim = embed_dim * scale_factor * scale_factor + out = x.new_zeros((num_images, max_output_tokens, output_dim)) + out_mask = torch.zeros((num_images, max_output_tokens), device=x.device, dtype=torch.long) + + for i in range(num_images): + h, w = token_grids[i].tolist() + if h == 0 or w == 0: + continue + seq_len = h * w + tokens = x[i, :seq_len] + hb, wb = h // scale_factor, w // scale_factor + t = tokens.view(h, w, embed_dim).permute(2, 0, 1).unsqueeze(0) + t = torch.nn.functional.pixel_unshuffle(t, downscale_factor=scale_factor) + t = t.view(1, embed_dim, scale_factor, scale_factor, hb, wb) + t = t.permute(0, 4, 5, 2, 3, 1).contiguous().view(hb * wb, output_dim) + out[i, : hb * wb] = t + out_mask[i, : hb * wb] = 1 + + return out, out_mask, torch.tensor(output_lengths, device=x.device, dtype=torch.long) + + def create_isaac_processor( tokenizer, isaac_config, @@ -278,6 +309,24 @@ def create_isaac_processor( ) +def to_model_multimodal_inputs(processor_output, device): + keys = ( + "modality_tensor", + "position_ids", + "vision_patches", + "vision_patch_attention_mask", + "vision_token_grids", + "vision_token_offsets", + "vision_token_lengths", + "vision_image_attention_mask", + ) + return { + key: (value.to(device) if isinstance(value, torch.Tensor) else value) + for key, value in processor_output.items() + if key in keys + } + + @lru_cache(maxsize=1) def _load_red_dot_image(): if Image is None: @@ -387,7 +436,7 @@ def __init__( self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.is_training = True - self.expected_num_hidden_layers = 1 + self.expected_num_hidden_layers = num_hidden_layers + 1 self.text_config = { "bos_token_id": 0, @@ -565,6 +614,33 @@ def test_isaac_for_conditional_generation_loss_and_generate_flag(self): self.assertEqual(outputs.logits.shape, (batch_size, seq_len, config.vocab_size)) +@require_torch +class IsaacPixelShufflePaddedTest(unittest.TestCase): + def test_pixel_shuffle_padded_matches_reference_no_attention_mask(self): + x = torch.arange(2 * 16 * 4, device=torch_device, dtype=torch.float32).view(2, 16, 4) + token_grids = torch.tensor([[4, 4], [2, 4]], device=torch_device, dtype=torch.long) + expected_hidden, expected_mask, expected_lengths = _pixel_shuffle_reference(x, token_grids, scale_factor=2) + + hidden = pixel_shuffle_padded(x=x, token_grids=token_grids, scale_factor=2) + + torch.testing.assert_close(hidden, expected_hidden) + + def test_pixel_shuffle_padded_raises_on_non_divisible_grid(self): + x = torch.randn(1, 15, 8, device=torch_device) + token_grids = torch.tensor([[3, 5]], device=torch_device, dtype=torch.long) + + with pytest.raises(ValueError, match="divisible"): + pixel_shuffle_padded(x=x, token_grids=token_grids, scale_factor=2) + + def test_pixel_shuffle_padded_zero_grid(self): + x = torch.randn(1, 4, 8, device=torch_device) + token_grids = torch.tensor([[0, 0]], device=torch_device, dtype=torch.long) + + hidden = pixel_shuffle_padded(x=x, token_grids=token_grids, scale_factor=2) + + self.assertEqual(hidden.shape, (1, 0, 32)) + + @require_torch @require_flash_attn class IsaacAttentionDtypeTest(unittest.TestCase): @@ -624,7 +700,7 @@ def test_flash_attention_matches_weight_dtype_bf16_with_padding(self): assert attn_output.dtype == attn.out_proj.weight.dtype assert attn_output.dtype == hidden_states.dtype - def test_flash_attention_matches_weight_dtype_bf16_with_cu_seqlens(self): + def test_flash_attention_matches_weight_dtype_bf16_with_prepared_mask(self): self._skip_if_no_cuda_bf16() torch.manual_seed(0) @@ -635,10 +711,15 @@ def test_flash_attention_matches_weight_dtype_bf16_with_cu_seqlens(self): attn = IsaacVisionAttention(config).to(device=device, dtype=torch.bfloat16).eval() hidden_states = torch.randn(1, 5, config.hidden_size, device=device, dtype=torch.bfloat16) - cu_seqlens = torch.tensor([0, 3, 5], device=device, dtype=torch.int32) + attention_mask = torch.tensor([[1, 1, 1, 0, 0]], device=device, dtype=torch.long) + prepared_attention_mask = create_bidirectional_mask( + config=config, + inputs_embeds=hidden_states, + attention_mask=attention_mask, + ) with torch.no_grad(): - attn_output, _ = attn(hidden_states, cu_seqlens=cu_seqlens, max_seqlen=3) + attn_output, _ = attn(hidden_states, attention_mask=prepared_attention_mask) assert attn_output.dtype == attn.out_proj.weight.dtype assert attn_output.dtype == hidden_states.dtype @@ -702,7 +783,6 @@ def setUp(self): def _generate_from_messages(self, messages, images, num_tokens=None): prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() processor_output = self.processor(text=prompt, images=images, return_tensors="pt") - packed_inputs = processor_output["packed_inputs"] input_ids = processor_output["input_ids"].to(self.device) attention_mask = processor_output.get("attention_mask") if attention_mask is None: @@ -712,16 +792,13 @@ def _generate_from_messages(self, messages, images, num_tokens=None): attention_mask = processor_output["input_ids"].ne(pad_id).long() attention_mask = attention_mask.to(self.device) prompt_len = input_ids.shape[1] - packed_inputs = { - key: (value.to(self.device) if isinstance(value, torch.Tensor) else value) - for key, value in packed_inputs.items() - } + multimodal_inputs = to_model_multimodal_inputs(processor_output, self.device) with torch.no_grad(): outputs = self.model.generate( input_ids=input_ids, attention_mask=attention_mask, - packed_inputs=packed_inputs, + **multimodal_inputs, max_new_tokens=num_tokens or self.max_new_tokens, do_sample=False, pad_token_id=self.tokenizer.eos_token_id, @@ -796,16 +873,13 @@ def _generate_batch(self, prompts, images_list, num_tokens=None): input_ids = input_ids.to(self.device) attention_mask = attention_mask.to(self.device) - packed_inputs = { - k: (v.to(self.device) if isinstance(v, torch.Tensor) else v) - for k, v in processor_output["packed_inputs"].items() - } + multimodal_inputs = to_model_multimodal_inputs(processor_output, self.device) with torch.no_grad(): outputs = self.model.generate( input_ids=input_ids, attention_mask=attention_mask, - packed_inputs=packed_inputs, + **multimodal_inputs, max_new_tokens=num_tokens or self.max_new_tokens, do_sample=False, pad_token_id=self.tokenizer.eos_token_id, @@ -836,17 +910,15 @@ def test_logit_equivalence(self): ] prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() processor_output = self.processor(text=prompt, images=images, return_tensors="pt") - packed_inputs = processor_output["packed_inputs"] input_ids = processor_output["input_ids"] device = next(self.model.parameters()).device input_ids = input_ids.to(device) - # Move packed tensors to model device - packed_inputs = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in packed_inputs.items()} + multimodal_inputs = to_model_multimodal_inputs(processor_output, device) with torch.no_grad(): outputs = self.model.generate( input_ids=input_ids, - packed_inputs=packed_inputs, + **multimodal_inputs, max_new_tokens=num_tokens or self.max_new_tokens, do_sample=False, pad_token_id=self.tokenizer.eos_token_id, @@ -938,22 +1010,16 @@ def test_batched_generation_matches_individual(self): ] batch_outputs = self.processor(text=prompts, images=images_list, return_tensors="pt") batch_input_ids = batch_outputs["input_ids"] - batch_packed = batch_outputs["packed_inputs"] + batch_packed = batch_outputs sample_lengths = [output["input_ids"].squeeze(0).shape[0] for output in per_sample_outputs] max_length = max(sample_lengths) - expected_vision_patches = [] - expected_vision_grids = [] - expected_vision_offsets = [] - expected_vision_lengths = [] - expected_vision_batch_indices = [] - for i, (single_output, batch_ids, single_len) in enumerate( zip(per_sample_outputs, batch_input_ids, sample_lengths) ): single_ids = single_output["input_ids"].squeeze(0) - single_packed = single_output["packed_inputs"] + single_packed = single_output torch.testing.assert_close(batch_ids[-single_len:], single_ids) @@ -975,11 +1041,22 @@ def test_batched_generation_matches_individual(self): torch.testing.assert_close(batch_positions_row, expected_positions) if single_packed["vision_patches"] is not None: - expected_vision_patches.append(single_packed["vision_patches"]) - expected_vision_grids.append(single_packed["vision_token_grids"]) - expected_vision_offsets.append(single_packed["vision_token_offsets"]) - expected_vision_lengths.append(single_packed["vision_token_lengths"]) - expected_vision_batch_indices.append(torch.full_like(single_packed["vision_token_batch_indices"], i)) + expected_image_count = int(single_packed["vision_image_attention_mask"].sum().item()) + batch_image_count = int(batch_packed["vision_image_attention_mask"][i].sum().item()) + assert batch_image_count == expected_image_count + if expected_image_count > 0: + torch.testing.assert_close( + batch_packed["vision_token_grids"][i, :expected_image_count], + single_packed["vision_token_grids"][0, :expected_image_count], + ) + torch.testing.assert_close( + batch_packed["vision_token_offsets"][i, :expected_image_count], + single_packed["vision_token_offsets"][0, :expected_image_count], + ) + torch.testing.assert_close( + batch_packed["vision_token_lengths"][i, :expected_image_count], + single_packed["vision_token_lengths"][0, :expected_image_count], + ) if single_len == max_length: continue @@ -991,20 +1068,11 @@ def test_batched_generation_matches_individual(self): assert not torch.any(attention_mask[: max_length - single_len]), f"sample {i} mask ones inside left pad" assert torch.all(attention_mask[-single_len:]), f"sample {i} mask zeros inside content" - if expected_vision_patches: - torch.testing.assert_close(batch_packed["vision_patches"], torch.cat(expected_vision_patches, dim=0)) - torch.testing.assert_close(batch_packed["vision_token_grids"], torch.cat(expected_vision_grids, dim=0)) - torch.testing.assert_close(batch_packed["vision_token_offsets"], torch.cat(expected_vision_offsets, dim=0)) - torch.testing.assert_close(batch_packed["vision_token_lengths"], torch.cat(expected_vision_lengths, dim=0)) - torch.testing.assert_close( - batch_packed["vision_token_batch_indices"], torch.cat(expected_vision_batch_indices, dim=0) - ) - else: - assert batch_packed["vision_patches"] is None - assert batch_packed["vision_token_grids"] is None - assert batch_packed["vision_token_offsets"] is None - assert batch_packed["vision_token_lengths"] is None - assert batch_packed["vision_token_batch_indices"] is None + assert batch_packed["vision_patches"] is not None + assert batch_packed["vision_token_grids"] is not None + assert batch_packed["vision_token_offsets"] is not None + assert batch_packed["vision_token_lengths"] is not None + assert batch_packed["vision_image_attention_mask"] is not None batch_texts = self._generate_batch(prompts, images_list, num_tokens=100) assert len(batch_texts) == len(single_texts) == 3 @@ -1058,18 +1126,14 @@ def test_hf_generate_box_points(self): messages, images = document_to_messages(document, vision_token=self.hf_config.vision_token) prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() processor_output = self.processor(text=prompt, images=images, return_tensors="pt") - packed_inputs = processor_output["packed_inputs"] input_ids = processor_output["input_ids"].to(self.device) prompt_len = input_ids.shape[1] - packed_inputs = { - key: (value.to(self.device) if isinstance(value, torch.Tensor) else value) - for key, value in packed_inputs.items() - } + multimodal_inputs = to_model_multimodal_inputs(processor_output, self.device) with torch.no_grad(): outputs = self.model.generate( input_ids=input_ids, - packed_inputs=packed_inputs, + **multimodal_inputs, max_new_tokens=self.max_new_tokens, do_sample=False, pad_token_id=self.tokenizer.eos_token_id, From bf501ddbf1e0c5704555faa643c8cb191b873939 Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Thu, 19 Mar 2026 17:55:45 +0400 Subject: [PATCH 64/77] feat: rely on qwen3 backbone, flatten vision components, misc style changes, processor post-processing, expanded tests (#15) * fix: update imports * fix: replace removed check_model_inputs with merge_with_config_defaults and capture_outputs * fix: no capture outputs within capture outputs * refactor: move isaac vision internals to padded batched flow * refactor: align isaac vision attention with standard mask interfaces * refactor: remove packed_inputs from isaac model api and generation path * chore: purge isaac packing internals and sync modular outputs * refactor: remove isaac packing pipeline and align with transformers batched attention standards * refactor: drop final isaac packed compatibility path * refactor: use OutputRecorder for isaac hidden states * refactor: remove manual output_attentions handling in isaac model * refactor: rely on output recorder for isaac attentions * fix: do not deepcopy text config * style: remove overly defensive checks * style: remove unneeded pops * refactor: simplify pixshuf * style: drop unused vision_model alias * wip simplify * wip simplify 2 * perf: remove device syncs * test: add isaac pixel shuffle strict invariant characterization * refactor: make isaac pixel shuffle tensor-only with strict invariants * chore: regenerate isaac generated files after modular pixel shuffle refactor * style: drop redunant check * refactor: simplify config wiring * refactor: unify multimodal check for input preparation * refactor: drop now redundant init override * style: drop unused attention mask flow through pixel shuffle * style: collapse resize callsite for readability * style: drop more redundant checks * refactor: rely on siglip2 for viison attention * refactor: enforce invariant * refactor: simplify processor * fix: add post init call to vision transformer * style: update year * style: drop redundant sdpa standard * refactor: drop unneeded decorators when inheriting from pretrained model * refactor: delegate to siglip init * refactor: base class does rgb check * style: no need for positional args * docs: split up isaac vision transformer docstrings * style: don't save self attributes unused in forward * style: drop unneeded image processor identifier * style: remove explicit setting of now auto-discovered config settings * style: remove option for positional args * style: remove self.config; is unused * style: drop redundant vocab size handling * style: drop redundant can generate decorator * style: drop unneeded license * refactor: rely on _tied_weight_keys attr * refactor: chat template in init * style: drop redundant fields * fix: proper types * refactor: rely on qwen3vl functionality * fix: config * test: rope ids * fix: get rope working by adapting to qwen3vl properly * style: drop unused current processor * fix: remove need for config in processor init * docs: move IsaacProcessor to auto_docstring * refactor: WIP big refac * refactor: explicitly pass vision token components * style: drop unneeded siglip args * fix: use is_first_iteration and use_cache to catch edge cases * feat: remove cache position * fix: assume inputs are on correct device * refactor: remove properties * refactor: move input validation * refactor: WIP use get_image_features and placeholder_mask standard * refactor: wip 2 get_image_featutes and placeholder_mask * fix: drop unneeded set_input_embeddings * fix: simplify vision inputs check * test: multimodal test inputs * fix: ignore isaac specific keys at text config init * chore: generated files * refactor: WIP isolated text model * fix: wip drop double capture * fix: return tuple to properly track outputs * refactor: wip rely on base config * fix: fix access pattern in projector * refactor: rely on qwen3 config for num hidden layers * refactor: stop mirroring hidden size * test: drop useless test * style: drop unneeded vocab size setting * test: don't read from base config * style: drop custom type * refactor: move to canonical approach for scattering image features * chore: post merge re-generation * refactor: remove cache position * refactor: drop extra check for generation phase wip * refactor: wip rope index change update * fix: input ids is never None * style: attention mask is never none * refactor: inline helper logic * refactor: delegate attention mask handling to backbone * refactor: rely on transormers utilities and invariants to batch images * refactor: derive token type ids from input ids at the very end * refactor: move image level logic to isaacimageprocessorfast * refactor: call on batch of images! * refactor: all image logic in image processor * fix: drop device handling for processors * refactor: more image processor isolation * refactor: operate on expanded text wip 2 * refactor: reduce verbosity wip 4 * feat: post processor * WIP 1: get vision position ids * refactor: drop needless posid handling * refactor: drop unneeded error * refactor: move position id handling to model in get_rope_index wip 2 * refactor: deduplicate position id logic * refactor: use tokenizer with left side padding * refactor: drop virtual dims * refactor: remove vision image attention mask * test: post processing * wip * wip 2 simplify * wip 3 * wip 4 * test: image processor tests * test: use processing common standard for testing processor * wip 5 * refactor: rely on config for special token attribute tracking * style: remove rope index arbitrary posargs * docs: no pixel_values arg exists * refactor: rely on library standard assumptions * docs: improve docstrings * test: use processor utility directly in model test * wip * wip 2 * wip 3 * fix: set defaults * refactor: remove high level vision embed class * refactor: simplify position id computation branching logic * chore: artifacts --- src/transformers/models/isaac/__init__.py | 2 +- .../models/isaac/configuration_isaac.py | 164 +- .../isaac/image_processing_isaac_fast.py | 117 +- .../models/isaac/modeling_isaac.py | 1488 +++++++++------ .../models/isaac/modular_isaac.py | 1605 +++++++++-------- .../models/isaac/processing_isaac.py | 502 +++--- tests/models/isaac/__init__.py | 13 - .../isaac/test_image_processing_isaac.py | 417 +++++ tests/models/isaac/test_modeling_isaac.py | 227 +-- .../isaac/test_post_processing_isaac.py | 102 ++ tests/models/isaac/test_processing_isaac.py | 741 ++++---- 11 files changed, 3266 insertions(+), 2112 deletions(-) create mode 100644 tests/models/isaac/test_image_processing_isaac.py create mode 100644 tests/models/isaac/test_post_processing_isaac.py diff --git a/src/transformers/models/isaac/__init__.py b/src/transformers/models/isaac/__init__.py index 8ff2b88ec9af..bc0f3fcc6d7c 100644 --- a/src/transformers/models/isaac/__init__.py +++ b/src/transformers/models/isaac/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index ddd12e55c958..28fc96790ca3 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -4,7 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved. +# Copyright 2026 Perceptron, Inc and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,9 +19,11 @@ # limitations under the License. from ...configuration_utils import PreTrainedConfig, PretrainedConfig, layer_type_validation -from ...models.qwen3.configuration_qwen3 import Qwen3Config +from ...modeling_rope_utils import RopeParameters +from ...utils import auto_docstring +@auto_docstring(checkpoint="google/isaac-base-patch16-naflex") class IsaacVisionConfig(PreTrainedConfig): """Vision configuration for Isaac with Pixel Shuffle support. @@ -30,8 +32,6 @@ class IsaacVisionConfig(PreTrainedConfig): Args: pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1): Spatial factor applied before pixel shuffle reduces the resolution. - num_patches (`int`, *optional*, defaults to 256): - Maximum number of learnable positional embeddings to initialize. """ model_type = "isaac_vision" @@ -64,13 +64,118 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.hidden_act = hidden_act self.num_patches = num_patches - # Add our custom fields self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor - # Ensure a sensible default attention backend - if getattr(self, "_attn_implementation", None) is None: - self._attn_implementation = "sdpa" + +@auto_docstring(checkpoint="Qwen/IsaacText-8B") +class IsaacTextConfig(PreTrainedConfig): + r""" + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any + additional layer afterwards will use SWA (Sliding Window Attention). + + ```python + >>> from transformers import IsaacTextModel, IsaacTextConfig + + >>> # Initializing a IsaacText style configuration + >>> configuration = IsaacTextConfig() + + >>> # Initializing a model from the IsaacText-8B style configuration + >>> model = IsaacTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "isaac_text" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `IsaacText` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size: int | None = 151936, + hidden_size: int | None = 4096, + intermediate_size: int | None = 22016, + num_hidden_layers: int | None = 32, + num_attention_heads: int | None = 32, + num_key_value_heads: int | None = 32, + head_dim: int | None = 128, + hidden_act: str | None = "silu", + max_position_embeddings: int | None = 32768, + initializer_range: float | None = 0.02, + rms_norm_eps: float | None = 1e-6, + use_cache: bool | None = True, + tie_word_embeddings: bool | None = False, + rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None, + attention_bias: bool | None = False, + use_sliding_window: bool | None = False, + sliding_window: int | None = 4096, + max_window_layers: int | None = 28, + layer_types: list[str] | None = None, + attention_dropout: float | None = 0.0, + pad_token_id: int | None = None, + bos_token_id: int | None = None, + eos_token_id: int | None = None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if self.use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.tie_word_embeddings = tie_word_embeddings + self.rope_parameters = rope_parameters + + super().__init__(**kwargs) class IsaacConfig(PretrainedConfig): @@ -78,16 +183,27 @@ class IsaacConfig(PretrainedConfig): This configuration corresponds to checkpoints such as [Perceptron/isaac-base](https://huggingface.co/Perceptron/isaac-base). + + Args: + vision_config (`IsaacVisionConfig`, *optional*): + Configuration for the Isaac vision tower. If unset, the default [`IsaacVisionConfig`] is used. + text_config (`IsaacTextConfig` or `dict`, *optional*): + Configuration for the text backbone. Dictionaries are converted to [`IsaacTextConfig`]. + vision_rescale_factor (`float`, *optional*, defaults to 1 / 255): + Rescale factor applied by the image processor before normalization. + max_sequence_length (`int`, *optional*, defaults to 16384): + Maximum multimodal sequence length produced by the processor and expected by the model. + vision_token (`str`, *optional*, defaults to `""`): + Placeholder string inserted into text prompts to mark image positions. """ model_type = "isaac" - sub_configs = {"vision_config": IsaacVisionConfig, "text_config": Qwen3Config} - image_processor_type = "IsaacImageProcessor" + sub_configs = {"vision_config": IsaacVisionConfig, "text_config": IsaacTextConfig} def __init__( self, vision_config: IsaacVisionConfig | None = None, - text_config: Qwen3Config | dict | None = None, + text_config: IsaacTextConfig | dict | None = None, vision_rescale_factor: float = 1 / 255, max_sequence_length: int = 16384, vision_token: str = "", @@ -95,7 +211,7 @@ def __init__( ): if isinstance(text_config, dict): self.text_config = self.sub_configs["text_config"](**text_config) - elif isinstance(text_config, Qwen3Config): + elif isinstance(text_config, IsaacTextConfig): self.text_config = text_config elif text_config is None: self.text_config = self.sub_configs["text_config"]() @@ -107,32 +223,14 @@ def __init__( elif vision_config is None: self.vision_config = self.sub_configs["vision_config"]() - # Seed RoPE parameters before base init so the shared mixin can standardize/validate them. - self.rope_parameters = getattr(self.text_config, "rope_parameters", None) - self.layer_types = getattr(self.text_config, "layer_types", None) - super().__init__(**kwargs) - # Keep rope parameters aligned between the composite and text sub-configs. - self.text_config.rope_parameters = self.rope_parameters - - # Mirror frequently accessed Qwen3 attributes at the composite config level - self.vocab_size = self.text_config.vocab_size - self.hidden_size = self.text_config.hidden_size - self.num_hidden_layers = self.text_config.num_hidden_layers - self.num_attention_heads = self.text_config.num_attention_heads - self.head_dim = self.text_config.head_dim - self.hidden_act = self.text_config.hidden_act + # Mirror frequently accessed composite-level attributes. self.use_cache = self.text_config.use_cache - self.rope_theta = self.rope_parameters["rope_theta"] + self.rope_theta = self.text_config.rope_parameters["rope_theta"] self.max_position_embeddings = getattr(self.text_config, "max_position_embeddings", max_sequence_length) self.text_config.max_position_embeddings = self.max_position_embeddings - self.layer_types = getattr(self.text_config, "layer_types", None) - layer_type_validation(self.layer_types, self.num_hidden_layers) - - if getattr(self, "_attn_implementation", None) is None: - self._attn_implementation = "sdpa" # Vision normalization parameters self.vision_rescale_factor = float(vision_rescale_factor) @@ -141,4 +239,4 @@ def __init__( self.vision_token = vision_token -__all__ = ["IsaacConfig"] +__all__ = ["IsaacConfig", "IsaacTextConfig", "IsaacVisionConfig"] diff --git a/src/transformers/models/isaac/image_processing_isaac_fast.py b/src/transformers/models/isaac/image_processing_isaac_fast.py index fe2b3d5fa8dd..db1a8b52a756 100644 --- a/src/transformers/models/isaac/image_processing_isaac_fast.py +++ b/src/transformers/models/isaac/image_processing_isaac_fast.py @@ -4,7 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved. +# Copyright 2026 Perceptron, Inc and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ from ...feature_extraction_utils import BatchFeature from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict, group_images_by_shape, reorder_images -from ...image_utils import PILImageResampling +from ...image_utils import ImageInput, PILImageResampling, make_nested_list_of_images from ...utils import TensorType, auto_docstring from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD @@ -156,7 +156,11 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): MAX_PIXELS = 60_000_000 # 60โ€‘megapixel ceiling โ‰ˆ 8200 ร— 7300 px resample = PILImageResampling.BILINEAR - model_input_names = ["patches", "token_grids"] + model_input_names = [ + "vision_patches", + "vision_patch_attention_mask", + "vision_token_grids", + ] valid_kwargs = IsaacImageProcessorFastKwargs unused_kwargs = ["size", "do_center_crop", "crop_size", "pad_size", "do_pad"] @@ -179,6 +183,14 @@ def _validate_preprocess_kwargs(self, **kwargs): kwargs.pop("do_resize", None) return super()._validate_preprocess_kwargs(**kwargs) + def _prepare_images_structure( + self, + images: ImageInput, + expected_ndims: int = 3, + ) -> ImageInput: + images = self.fetch_images(images) + return make_nested_list_of_images(images, expected_ndims=expected_ndims) + def resize( self, image: torch.Tensor, @@ -189,7 +201,7 @@ def resize( def _preprocess( self, - images: list[torch.Tensor], + images: list[list[torch.Tensor]], do_resize: bool, interpolation: Any | None, do_rescale: bool | None, @@ -199,23 +211,29 @@ def _preprocess( image_std: float | Sequence[float] | None, disable_grouping: bool | None = None, return_tensors: str | TensorType | None = None, - *, patch_size: int | None = None, max_num_patches: int | None = None, min_num_patches: int | None = None, pixel_shuffle_scale: int | None = None, **kwargs, ) -> BatchFeature: - grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + batch_size = len(images) + if all(len(sample_images) == 0 for sample_images in images): + tensors = { + "vision_patches": torch.zeros((batch_size, 0, 0, 0), dtype=torch.float32), + "vision_patch_attention_mask": torch.zeros((batch_size, 0, 0), dtype=torch.long), + "vision_token_grids": torch.zeros((batch_size, 0, 2), dtype=torch.long), + } + return BatchFeature(data=tensors, tensor_type=return_tensors) + + grouped_images, grouped_images_index = group_images_by_shape( + images, disable_grouping=disable_grouping, is_nested=True + ) grouped_outputs = {} for shape, stacked_images in grouped_images.items(): - batch_size, channels, original_height, original_width = stacked_images.shape - - if bool(self.do_convert_rgb) and channels == 1: - stacked_images = stacked_images.repeat(1, 3, 1, 1) - + grouped_batch_size, channels, original_height, original_width = stacked_images.shape target_height, target_width = get_image_size_for_max_num_patches( original_height, original_width, @@ -245,49 +263,66 @@ def _preprocess( ) patches = torch_extract_patches(image_batch, patch_size, patch_size) - _, height_tokens, width_tokens, _ = patches.shape + _, height_tokens, width_tokens, patch_dim = patches.shape token_grid = ( - torch.tensor([height_tokens, width_tokens], device=patches.device).long().expand(batch_size, 2) - ) - - real_dim = ( - torch.tensor( - [1, height_tokens, width_tokens], - dtype=torch.long, - device=patches.device, - ) - .unsqueeze(0) - .repeat(batch_size, 1) + torch.tensor([height_tokens, width_tokens], device=patches.device).long().expand(grouped_batch_size, 2) ) if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): raise ValueError( f"Token grid (h={height_tokens}, w={width_tokens}) must be divisible by pixel_shuffle_scale={pixel_shuffle_scale}; adjust resize/patch parameters or disable pixel shuffle." ) - virtual_height = height_tokens // pixel_shuffle_scale - virtual_width = width_tokens // pixel_shuffle_scale - - virtual_dim = ( - torch.tensor( - [1, virtual_height, virtual_width], - dtype=torch.long, - device=patches.device, - ) - .unsqueeze(0) - .repeat(batch_size, 1) - ) - grouped_outputs[shape] = (patches, token_grid, virtual_dim, real_dim) - keys = ("patches", "token_grids", "virtual_pixel_size", "real_pixel_size") - tensors: dict[str, torch.Tensor] = {} + grouped_outputs[shape] = ( + patches.reshape(grouped_batch_size, -1, patch_dim), + token_grid, + ) - for i, key in enumerate(keys): - slices = reorder_images( + keys = ("vision_patches", "vision_token_grids") + nested_outputs = { + key: reorder_images( {shape: values[i] for shape, values in grouped_outputs.items()}, grouped_images_index, + is_nested=True, ) - tensors[key] = torch.stack(slices, dim=0) + for i, key in enumerate(keys) + } + + first_patch = next( + patches for sample_patches in nested_outputs["vision_patches"] for patches in sample_patches + ) + max_images = max(len(sample_patches) for sample_patches in nested_outputs["vision_patches"]) + patch_dim = first_patch.shape[-1] + patch_dtype = first_patch.dtype + patch_device = first_patch.device + max_patches = max( + patches.shape[0] for sample_patches in nested_outputs["vision_patches"] for patches in sample_patches + ) + + tensors = { + "vision_patches": torch.zeros( + (batch_size, max_images, max_patches, patch_dim), device=patch_device, dtype=patch_dtype + ), + "vision_patch_attention_mask": torch.zeros( + (batch_size, max_images, max_patches), device=patch_device, dtype=torch.long + ), + "vision_token_grids": torch.zeros((batch_size, max_images, 2), device=patch_device, dtype=torch.long), + } + + for batch_idx, sample_patches in enumerate(nested_outputs["vision_patches"]): + sample_image_count = len(sample_patches) + if sample_image_count == 0: + continue + + for image_idx, patches in enumerate(sample_patches): + patch_count = int(patches.shape[0]) + + tensors["vision_patches"][batch_idx, image_idx, :patch_count] = patches + tensors["vision_patch_attention_mask"][batch_idx, image_idx, :patch_count] = 1 + tensors["vision_token_grids"][batch_idx, image_idx] = nested_outputs["vision_token_grids"][batch_idx][ + image_idx + ] return BatchFeature(data=tensors, tensor_type=return_tensors) diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 5ce71788ceae..e1236be08020 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -4,7 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved. +# Copyright 2026 Perceptron, Inc and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,8 +21,7 @@ import copy from collections.abc import Callable -from enum import IntEnum -from typing import Any, Optional +from typing import Any, NamedTuple, Optional from ... import initialization as init from ...activations import ACT2FN @@ -30,13 +29,18 @@ from ...generation.utils import GenerationMixin from ...image_processing_utils_fast import ImagesKwargs from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func -from ...masking_utils import create_bidirectional_mask, create_masks_for_generate +from ...masking_utils import create_bidirectional_mask, create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + CausalLMOutputWithPast, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3DecoderLayer, Qwen3Model, Qwen3PreTrainedModel +from ...models.qwen3.modeling_qwen3 import Qwen3PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, torch_compilable_check from ...utils.generic import TransformersKwargs, can_return_tuple, maybe_autocast, merge_with_config_defaults @@ -44,8 +48,8 @@ is_torch_available, is_torchdynamo_compiling, ) -from ...utils.output_capturing import OutputRecorder, capture_outputs -from .configuration_isaac import IsaacConfig, IsaacVisionConfig +from ...utils.output_capturing import capture_outputs +from .configuration_isaac import IsaacConfig, IsaacTextConfig, IsaacVisionConfig if is_torch_available(): @@ -54,17 +58,18 @@ import torch.nn.functional as F -class ModalityType(IntEnum): - """ - Modality identifiers for events. +class SinglePoint(NamedTuple): + x: int + y: int + mention: str | None = None + t: float | None = None - Members: - image: Vision tokens (e.g., patches). - text: Textual tokens. - """ - image = 0 - text = 1 +class BoundingBox(NamedTuple): + top_left: SinglePoint + bottom_right: SinglePoint + mention: str | None = None + t: float | None = None class IsaacImageProcessorFastKwargs(ImagesKwargs, total=False): @@ -79,10 +84,10 @@ class IsaacImageProcessorFastKwargs(ImagesKwargs, total=False): Pixel-shuffle reduction factor applied in the vision tower. """ - patch_size: int | None - max_num_patches: int | None - min_num_patches: int | None - pixel_shuffle_scale: int | None + patch_size: int + max_num_patches: int + min_num_patches: int + pixel_shuffle_scale: int class IsaacVisionEmbeddings(nn.Module): @@ -113,7 +118,6 @@ def __init__(self, config: IsaacVisionConfig): self.embed_dim, ) ) - nn.init.normal_(self.position_embedding) @staticmethod def resize_positional_embeddings( @@ -177,8 +181,6 @@ def resize_positional_embeddings( return resulted_positional_embeddings - @merge_with_config_defaults - @capture_outputs def forward( self, pixel_values: torch.Tensor, @@ -389,10 +391,8 @@ def pixel_shuffle_padded( Spatial down-sampling factor. Returns: - Tuple of: - - pixel-shuffled embeddings `(num_images, max_tokens, hidden_size * scale_factor**2)` - - attention mask `(num_images, max_tokens)` - - per-image valid token lengths `(num_images,)` + `torch.Tensor`: Pixel-shuffled embeddings of shape + `(num_images, max_tokens, hidden_size * scale_factor**2)`. """ num_images, max_patches, embed_dim = x.shape output_dim = embed_dim * scale_factor * scale_factor @@ -441,19 +441,14 @@ class IsaacVisionTransformer(PreTrainedModel): Args: config (IsaacVisionConfig): Vision configuration with pixel-shuffle and patching parameters. - Inputs: - vision_tokens (Tuple[Tensor, Tensor, Optional[Tensor]]): - `(patches, token_grids, patch_attention_mask)` where: - - `patches`: `(num_images, max_patches, patch_dim)` - - `token_grids`: `(num_images, 2)` with per-image `(H_tokens, W_tokens)` - - `patch_attention_mask`: `(num_images, max_patches)` or `None` - - Returns: - Tuple of `(pixel_shuffled_features, attention_mask, token_lengths)`. """ _supports_sdpa = True _supports_flash_attn = True + _can_record_outputs = { + "hidden_states": IsaacVisionEncoderLayer, + "attentions": IsaacVisionAttention, + } def __init__(self, config: IsaacVisionConfig): super().__init__(config) @@ -471,18 +466,30 @@ def _init_weights(self, module): if isinstance(module, IsaacVisionEmbeddings): init.zeros_(module.position_embedding) + @merge_with_config_defaults + @capture_outputs(tie_last_hidden_states=False) def forward( self, - vision_tokens: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor | None], - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if len(vision_tokens) == 2: - seq_patches, token_grids = vision_tokens - vision_patch_attention_mask = None - else: - seq_patches, token_grids, vision_patch_attention_mask = vision_tokens + vision_patches: torch.Tensor, + vision_token_grids: torch.Tensor, + vision_patch_attention_mask: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + """ + Inputs: + vision_patches (`torch.Tensor`): + Patches shaped `(num_images, max_patches, patch_dim)`. + vision_token_grids (`torch.Tensor`): + Token grids shaped `(num_images, 2)` with per-image `(H_tokens, W_tokens)`. + vision_patch_attention_mask (`torch.Tensor`): + Patch mask shaped `(num_images, max_patches)`. + + Returns: + `BaseModelOutputWithPooling` with pixel-shuffled embeddings in `last_hidden_state`. + """ hidden_states = self.embeddings( - seq_patches, - token_grids, + vision_patches, + vision_token_grids, attention_mask=vision_patch_attention_mask, ) @@ -491,28 +498,34 @@ def forward( inputs_embeds=hidden_states, attention_mask=vision_patch_attention_mask, ) - encoder_outputs = self.encoder(inputs_embeds=hidden_states, attention_mask=encoder_attention_mask) + encoder_outputs = self.encoder(inputs_embeds=hidden_states, attention_mask=encoder_attention_mask, **kwargs) hidden_states = self.post_layernorm(encoder_outputs.last_hidden_state) - return pixel_shuffle_padded( + hidden_states = pixel_shuffle_padded( x=hidden_states, - token_grids=token_grids, + token_grids=vision_token_grids, scale_factor=self.pixel_shuffle_scale_factor, ) + return BaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=None, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + class IsaacMultiModalProjector(nn.Module): """Maps vision tower outputs to the text hidden size with a SiLU MLP.""" def __init__(self, config: IsaacConfig): super().__init__() - self.vision_hidden_size = config.vision_config.hidden_size * ( - config.vision_config.pixel_shuffle_scale_factor**2 - ) - self.backbone_hidden_size = config.hidden_size - self.linear_1 = nn.Linear(self.vision_hidden_size, 4 * self.vision_hidden_size, bias=False) + text_config = config.get_text_config() + vision_hidden_size = config.vision_config.hidden_size * (config.vision_config.pixel_shuffle_scale_factor**2) + backbone_hidden_size = text_config.hidden_size + self.linear_1 = nn.Linear(vision_hidden_size, 4 * vision_hidden_size, bias=False) self.silu = nn.SiLU() - self.linear_2 = nn.Linear(4 * self.vision_hidden_size, self.backbone_hidden_size, bias=False) + self.linear_2 = nn.Linear(4 * vision_hidden_size, backbone_hidden_size, bias=False) def forward(self, image_features): hidden_states = self.linear_1(image_features) @@ -521,32 +534,14 @@ def forward(self, image_features): return hidden_states -class IsaacVisionEmbedding(nn.Module): - def __init__(self, config: IsaacConfig): - super().__init__() - vision_cfg = config.vision_config - - self.vision_tower = IsaacVisionTransformer(vision_cfg) - self.multimodal_projector = IsaacMultiModalProjector(config) - - def forward( - self, - vision_tokens: tuple[torch.Tensor, torch.Tensor, torch.Tensor | None], - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - vision_patches, token_grids, vision_patch_attention_mask = vision_tokens - hidden_states = self.vision_tower((vision_patches, token_grids, vision_patch_attention_mask)) - projected = self.multimodal_projector(hidden_states) - return projected - - class IsaacRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` - def __init__(self, config: IsaacConfig, device=None): + def __init__(self, config: IsaacConfig | IsaacTextConfig, device=None): super().__init__() - rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config - rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} + rope_source_cfg = config.get_text_config() config_for_rope = copy.copy(rope_source_cfg) + rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} config_for_rope.rope_scaling = rope_scaling self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings @@ -562,13 +557,12 @@ def __init__(self, config: IsaacConfig, device=None): self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) - rotary_half_dim = self.inv_freq.shape[0] - self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim) - self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size + self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), self.inv_freq.shape[0]) + self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) @staticmethod def compute_default_rope_parameters( - config: IsaacConfig | None = None, + config: IsaacTextConfig | None = None, device: Optional["torch.device"] = None, seq_len: int | None = None, ) -> tuple["torch.Tensor", float]: @@ -596,48 +590,38 @@ def compute_default_rope_parameters( ) return inv_freq, attention_factor - # Ignore copy - def forward( - self, - position_ids: torch.Tensor, - modality_tensor: torch.Tensor, - hidden_states: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if hidden_states is None: - batch, seq_len, _ = position_ids.shape - hidden_states = torch.zeros( - batch, - seq_len, - self.hidden_size, - dtype=torch.float32, - device=position_ids.device, - ) - - with torch.no_grad(): - pos = position_ids.clone() - not_spatial = modality_tensor == 1 - data_1d = pos[not_spatial][..., 0].unsqueeze(-1) # Collapse non-vision modalities to 1D positions - pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) - pos_axes = pos.permute(2, 0, 1).contiguous() - - inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, pos_axes.shape[1], -1, 1) - pos_axes_expanded = pos_axes[:, :, None, :].float() # shape (3, bs, 1, positions) + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Isaac has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) - device_type = ( - hidden_states.device.type - if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps" - else "cpu" - ) + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ pos_axes_expanded.float()).transpose(2, 3) + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling - cos_axes, sin_axes = cos.to(hidden_states.dtype), sin.to(hidden_states.dtype) - cos_combined, sin_combined = self._combine_axes(cos_axes), self._combine_axes(sin_axes) - - return cos_combined, sin_combined + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def apply_interleaved_mrope(self, freqs: torch.Tensor, mrope_section: list[int]) -> torch.Tensor: + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THWTHWTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + chunks = freqs.split(tuple(mrope_section), dim=-1) + return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) @staticmethod def _resolve_mrope_section(section: list[int] | None, rotary_half_dim: int) -> list[int]: @@ -650,334 +634,12 @@ def _resolve_mrope_section(section: list[int] | None, rotary_half_dim: int) -> l section = [int(v) for v in section] return section - def _combine_axes(self, tensor: torch.Tensor) -> torch.Tensor: - split_sections = tuple(self.mrope_section * 2) - chunks = tensor.split(split_sections, dim=-1) - return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) - - -@auto_docstring -class IsaacModel(PreTrainedModel): - config: IsaacConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["IsaacDecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn = True - _supports_sdpa = True - _supports_flex_attn = False - _can_compile_fullgraph = False - _supports_attention_backend = True - _can_record_outputs = { - "hidden_states": OutputRecorder(Qwen3DecoderLayer), - "attentions": Qwen3Attention, - "vision_attentions": IsaacVisionAttention, - } - all_tied_weights_keys: dict[str, str] = {} - - def __init__(self, config: IsaacConfig): - Qwen3PreTrainedModel.__init__(self, config) - self.text_model = Qwen3Model._from_config(config.text_config) - - self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) - - self.vision_embedding = IsaacVisionEmbedding(config) - self.max_sequence_length = config.max_sequence_length - self.vision_rescale_factor = config.vision_rescale_factor - self.vision_token = config.vision_token - self.rope_deltas = None - - self.post_init() - - def get_input_embeddings(self) -> nn.Module: - return self.text_model.get_input_embeddings() - - def set_input_embeddings(self, value: nn.Module) -> None: - self.text_model.set_input_embeddings(value) - vocab_size = getattr(value, "num_embeddings", None) - if vocab_size is not None: - self.config.vocab_size = vocab_size - if hasattr(self.config, "text_config"): - self.config.text_config.vocab_size = vocab_size - self.text_model.config.vocab_size = vocab_size - - @property - def final_norm(self) -> nn.Module: - return self.text_model.norm - - @property - def embed_tokens(self) -> nn.Module: - return self.text_model.embed_tokens - - @embed_tokens.setter - def embed_tokens(self, value: nn.Module) -> None: - self.text_model.embed_tokens = value - - def embed_multimodal_inputs( - self, - input_ids: torch.Tensor, - modality_tensor: torch.Tensor, - vision_patches: torch.Tensor, - vision_token_grids: torch.Tensor, - vision_patch_attention_mask: torch.Tensor | None = None, - vision_token_offsets: torch.Tensor | None = None, - vision_token_lengths: torch.Tensor | None = None, - vision_image_attention_mask: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - modality = modality_tensor.to(device=input_ids.device, dtype=torch.long) - embeds = self.text_model.embed_tokens(input_ids) - image_token_mask = modality == ModalityType.image.value - - if vision_patches is None or vision_token_grids is None: - if torch.any(image_token_mask): - raise ValueError("Image placeholders require `vision_patches` and `vision_token_grids`.") - return embeds, modality - - vision_patches = vision_patches.to(device=embeds.device) - token_grids = vision_token_grids.to(device=embeds.device, dtype=torch.long) - image_attention_mask = ( - vision_image_attention_mask.to(device=embeds.device, dtype=torch.bool) - if vision_image_attention_mask is not None - else torch.ones(token_grids.shape[:2], device=embeds.device, dtype=torch.bool) - ) - patch_attention_mask = ( - vision_patch_attention_mask.to(device=embeds.device, dtype=torch.long) - if vision_patch_attention_mask is not None - else torch.ones(vision_patches.shape[:3], device=embeds.device, dtype=torch.long) - ) - offsets = ( - vision_token_offsets.to(device=embeds.device, dtype=torch.long) - if vision_token_offsets is not None - else torch.zeros(token_grids.shape[:2], device=embeds.device, dtype=torch.long) - ) - reduction_factor = int(self.config.vision_config.pixel_shuffle_scale_factor) ** 2 - lengths = ( - vision_token_lengths.to(device=embeds.device, dtype=torch.long) - if vision_token_lengths is not None - else token_grids.prod(-1).div(reduction_factor, rounding_mode="floor").to(dtype=torch.long) - ) - - flat_vision_patches = vision_patches[image_attention_mask] - flat_patch_attention_mask = patch_attention_mask[image_attention_mask] - flat_token_grids = token_grids[image_attention_mask] - flat_offsets = offsets[image_attention_mask] - flat_lengths = lengths[image_attention_mask] - - vision_embeddings = self.vision_embedding((flat_vision_patches, flat_token_grids, flat_patch_attention_mask)) - token_positions = torch.arange(flat_lengths.max(), device=embeds.device, dtype=torch.long) - gather_positions = flat_offsets[:, None] + token_positions[None, :] - gather_mask = token_positions[None, :] < flat_lengths[:, None] - image_features = vision_embeddings[ - torch.arange(vision_embeddings.shape[0], device=embeds.device, dtype=torch.long)[:, None], - gather_positions, - ][gather_mask] - scatter_mask = image_token_mask.unsqueeze(-1).expand_as(embeds) - embeds = embeds.masked_scatter(scatter_mask, image_features) - - return embeds, modality - - def get_rope_index( - self, - *, - position_ids: torch.Tensor | None = None, - attention_mask: torch.Tensor, - inputs_embeds: torch.Tensor, - cache_position: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Prepare multimodal RoPE positions and carry forward per-batch offsets. - - Unlike vanilla 1D RoPE, Isaac builds 3-axis indices for text and vision tokens. - If callers do not supply positions, we synthesize them from `cache_position` and - use `attention_mask` to strip left padding so pad tokens never consume RoPE slots. - The returned `rope_deltas` capture any custom offset (i.e., prefill length) and - are reused across generation steps so newly decoded tokens keep counting forward - after the cached prefix.""" - - device = inputs_embeds.device - batch_size, seq_len = inputs_embeds.shape[:2] - - if position_ids is None: - cp = cache_position.to(device=device, dtype=torch.long) - if cp.ndim == 1: - cp = cp.view(1, -1).expand(batch_size or 1, -1) - - is_new_prefill = cp[:, :1].eq(0).all(dim=1, keepdim=True) - if self.rope_deltas is None: - base_delta = torch.zeros((batch_size, 1), device=device, dtype=torch.long) - else: - previous_delta = torch.as_tensor(self.rope_deltas, device=device, dtype=torch.long).reshape(-1, 1) - previous_delta = torch.broadcast_to(previous_delta, (batch_size, 1)) - base_delta = torch.where(is_new_prefill, torch.zeros_like(previous_delta), previous_delta) - - mask_delta = attention_mask.to(device=device, dtype=torch.long).sum(1, keepdim=True) - attention_mask.size( - 1 - ) - rope_position = cp + base_delta + mask_delta - pos_3d = rope_position.unsqueeze(-1).expand(-1, -1, 3) - return pos_3d, base_delta - - position_ids = position_ids.to(device=device) - if position_ids.ndim == 2: - position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) - - if position_ids.shape[1] != seq_len: - start_positions = position_ids[:, :1, 0] - position_ids = torch.arange(seq_len, device=position_ids.device).view(1, -1) + start_positions - position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) - - attn = attention_mask.to(device=device, dtype=torch.long) - m_per_batch = position_ids.amax(dim=(1, 2)) - seq_lens = attn.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=device) - rope_deltas = (m_per_batch + 1 - seq_lens).to(dtype=position_ids.dtype).unsqueeze(1) - return position_ids, rope_deltas - - @auto_docstring - @merge_with_config_defaults - @capture_outputs - def forward( - self, - input_ids: torch.LongTensor | None = None, - modality_tensor: torch.LongTensor | None = None, - vision_patches: torch.Tensor | None = None, - vision_patch_attention_mask: torch.Tensor | None = None, - vision_token_grids: torch.LongTensor | None = None, - vision_token_offsets: torch.LongTensor | None = None, - vision_token_lengths: torch.LongTensor | None = None, - vision_image_attention_mask: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | None = None, - inputs_embeds: torch.FloatTensor | None = None, - use_cache: bool | None = None, - cache_position: torch.LongTensor | None = None, - **kwargs: Unpack[TransformersKwargs], - ) -> tuple | BaseModelOutputWithPast: - """ - Forward pass with MRoPE position embeddings. - - Computes position embeddings once and passes them through all layers. - - Args: - modality_tensor (`torch.LongTensor`, *optional*): - Modality identifiers aligned with the embedded sequence, shaped `(batch_size, seq_len)` and containing - values from `ModalityType`. Treated as text-only when omitted. - vision_patches (`torch.FloatTensor`, *optional*): - Padded per-image patch vectors of shape `(batch_size, max_images, max_patches, patch_dim)`. - vision_patch_attention_mask (`torch.LongTensor`, *optional*): - Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. - vision_token_grids (`torch.LongTensor`, *optional*): - Per-image patch grids `(h, w)` with shape `(batch_size, max_images, 2)`. - vision_token_offsets (`torch.LongTensor`, *optional*): - Start offsets inside the per-image vision embedding sequence, shape `(batch_size, max_images)`. - vision_token_lengths (`torch.LongTensor`, *optional*): - Number of vision tokens to consume per image, shape `(batch_size, max_images)`. - vision_image_attention_mask (`torch.LongTensor`, *optional*): - Mask indicating which image slots are populated, shape `(batch_size, max_images)`. - """ - - if inputs_embeds is None: - if input_ids is None: - raise ValueError("`input_ids` or `inputs_embeds` must be provided.") - - has_vision_inputs = any( - value is not None - for value in ( - vision_patches, - vision_patch_attention_mask, - vision_token_grids, - vision_token_offsets, - vision_token_lengths, - vision_image_attention_mask, - ) - ) - if modality_tensor is not None or has_vision_inputs: - if modality_tensor is None: - modality_tensor = torch.full_like(input_ids, ModalityType.text.value) - inputs_embeds, modality_tensor = self.embed_multimodal_inputs( - input_ids=input_ids, - modality_tensor=modality_tensor, - vision_patches=vision_patches, - vision_patch_attention_mask=vision_patch_attention_mask, - vision_token_grids=vision_token_grids, - vision_token_offsets=vision_token_offsets, - vision_token_lengths=vision_token_lengths, - vision_image_attention_mask=vision_image_attention_mask, - ) - else: - inputs_embeds = self.text_model.embed_tokens(input_ids) - - if modality_tensor is None: - batch_size, seq_len = inputs_embeds.shape[:2] - modality_tensor = torch.full( - (batch_size, seq_len), ModalityType.text.value, device=inputs_embeds.device, dtype=torch.long - ) - - device = inputs_embeds.device - batch_size, seq_len = inputs_embeds.shape[:2] - - if use_cache and past_key_values is None: - past_key_values = DynamicCache(config=self.config.get_text_config()) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=device) - - if attention_mask is None: - attention_mask = torch.ones(inputs_embeds.shape[:2], device=inputs_embeds.device, dtype=torch.long) - - position_ids, rope_deltas = self.get_rope_index( - position_ids=position_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - cache_position=cache_position, - ) - self.rope_deltas = rope_deltas - - cos, sin = self.rotary_emb(position_ids, modality_tensor, hidden_states=inputs_embeds) - - decoder_position_ids = position_ids[..., 0] if position_ids.ndim == 3 else position_ids - - if not isinstance(attention_mask, dict): - attention_mask = create_masks_for_generate( - config=self.config, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - cache_position=cache_position, - past_key_values=past_key_values, - position_ids=decoder_position_ids, - ) - - is_mask_dict = isinstance(attention_mask, dict) - hidden_states = inputs_embeds - - for layer in self.text_model.layers: - layer_mask = attention_mask[layer.attention_type] if is_mask_dict else attention_mask - layer_outputs = layer( - hidden_states, - attention_mask=layer_mask, - position_ids=decoder_position_ids, - past_key_values=past_key_values, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=(cos, sin), - **kwargs, - ) - - hidden_states = layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs - - hidden_states = self.final_norm(hidden_states) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, - ) - @use_kernel_forward_from_hub("RMSNorm") -class IsaacRMSNorm(nn.Module): +class IsaacTextRMSNorm(nn.Module): def __init__(self, hidden_size, eps: float = 1e-6) -> None: """ - IsaacRMSNorm is equivalent to T5LayerNorm + IsaacTextRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) @@ -1001,6 +663,18 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + @use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -1027,23 +701,11 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): return q_embed, k_embed -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - @use_kernelized_func(apply_rotary_pos_emb) -class IsaacAttention(nn.Module): +class IsaacTextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: IsaacConfig, layer_idx: int): + def __init__(self, config: IsaacTextConfig, layer_idx: int): super().__init__() self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None self.config = config @@ -1066,9 +728,10 @@ def __init__(self, config: IsaacConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - self.q_norm = IsaacRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! - self.k_norm = IsaacRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape - self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + self.q_norm = IsaacTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = IsaacTextRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # thus post q_norm does not need reshape def forward( self, @@ -1076,7 +739,6 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor | None, past_key_values: Cache | None = None, - cache_position: torch.LongTensor | None = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None]: input_shape = hidden_states.shape[:-1] @@ -1090,9 +752,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward @@ -1106,7 +766,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - sliding_window=self.sliding_window, # diff with Llama **kwargs, ) @@ -1115,27 +774,41 @@ def forward( return attn_output, attn_weights -class IsaacDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: IsaacConfig, layer_idx: int): +class IsaacTextMLP(nn.Module): + def __init__(self, config): super().__init__() + self.config = config self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] - self.self_attn = IsaacAttention(config=config, layer_idx=layer_idx) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj - self.mlp = IsaacMLP(config) - self.input_layernorm = IsaacRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = IsaacRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] + +class IsaacTextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: IsaacTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = IsaacTextAttention(config=config, layer_idx=layer_idx) + + self.mlp = IsaacTextMLP(config) + self.input_layernorm = IsaacTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = IsaacTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, use_cache: bool | None = False, - cache_position: torch.LongTensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states @@ -1147,7 +820,6 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, - cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) @@ -1161,103 +833,805 @@ def forward( return hidden_states +class IsaacVisionRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + @auto_docstring class IsaacPreTrainedModel(PreTrainedModel): config: IsaacConfig base_model_prefix = "model" + input_modalities = ("image", "video", "text") supports_gradient_checkpointing = True - _no_split_modules = ["IsaacDecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] + _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionBlock"] + _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True - _supports_flex_attn = True _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { - "hidden_states": IsaacDecoderLayer, - "attentions": IsaacAttention, + "hidden_states": IsaacTextDecoderLayer, + "attentions": IsaacTextAttention, } + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, IsaacVisionRotaryEmbedding): + inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim)) + init.copy_(module.inv_freq, inv_freq) -@auto_docstring -class IsaacForConditionalGeneration(IsaacPreTrainedModel, GenerationMixin): - _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"} - _tp_plan = {"lm_head": "colwise_gather_output"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - config_class = IsaacConfig - _can_compile_fullgraph = False - all_tied_weights_keys: dict[str, str] = {"lm_head.weight": "model.text_model.embed_tokens.weight"} - def __init__(self, config: IsaacConfig): - super().__init__(config) - self.model = IsaacModel(config) +@auto_docstring( + custom_intro=( + "Text part of Isaac, " + "not a pure text-only model, as DeepStack integrates visual features into the early hidden states." + ) +) +class IsaacTextModel(IsaacPreTrainedModel): + config: IsaacTextConfig + input_modalities = ("text",) + _no_split_modules = ["IsaacTextDecoderLayer"] + + def __init__(self, config: IsaacTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [IsaacTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = IsaacTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = IsaacRotaryEmbedding(config=config, device=self.device) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + # args for deepstack + visual_pos_masks: torch.Tensor | None = None, + deepstack_visual_embeds: list[torch.Tensor] | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple | BaseModelOutputWithPast: + r""" + visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*): + The mask of the visual positions. + deepstack_visual_embeds (`list[torch.Tensor]`, *optional*): + The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim). + The feature is extracted from the different visual encoder layers, and fed to the decoder + hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334). + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # the hard coded `4` is for text, temporal, height and width. + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.view(1, 1, -1).expand(4, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(4, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = None + + attention_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + for layer_idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=text_position_ids, + past_key_values=past_key_values, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = layer_outputs + + # add visual features to the hidden states of first several layers + if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + def _deepstack_process( + self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor + ): + visual_pos_masks = visual_pos_masks.to(hidden_states.device) + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + hidden_states = hidden_states.clone() + local_this = hidden_states[visual_pos_masks, :] + visual_embeds + hidden_states[visual_pos_masks, :] = local_this + return hidden_states + + +@use_kernel_forward_from_hub("RMSNorm") +class IsaacRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + IsaacRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +@use_kernelized_func(apply_rotary_pos_emb) +class IsaacAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: IsaacConfig, layer_idx: int): + super().__init__() + self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = IsaacRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = IsaacRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape + self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class IsaacDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: IsaacConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = IsaacAttention(config=config, layer_idx=layer_idx) + + self.mlp = IsaacMLP(config) + self.input_layernorm = IsaacRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = IsaacRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class IsaacModel(PreTrainedModel): + config: IsaacConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["IsaacDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = False + _can_compile_fullgraph = False + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": IsaacDecoderLayer, + "attentions": IsaacAttention, + } + _tied_weights_keys = {} + + def __init__(self, config: IsaacConfig): + Qwen3PreTrainedModel.__init__(self, config) + self.text_model = IsaacTextModel._from_config(config.text_config) + + self.vision_tower = IsaacVisionTransformer(config.vision_config) + self.multimodal_projector = IsaacMultiModalProjector(config) + self.max_sequence_length = config.max_sequence_length + self.vision_rescale_factor = config.vision_rescale_factor + self.vision_token = config.vision_token + self.rope_deltas = None + + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value: nn.Module) -> None: + self.text_model.set_input_embeddings(value) + + @can_return_tuple @auto_docstring + def get_image_features( + self, + pixel_values: torch.Tensor, + image_token_grids: torch.Tensor, + image_patch_attention_mask: torch.Tensor | None = None, + image_token_offsets: torch.Tensor | None = None, + image_token_lengths: torch.Tensor | None = None, + image_attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + """ + Args: + pixel_values (`torch.Tensor`): + Padded per-image patch vectors with shape `(batch_size, max_images, max_patches, patch_dim)`. + image_token_grids (`torch.Tensor`): + Per-image token grids shaped `(batch_size, max_images, 2)` with `(height, width)` entries. + image_patch_attention_mask (`torch.Tensor`, *optional*): + Mask for valid patch rows in `pixel_values`, shaped `(batch_size, max_images, max_patches)`. + image_token_offsets (`torch.Tensor`, *optional*): + Start offsets inside each per-image embedding sequence, shaped `(batch_size, max_images)`. + image_token_lengths (`torch.Tensor`, *optional*): + Number of image tokens to gather per image for placeholder scattering, shaped `(batch_size, max_images)`. + image_attention_mask (`torch.Tensor`, *optional*): + Mask indicating which image slots are populated, shaped `(batch_size, max_images)`. + """ + device = self.text_model.embed_tokens.weight.device + pixel_values = pixel_values.to(device=device) + image_token_grids = image_token_grids.to(device=device, dtype=torch.long) + patch_attention_mask = image_patch_attention_mask.to(device=device, dtype=torch.long) + if image_attention_mask is None: + if image_token_lengths is not None: + image_attention_mask = image_token_lengths.to(device=device, dtype=torch.long) > 0 + else: + image_attention_mask = image_token_grids.any(dim=-1) + else: + image_attention_mask = image_attention_mask.to(device=device, dtype=torch.bool) + + batch_size, max_images = pixel_values.shape[:2] + hidden_size = self.config.get_text_config().hidden_size + + if image_attention_mask.any(): + vision_kwargs = { + key: value + for key in ("output_hidden_states", "output_attentions") + if (value := kwargs.get(key)) is not None + } + vision_outputs = self.vision_tower( + vision_patches=pixel_values[image_attention_mask], + vision_token_grids=image_token_grids[image_attention_mask], + vision_patch_attention_mask=patch_attention_mask[image_attention_mask], + return_dict=True, + **vision_kwargs, + ) + flat_projected_features = self.multimodal_projector(vision_outputs.last_hidden_state) + max_tokens = flat_projected_features.shape[1] + projected_features = flat_projected_features.new_zeros((batch_size, max_images, max_tokens, hidden_size)) + projected_features[image_attention_mask] = flat_projected_features + offsets = ( + image_token_offsets.to(device=device, dtype=torch.long) + if image_token_offsets is not None + else torch.zeros((batch_size, max_images), device=device, dtype=torch.long) + ) + lengths = ( + image_token_lengths.to(device=device, dtype=torch.long) + if image_token_lengths is not None + else torch.full((batch_size, max_images), max_tokens, device=device, dtype=torch.long) + ) + flat_offsets = offsets[image_attention_mask] + flat_lengths = lengths[image_attention_mask] + token_positions = torch.arange(flat_lengths.max(), device=device, dtype=torch.long) + gather_positions = flat_offsets[:, None] + token_positions[None, :] + gather_mask = token_positions[None, :] < flat_lengths[:, None] + image_features = flat_projected_features[ + torch.arange(flat_projected_features.shape[0], device=device, dtype=torch.long)[:, None], + gather_positions, + ][gather_mask] + hidden_states = vision_outputs.hidden_states + attentions = vision_outputs.attentions + else: + projected_features = pixel_values.new_zeros((batch_size, max_images, 0, hidden_size)) + image_features = pixel_values.new_zeros((0, hidden_size)) + hidden_states = None + attentions = None + + return BaseModelOutputWithPooling( + last_hidden_state=projected_features, + pooler_output=image_features, + hidden_states=hidden_states, + attentions=attentions, + ) + + def get_placeholder_mask( + self, + mm_token_type_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor, + ) -> torch.BoolTensor: + image_token_mask = mm_token_type_ids.to(device=inputs_embeds.device, dtype=torch.long) == 1 + n_image_tokens = image_token_mask.sum() + image_token_mask = image_token_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[image_token_mask].numel() == image_features.numel(), + f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", + ) + return image_token_mask + + def get_vision_position_ids( + self, + start_position: int, + grid_hw: torch.LongTensor, + token_offset: int, + token_length: int, + ) -> torch.LongTensor: + height, width = grid_hw[0].item(), grid_hw[1].item() + token_positions = torch.arange(height * width, device=grid_hw.device, dtype=torch.long) + vision_position_ids = torch.stack( + ( + torch.full((token_positions.shape[0],), start_position, device=grid_hw.device, dtype=torch.long), + token_positions.div(width, rounding_mode="floor"), + token_positions.remainder(width), + ), + dim=0, + ) + return vision_position_ids[:, token_offset : token_offset + token_length] + + def get_rope_index( + self, + mm_token_type_ids: torch.Tensor, + image_token_grids: torch.Tensor, + image_token_offsets: torch.Tensor, + image_token_lengths: torch.Tensor, + attention_mask: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare multimodal RoPE positions for the current prefill sequence. + + Unlike vanilla 1D RoPE, Isaac builds 3-axis indices for text and vision tokens. + If callers do not supply positions, we synthesize text-style positions from + `attention_mask`. The returned `rope_deltas` capture any custom offset between + the attended sequence length and Isaac's multimodal positions so decode steps can + keep counting forward from the cached prefix.""" + + device = attention_mask.device + batch_size, seq_len = attention_mask.shape + mm_token_type_ids = mm_token_type_ids.to(device=device, dtype=torch.long) + image_token_grids = image_token_grids.to(device=device, dtype=torch.long) + image_token_offsets = image_token_offsets.to(device=device, dtype=torch.long) + image_token_lengths = image_token_lengths.to(device=device, dtype=torch.long) + attention_mask = attention_mask.to(device=device, dtype=torch.long) + image_attention_mask = image_token_lengths > 0 + + position_ids = torch.zeros((3, batch_size, seq_len), device=device, dtype=torch.long) + rope_deltas = torch.zeros((batch_size, 1), device=device, dtype=torch.long) + pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor + + for batch_idx in range(batch_size): + sample_attention_mask = attention_mask[batch_idx].bool() + sample_token_types = mm_token_type_ids[batch_idx][sample_attention_mask] + sample_grids = image_token_grids[batch_idx][image_attention_mask[batch_idx]] + sample_offsets = image_token_offsets[batch_idx][image_attention_mask[batch_idx]] + sample_lengths = image_token_lengths[batch_idx][image_attention_mask[batch_idx]] + + current_pos = 0 + image_idx = 0 + seq_pos = 0 + llm_pos_ids_list = [] + + while seq_pos < sample_token_types.shape[0]: + modality_type = int(sample_token_types[seq_pos].item()) + group_end = seq_pos + 1 + while group_end < sample_token_types.shape[0] and sample_token_types[group_end] == modality_type: + group_end += 1 + + group_length = group_end - seq_pos + if modality_type == 0: + llm_pos_ids_list.append( + torch.arange(group_length, device=device, dtype=torch.long).view(1, -1).expand(3, -1) + + current_pos + ) + current_pos += group_length + else: + grid_hw = sample_grids[image_idx].div(pixel_shuffle_scale, rounding_mode="floor") + token_offset = int(sample_offsets[image_idx].item()) + token_length = int(sample_lengths[image_idx].item()) + llm_pos_ids_list.append( + self.get_vision_position_ids(current_pos, grid_hw, token_offset, token_length) + ) + current_pos += 1 + image_idx += 1 + + seq_pos = group_end + + llm_positions = torch.cat(llm_pos_ids_list, dim=1) + position_ids[:, batch_idx, sample_attention_mask] = llm_positions + rope_deltas[batch_idx, 0] = llm_positions.max() + 1 - sample_token_types.shape[0] + + return position_ids, rope_deltas + + def compute_3d_position_ids( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor | None, + mm_token_type_ids: torch.Tensor | None = None, + image_token_grids: torch.Tensor | None = None, + image_token_offsets: torch.Tensor | None = None, + image_token_lengths: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values: torch.Tensor | None = None, + ) -> torch.Tensor: + past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() + + if image_token_lengths is not None and image_token_lengths.gt(0).any() and past_seen_tokens == 0: + position_ids, rope_deltas = self.get_rope_index( + mm_token_type_ids=mm_token_type_ids, + image_token_grids=image_token_grids, + image_token_offsets=image_token_offsets, + image_token_lengths=image_token_lengths, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + ) + self.rope_deltas = rope_deltas + return position_ids + + if position_ids is not None and past_seen_tokens == 0: + position_ids = position_ids.to(device=inputs_embeds.device) + if position_ids.ndim == 2: + return position_ids.view(1, position_ids.shape[0], -1).expand(3, -1, -1) + if position_ids.ndim == 3 and position_ids.shape[0] in (1, 4): + return position_ids + + if self.rope_deltas is None: + return None + + rope_deltas = torch.as_tensor(self.rope_deltas, device=inputs_embeds.device, dtype=torch.long).reshape(-1, 1) + if rope_deltas.shape[0] != inputs_embeds.shape[0]: + if inputs_embeds.shape[0] % rope_deltas.shape[0] == 0: + rope_deltas = rope_deltas.repeat_interleave(inputs_embeds.shape[0] // rope_deltas.shape[0], dim=0) + else: + rope_deltas = rope_deltas[:1].expand(inputs_embeds.shape[0], -1) + + if attention_mask is not None and attention_mask.shape[-1] > inputs_embeds.shape[1]: + rope_position = attention_mask.long().cumsum(dim=-1) - 1 + rope_position = rope_position.masked_fill(attention_mask == 0, 0) + rope_position = rope_position[:, -inputs_embeds.shape[1] :] + else: + rope_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + dtype=torch.long, + ).view(1, -1) + rope_position = rope_position.expand(inputs_embeds.shape[0], -1) + + position_ids = rope_position.view(1, inputs_embeds.shape[0], -1).expand(3, -1, -1) + return position_ids + rope_deltas.to(device=inputs_embeds.device).unsqueeze(0) + + @auto_docstring( + custom_intro=""" + Forward pass with multimodal MRoPE position ids. + + When image placeholders are present, Isaac computes vision features, scatters them into the token + embeddings, and runs the shared text backbone on the mixed sequence. + """, + ) @can_return_tuple @merge_with_config_defaults def forward( self, input_ids: torch.LongTensor | None = None, - modality_tensor: torch.LongTensor | None = None, + mm_token_type_ids: torch.LongTensor | None = None, vision_patches: torch.Tensor | None = None, vision_patch_attention_mask: torch.Tensor | None = None, + image_patch_attention_mask: torch.Tensor | None = None, vision_token_grids: torch.LongTensor | None = None, + image_token_grids: torch.LongTensor | None = None, vision_token_offsets: torch.LongTensor | None = None, vision_token_lengths: torch.LongTensor | None = None, - vision_image_attention_mask: torch.LongTensor | None = None, + image_attention_mask: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPast: + """ + Args: + mm_token_type_ids (`torch.LongTensor`, *optional*): + Multimodal token type ids aligned with the embedded sequence, shaped `(batch_size, seq_len)`. Isaac + follows the standard convention `0 -> text`, `1 -> image`. Treated as text-only when omitted. + vision_patches (`torch.FloatTensor`, *optional*): + Padded per-image patch vectors of shape `(batch_size, max_images, max_patches, patch_dim)`. + vision_patch_attention_mask (`torch.LongTensor`, *optional*): + Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. + image_patch_attention_mask (`torch.LongTensor`, *optional*): + Alias for `vision_patch_attention_mask`. + vision_token_grids (`torch.LongTensor`, *optional*): + Per-image patch grids `(h, w)` with shape `(batch_size, max_images, 2)`. + image_token_grids (`torch.LongTensor`, *optional*): + Alias for `vision_token_grids`. + vision_token_offsets (`torch.LongTensor`, *optional*): + Start offsets inside the per-image vision embedding sequence, shape `(batch_size, max_images)`. + vision_token_lengths (`torch.LongTensor`, *optional*): + Number of vision tokens to consume per image, shape `(batch_size, max_images)`. + image_attention_mask (`torch.LongTensor`, *optional*): + Backward-compatible override for populated image slots. When omitted, the model derives it from + `vision_token_lengths > 0`. + """ + created_inputs_embeds = inputs_embeds is None + if created_inputs_embeds: + inputs_embeds = self.text_model.embed_tokens(input_ids) + + if mm_token_type_ids is None: + batch_size, seq_len = inputs_embeds.shape[:2] + mm_token_type_ids = torch.full((batch_size, seq_len), 0, device=inputs_embeds.device, dtype=torch.long) + else: + mm_token_type_ids = mm_token_type_ids.to(device=inputs_embeds.device, dtype=torch.long) + + image_token_mask = mm_token_type_ids == 1 + if created_inputs_embeds and torch.any(image_token_mask): + image_outputs = self.get_image_features( + pixel_values=vision_patches, + image_token_grids=vision_token_grids, + image_patch_attention_mask=vision_patch_attention_mask, + image_token_offsets=vision_token_offsets, + image_token_lengths=vision_token_lengths, + image_attention_mask=image_attention_mask, + return_dict=True, + ) + image_features = image_outputs.pooler_output.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype) + scatter_mask = self.get_placeholder_mask( + mm_token_type_ids=mm_token_type_ids, + inputs_embeds=inputs_embeds, + image_features=image_features, + ) + inputs_embeds = inputs_embeds.masked_scatter(scatter_mask, image_features) + + if isinstance(attention_mask, dict): + attention_mask = attention_mask.get("full_attention", next(iter(attention_mask.values()))) + + position_ids = self.compute_3d_position_ids( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + mm_token_type_ids=mm_token_type_ids, + image_token_grids=vision_token_grids, + image_token_offsets=vision_token_offsets, + image_token_lengths=vision_token_lengths, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + ) + + text_model_outputs = self.text_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + return BaseModelOutputWithPast( + last_hidden_state=text_model_outputs.last_hidden_state, + past_key_values=text_model_outputs.past_key_values, + hidden_states=text_model_outputs.hidden_states, + attentions=text_model_outputs.attentions, + ) + + +@auto_docstring +class IsaacForConditionalGeneration(IsaacPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + config_class = IsaacConfig + _can_compile_fullgraph = False + + def __init__(self, config: IsaacConfig): + super().__init__(config) + self.model = IsaacModel(config) + self.vocab_size = config.get_text_config().vocab_size + self.lm_head = nn.Linear(config.get_text_config().hidden_size, config.get_text_config().vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor | None = None, + mm_token_type_ids: torch.LongTensor | None = None, + vision_patches: torch.Tensor | None = None, + pixel_values: torch.Tensor | None = None, + vision_patch_attention_mask: torch.Tensor | None = None, + image_patch_attention_mask: torch.Tensor | None = None, + vision_token_grids: torch.LongTensor | None = None, + image_token_grids: torch.LongTensor | None = None, + vision_token_offsets: torch.LongTensor | None = None, + vision_token_lengths: torch.LongTensor | None = None, + image_attention_mask: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: list[torch.FloatTensor] | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, - cache_position: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | CausalLMOutputWithPast: r""" - modality_tensor (`torch.LongTensor`, *optional*): - Modality identifiers aligned with the token sequence, shaped `(batch_size, seq_len)`. + mm_token_type_ids (`torch.LongTensor`, *optional*): + Multimodal token type ids aligned with the token sequence, shaped `(batch_size, seq_len)`, using + `0 -> text` and `1 -> image`. vision_patches (`torch.FloatTensor`, *optional*): Padded per-image patch vectors of shape `(batch_size, max_images, max_patches, patch_dim)`. + pixel_values (`torch.FloatTensor`, *optional*): + Alias for `vision_patches` accepted by generic image-feature and generation helpers. vision_patch_attention_mask (`torch.LongTensor`, *optional*): Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. + image_patch_attention_mask (`torch.LongTensor`, *optional*): + Alias for `vision_patch_attention_mask`. vision_token_grids (`torch.LongTensor`, *optional*): Per-image patch grids `(h, w)` with shape `(batch_size, max_images, 2)`. + image_token_grids (`torch.LongTensor`, *optional*): + Alias for `vision_token_grids`. vision_token_offsets (`torch.LongTensor`, *optional*): Start offsets inside the per-image vision embedding sequence, shape `(batch_size, max_images)`. vision_token_lengths (`torch.LongTensor`, *optional*): Number of vision tokens to consume per image, shape `(batch_size, max_images)`. - vision_image_attention_mask (`torch.LongTensor`, *optional*): - Mask indicating which image slots are populated, shape `(batch_size, max_images)`. + image_attention_mask (`torch.LongTensor`, *optional*): + Backward-compatible override for populated image slots. When omitted, the model derives it from + `vision_token_lengths > 0`. """ outputs = self.model( input_ids=input_ids, - modality_tensor=modality_tensor, + mm_token_type_ids=mm_token_type_ids, vision_patches=vision_patches, vision_patch_attention_mask=vision_patch_attention_mask, vision_token_grids=vision_token_grids, vision_token_offsets=vision_token_offsets, vision_token_lengths=vision_token_lengths, - vision_image_attention_mask=vision_image_attention_mask, + image_attention_mask=image_attention_mask, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - cache_position=cache_position, **kwargs, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=logits.shape[-1]) return CausalLMOutputWithPast( loss=loss, @@ -1273,59 +1647,117 @@ def prepare_inputs_for_generation( past_key_values: list[torch.FloatTensor] | None = None, attention_mask: torch.Tensor | None = None, inputs_embeds: torch.FloatTensor | None = None, - modality_tensor: torch.LongTensor | None = None, + mm_token_type_ids: torch.LongTensor | None = None, vision_patches: torch.Tensor | None = None, + pixel_values: torch.Tensor | None = None, vision_patch_attention_mask: torch.Tensor | None = None, + image_patch_attention_mask: torch.Tensor | None = None, vision_token_grids: torch.LongTensor | None = None, + image_token_grids: torch.LongTensor | None = None, vision_token_offsets: torch.LongTensor | None = None, vision_token_lengths: torch.LongTensor | None = None, - vision_image_attention_mask: torch.LongTensor | None = None, - cache_position: torch.LongTensor | None = None, + image_attention_mask: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, + is_first_iteration=False, + use_cache=True, **kwargs, ) -> dict[str, Any]: + if vision_patches is None: + vision_patch_attention_mask = ( + image_patch_attention_mask if vision_patch_attention_mask is None else vision_patch_attention_mask + ) + vision_token_grids = image_token_grids if vision_token_grids is None else vision_token_grids + if position_ids is None or position_ids.ndim == 2: + position_ids = self._prepare_position_ids_for_generation( + input_ids, + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "mm_token_type_ids": mm_token_type_ids, + "vision_token_grids": vision_token_grids, + "vision_token_offsets": vision_token_offsets, + "vision_token_lengths": vision_token_lengths, + }, + ) model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, - cache_position=cache_position, - position_ids=None, + position_ids=position_ids, + is_first_iteration=is_first_iteration, + use_cache=use_cache, **kwargs, ) multimodal_inputs = { - "modality_tensor": modality_tensor, + "mm_token_type_ids": mm_token_type_ids, "vision_patches": vision_patches, "vision_patch_attention_mask": vision_patch_attention_mask, "vision_token_grids": vision_token_grids, "vision_token_offsets": vision_token_offsets, "vision_token_lengths": vision_token_lengths, - "vision_image_attention_mask": vision_image_attention_mask, } - if not any(value is not None for value in multimodal_inputs.values()): - return model_inputs - - past_len = past_key_values.get_seq_length() if past_key_values is not None else 0 - first_step = past_len == 0 + is_prefill = is_first_iteration or not use_cache for key, value in multimodal_inputs.items(): - model_inputs[key] = value if first_step else None - model_inputs["position_ids"] = position_ids if first_step else None + model_inputs[key] = value if is_prefill else None return model_inputs - @classmethod - def can_generate(cls) -> bool: - return True + def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs): + text_positions = super()._prepare_position_ids_for_generation(inputs_tensor, model_kwargs) + + past_length = 0 + if (cache := model_kwargs.get("past_key_values")) is not None: + past_length = cache.get_seq_length() + if past_length != 0 and self.model.rope_deltas is not None: + return text_positions[None, ...] + self.model.rope_deltas + + if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0: + inputs_tensor = model_kwargs["input_ids"] + + if ( + model_kwargs.get("image_token_lengths") is not None + and len(inputs_tensor.shape) == 2 + and inputs_tensor.dtype in [torch.int, torch.long] + ): + vision_positions, rope_deltas = self.model.get_rope_index( + mm_token_type_ids=model_kwargs["mm_token_type_ids"], + image_token_grids=model_kwargs["vision_token_grids"], + image_token_offsets=model_kwargs["vision_token_offsets"], + image_token_lengths=model_kwargs["vision_token_lengths"], + attention_mask=model_kwargs.get("attention_mask"), + ) + self.model.rope_deltas = rope_deltas + else: + vision_positions = text_positions.unsqueeze(0).expand(3, -1, -1) + self.model.rope_deltas = torch.zeros( + inputs_tensor.shape[0], 1, dtype=torch.long, device=inputs_tensor.device + ) - def set_input_embeddings(self, value: nn.Module) -> None: - self.model.set_input_embeddings(value) - vocab_size = getattr(value, "num_embeddings", None) - self.config.vocab_size = vocab_size - self.model.config.vocab_size = vocab_size - self.model.text_model.config.vocab_size = vocab_size - if self.lm_head.weight.shape[0] != vocab_size: - self.lm_head = nn.Linear(self.config.hidden_size, vocab_size, bias=False) - self.lm_head.weight = self.model.text_model.embed_tokens.weight + return torch.cat([text_positions[None, ...], vision_positions], dim=0) + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: torch.LongTensor | None = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + position_ids = model_kwargs.pop("position_ids", None) + input_ids, model_kwargs = super()._expand_inputs_for_generation( + expand_size=expand_size, + is_encoder_decoder=is_encoder_decoder, + input_ids=input_ids, + **model_kwargs, + ) + if position_ids is not None: + dim = 1 if position_ids.ndim == 3 else 0 + model_kwargs["position_ids"] = position_ids.repeat_interleave(expand_size, dim=dim) + return input_ids, model_kwargs + + def get_input_embeddings(self) -> nn.Module: + return self.model.get_input_embeddings() -__all__ = ["IsaacModel", "IsaacPreTrainedModel", "IsaacForConditionalGeneration"] +__all__ = ["IsaacTextModel", "IsaacModel", "IsaacPreTrainedModel", "IsaacForConditionalGeneration"] diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 851f9b4bc339..0264896bcfd4 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -1,4 +1,4 @@ -# Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved. +# Copyright 2026 Perceptron, Inc and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,13 +16,12 @@ import copy import math +import re from collections.abc import Sequence -from enum import IntEnum -from typing import Any +from typing import Any, NamedTuple from ... import initialization as init -from ...cache_utils import DynamicCache -from ...configuration_utils import PretrainedConfig, layer_type_validation +from ...configuration_utils import PretrainedConfig from ...feature_extraction_utils import BatchFeature from ...generation.utils import GenerationMixin from ...image_processing_utils_fast import ( @@ -33,32 +32,36 @@ reorder_images, ) from ...image_utils import ( + ImageInput, PILImageResampling, + make_nested_list_of_images, ) -from ...masking_utils import create_bidirectional_mask, create_masks_for_generate -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...masking_utils import create_bidirectional_mask +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...models.qwen3.configuration_qwen3 import Qwen3Config from ...models.qwen3.modeling_qwen3 import ( - Qwen3Attention, - Qwen3DecoderLayer, Qwen3ForCausalLM, - Qwen3Model, Qwen3PreTrainedModel, ) -from ...processing_utils import ProcessorMixin, Unpack +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...utils import TensorType, auto_docstring, torch_compilable_check from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD -from ...utils.generic import TransformersKwargs, can_return_tuple, maybe_autocast, merge_with_config_defaults +from ...utils.generic import TransformersKwargs, can_return_tuple, merge_with_config_defaults from ...utils.import_utils import ( is_torch_available, is_torchdynamo_compiling, is_torchvision_available, is_vision_available, ) -from ...utils.output_capturing import OutputRecorder, capture_outputs -from ..qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding +from ...utils.output_capturing import capture_outputs +from ..qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLTextAttention, + Qwen3VLTextDecoderLayer, + Qwen3VLTextModel, + Qwen3VLTextRotaryEmbedding, +) from ..siglip2.configuration_siglip2 import Siglip2VisionConfig from ..siglip2.modeling_siglip2 import ( Siglip2Attention, @@ -80,17 +83,93 @@ from ..pix2struct.image_processing_pix2struct_fast import torch_extract_patches -class ModalityType(IntEnum): - """ - Modality identifiers for events. +class SinglePoint(NamedTuple): + x: int + y: int + mention: str | None = None + t: float | None = None + + +class BoundingBox(NamedTuple): + top_left: SinglePoint + bottom_right: SinglePoint + mention: str | None = None + t: float | None = None + + +_POINT_OR_BOX_TAG = re.compile( + r"<(?Ppoint|point_box)(?P[^>]*)>(?P[\s\S]*?)", re.IGNORECASE +) +_ATTR_RE = re.compile(r"(\w+)\s*=\s*(?:\"([^\"]*)\"|([^\s>]+))") +_COORD_RE = re.compile(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)") - Members: - image: Vision tokens (e.g., patches). - text: Textual tokens. - """ - image = 0 - text = 1 +class IsaacProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": True, + "return_attention_mask": True, + }, + } + + +def _maybe_float(value: str | None) -> float | None: + if value is None: + return None + try: + return float(value) + except ValueError: + return None + + +def _parse_attrs(attr_text: str) -> dict[str, str]: + attrs = {} + for match in _ATTR_RE.finditer(attr_text or ""): + key = match.group(1) + value = match.group(2) or match.group(3) or "" + attrs[key] = value + return attrs + + +def _parse_point_body(body: str, mention: str | None = None, t: str | None = None) -> SinglePoint: + match = _COORD_RE.search(body) + if not match: + raise ValueError(f"Malformed tag: {body!r}") + x, y = int(match.group(1)), int(match.group(2)) + return SinglePoint(x=x, y=y, mention=mention, t=_maybe_float(t)) + + +def _parse_box_body(body: str, mention: str | None = None, t: str | None = None) -> BoundingBox: + coords = list(_COORD_RE.finditer(body)) + if len(coords) < 2: + raise ValueError(f"Malformed tag: {body!r}") + + top_left = SinglePoint(x=int(coords[0].group(1)), y=int(coords[0].group(2))) + bottom_right = SinglePoint(x=int(coords[1].group(1)), y=int(coords[1].group(2))) + return BoundingBox(top_left=top_left, bottom_right=bottom_right, mention=mention, t=_maybe_float(t)) + + +def clean_text_and_extract_points( + text: str, + expected: str | None = None, +) -> tuple[str, list[SinglePoint | BoundingBox]]: + results = [] + for match in _POINT_OR_BOX_TAG.finditer(text or ""): + tag = match.group("tag").lower() + attrs = _parse_attrs(match.group("attrs")) + mention = attrs.get("mention") + t = attrs.get("t") + if tag == "point": + if expected not in (None, "point"): + continue + results.append(_parse_point_body(match.group("body"), mention=mention, t=t)) + else: + if expected not in (None, "box"): + continue + results.append(_parse_box_body(match.group("body"), mention=mention, t=t)) + + clean_text = re.sub(r"\s+", " ", _POINT_OR_BOX_TAG.sub("", text or "")).strip() + return clean_text, results class IsaacVisionConfig(Siglip2VisionConfig): @@ -101,8 +180,6 @@ class IsaacVisionConfig(Siglip2VisionConfig): Args: pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1): Spatial factor applied before pixel shuffle reduces the resolution. - num_patches (`int`, *optional*, defaults to 256): - Maximum number of learnable positional embeddings to initialize. """ model_type = "isaac_vision" @@ -110,38 +187,19 @@ class IsaacVisionConfig(Siglip2VisionConfig): def __init__( self, - hidden_size=768, - intermediate_size=3072, - num_hidden_layers=12, - num_attention_heads=12, - num_channels=3, - num_patches=256, - patch_size=16, - hidden_act="gelu_pytorch_tanh", - layer_norm_eps=1e-6, - attention_dropout=0.0, pixel_shuffle_scale_factor=1, - **kwargs, + **super_kwargs, ): - super().__init__(**kwargs) - - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_channels = num_channels - self.patch_size = patch_size - self.attention_dropout = attention_dropout - self.layer_norm_eps = layer_norm_eps - self.hidden_act = hidden_act - self.num_patches = num_patches - + super().__init__(**super_kwargs) # Add our custom fields self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor - # Ensure a sensible default attention backend - if getattr(self, "_attn_implementation", None) is None: - self._attn_implementation = "sdpa" + +class IsaacTextConfig(Qwen3Config): + model_type = "isaac_text" + + def __init__(self, **super_kwargs): + super().__init__(ignore_keys_at_rope_validation={"mrope_section", "mrope_interleaved"}, **super_kwargs) class IsaacImageProcessorFastKwargs(ImagesKwargs, total=False): @@ -156,10 +214,10 @@ class IsaacImageProcessorFastKwargs(ImagesKwargs, total=False): Pixel-shuffle reduction factor applied in the vision tower. """ - patch_size: int | None - max_num_patches: int | None - min_num_patches: int | None - pixel_shuffle_scale: int | None + patch_size: int + max_num_patches: int + min_num_patches: int + pixel_shuffle_scale: int @auto_docstring @@ -167,7 +225,11 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): MAX_PIXELS = 60_000_000 # 60โ€‘megapixel ceiling โ‰ˆ 8200 ร— 7300 px resample = PILImageResampling.BILINEAR - model_input_names = ["patches", "token_grids"] + model_input_names = [ + "vision_patches", + "vision_patch_attention_mask", + "vision_token_grids", + ] valid_kwargs = IsaacImageProcessorFastKwargs unused_kwargs = ["size", "do_center_crop", "crop_size", "pad_size", "do_pad"] @@ -190,6 +252,14 @@ def _validate_preprocess_kwargs(self, **kwargs): kwargs.pop("do_resize", None) return super()._validate_preprocess_kwargs(**kwargs) + def _prepare_images_structure( + self, + images: ImageInput, + expected_ndims: int = 3, + ) -> ImageInput: + images = self.fetch_images(images) + return make_nested_list_of_images(images, expected_ndims=expected_ndims) + def resize( self, image: torch.Tensor, @@ -200,7 +270,7 @@ def resize( def _preprocess( self, - images: list[torch.Tensor], + images: list[list[torch.Tensor]], do_resize: bool, interpolation: Any | None, do_rescale: bool | None, @@ -210,23 +280,29 @@ def _preprocess( image_std: float | Sequence[float] | None, disable_grouping: bool | None = None, return_tensors: str | TensorType | None = None, - *, patch_size: int | None = None, max_num_patches: int | None = None, min_num_patches: int | None = None, pixel_shuffle_scale: int | None = None, **kwargs, ) -> BatchFeature: - grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + batch_size = len(images) + if all(len(sample_images) == 0 for sample_images in images): + tensors = { + "vision_patches": torch.zeros((batch_size, 0, 0, 0), dtype=torch.float32), + "vision_patch_attention_mask": torch.zeros((batch_size, 0, 0), dtype=torch.long), + "vision_token_grids": torch.zeros((batch_size, 0, 2), dtype=torch.long), + } + return BatchFeature(data=tensors, tensor_type=return_tensors) + + grouped_images, grouped_images_index = group_images_by_shape( + images, disable_grouping=disable_grouping, is_nested=True + ) grouped_outputs = {} for shape, stacked_images in grouped_images.items(): - batch_size, channels, original_height, original_width = stacked_images.shape - - if bool(self.do_convert_rgb) and channels == 1: - stacked_images = stacked_images.repeat(1, 3, 1, 1) - + grouped_batch_size, channels, original_height, original_width = stacked_images.shape target_height, target_width = get_image_size_for_max_num_patches( original_height, original_width, @@ -256,49 +332,66 @@ def _preprocess( ) patches = torch_extract_patches(image_batch, patch_size, patch_size) - _, height_tokens, width_tokens, _ = patches.shape + _, height_tokens, width_tokens, patch_dim = patches.shape token_grid = ( - torch.tensor([height_tokens, width_tokens], device=patches.device).long().expand(batch_size, 2) - ) - - real_dim = ( - torch.tensor( - [1, height_tokens, width_tokens], - dtype=torch.long, - device=patches.device, - ) - .unsqueeze(0) - .repeat(batch_size, 1) + torch.tensor([height_tokens, width_tokens], device=patches.device).long().expand(grouped_batch_size, 2) ) if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): raise ValueError( f"Token grid (h={height_tokens}, w={width_tokens}) must be divisible by pixel_shuffle_scale={pixel_shuffle_scale}; adjust resize/patch parameters or disable pixel shuffle." ) - virtual_height = height_tokens // pixel_shuffle_scale - virtual_width = width_tokens // pixel_shuffle_scale - - virtual_dim = ( - torch.tensor( - [1, virtual_height, virtual_width], - dtype=torch.long, - device=patches.device, - ) - .unsqueeze(0) - .repeat(batch_size, 1) - ) - grouped_outputs[shape] = (patches, token_grid, virtual_dim, real_dim) - keys = ("patches", "token_grids", "virtual_pixel_size", "real_pixel_size") - tensors: dict[str, torch.Tensor] = {} + grouped_outputs[shape] = ( + patches.reshape(grouped_batch_size, -1, patch_dim), + token_grid, + ) - for i, key in enumerate(keys): - slices = reorder_images( + keys = ("vision_patches", "vision_token_grids") + nested_outputs = { + key: reorder_images( {shape: values[i] for shape, values in grouped_outputs.items()}, grouped_images_index, + is_nested=True, ) - tensors[key] = torch.stack(slices, dim=0) + for i, key in enumerate(keys) + } + + first_patch = next( + patches for sample_patches in nested_outputs["vision_patches"] for patches in sample_patches + ) + max_images = max(len(sample_patches) for sample_patches in nested_outputs["vision_patches"]) + patch_dim = first_patch.shape[-1] + patch_dtype = first_patch.dtype + patch_device = first_patch.device + max_patches = max( + patches.shape[0] for sample_patches in nested_outputs["vision_patches"] for patches in sample_patches + ) + + tensors = { + "vision_patches": torch.zeros( + (batch_size, max_images, max_patches, patch_dim), device=patch_device, dtype=patch_dtype + ), + "vision_patch_attention_mask": torch.zeros( + (batch_size, max_images, max_patches), device=patch_device, dtype=torch.long + ), + "vision_token_grids": torch.zeros((batch_size, max_images, 2), device=patch_device, dtype=torch.long), + } + + for batch_idx, sample_patches in enumerate(nested_outputs["vision_patches"]): + sample_image_count = len(sample_patches) + if sample_image_count == 0: + continue + + for image_idx, patches in enumerate(sample_patches): + patch_count = int(patches.shape[0]) + + tensors["vision_patches"][batch_idx, image_idx, :patch_count] = patches + tensors["vision_patch_attention_mask"][batch_idx, image_idx, :patch_count] = 1 + tensors["vision_token_grids"][batch_idx, image_idx] = nested_outputs["vision_token_grids"][batch_idx][ + image_idx + ] return BatchFeature(data=tensors, tensor_type=return_tensors) @@ -313,17 +406,6 @@ class IsaacVisionEmbeddings(Siglip2VisionEmbeddings): def __init__(self, config: IsaacVisionConfig): super().__init__(config) - self.config = config - self.embed_dim = config.hidden_size - self.patch_size = config.patch_size - - self.patch_embedding = nn.Linear( - in_features=config.num_channels * self.patch_size * self.patch_size, - out_features=self.embed_dim, - ) - - self.num_patches = config.num_patches - self.position_embedding_size = int(self.num_patches**0.5) self.position_embedding = nn.Parameter( torch.empty( self.position_embedding_size, @@ -331,10 +413,7 @@ def __init__(self, config: IsaacVisionConfig): self.embed_dim, ) ) - nn.init.normal_(self.position_embedding) - @merge_with_config_defaults - @capture_outputs def forward( self, pixel_values: torch.Tensor, @@ -396,10 +475,8 @@ def pixel_shuffle_padded( Spatial down-sampling factor. Returns: - Tuple of: - - pixel-shuffled embeddings `(num_images, max_tokens, hidden_size * scale_factor**2)` - - attention mask `(num_images, max_tokens)` - - per-image valid token lengths `(num_images,)` + `torch.Tensor`: Pixel-shuffled embeddings of shape + `(num_images, max_tokens, hidden_size * scale_factor**2)`. """ num_images, max_patches, embed_dim = x.shape output_dim = embed_dim * scale_factor * scale_factor @@ -448,19 +525,14 @@ class IsaacVisionTransformer(PreTrainedModel): Args: config (IsaacVisionConfig): Vision configuration with pixel-shuffle and patching parameters. - Inputs: - vision_tokens (Tuple[Tensor, Tensor, Optional[Tensor]]): - `(patches, token_grids, patch_attention_mask)` where: - - `patches`: `(num_images, max_patches, patch_dim)` - - `token_grids`: `(num_images, 2)` with per-image `(H_tokens, W_tokens)` - - `patch_attention_mask`: `(num_images, max_patches)` or `None` - - Returns: - Tuple of `(pixel_shuffled_features, attention_mask, token_lengths)`. """ _supports_sdpa = True _supports_flash_attn = True + _can_record_outputs = { + "hidden_states": IsaacVisionEncoderLayer, + "attentions": IsaacVisionAttention, + } def __init__(self, config: IsaacVisionConfig): super().__init__(config) @@ -478,18 +550,30 @@ def _init_weights(self, module): if isinstance(module, IsaacVisionEmbeddings): init.zeros_(module.position_embedding) + @merge_with_config_defaults + @capture_outputs(tie_last_hidden_states=False) def forward( self, - vision_tokens: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor | None], - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if len(vision_tokens) == 2: - seq_patches, token_grids = vision_tokens - vision_patch_attention_mask = None - else: - seq_patches, token_grids, vision_patch_attention_mask = vision_tokens + vision_patches: torch.Tensor, + vision_token_grids: torch.Tensor, + vision_patch_attention_mask: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + """ + Inputs: + vision_patches (`torch.Tensor`): + Patches shaped `(num_images, max_patches, patch_dim)`. + vision_token_grids (`torch.Tensor`): + Token grids shaped `(num_images, 2)` with per-image `(H_tokens, W_tokens)`. + vision_patch_attention_mask (`torch.Tensor`): + Patch mask shaped `(num_images, max_patches)`. + + Returns: + `BaseModelOutputWithPooling` with pixel-shuffled embeddings in `last_hidden_state`. + """ hidden_states = self.embeddings( - seq_patches, - token_grids, + vision_patches, + vision_token_grids, attention_mask=vision_patch_attention_mask, ) @@ -498,28 +582,34 @@ def forward( inputs_embeds=hidden_states, attention_mask=vision_patch_attention_mask, ) - encoder_outputs = self.encoder(inputs_embeds=hidden_states, attention_mask=encoder_attention_mask) + encoder_outputs = self.encoder(inputs_embeds=hidden_states, attention_mask=encoder_attention_mask, **kwargs) hidden_states = self.post_layernorm(encoder_outputs.last_hidden_state) - return pixel_shuffle_padded( + hidden_states = pixel_shuffle_padded( x=hidden_states, - token_grids=token_grids, + token_grids=vision_token_grids, scale_factor=self.pixel_shuffle_scale_factor, ) + return BaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=None, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + class IsaacMultiModalProjector(nn.Module): """Maps vision tower outputs to the text hidden size with a SiLU MLP.""" def __init__(self, config: IsaacConfig): super().__init__() - self.vision_hidden_size = config.vision_config.hidden_size * ( - config.vision_config.pixel_shuffle_scale_factor**2 - ) - self.backbone_hidden_size = config.hidden_size - self.linear_1 = nn.Linear(self.vision_hidden_size, 4 * self.vision_hidden_size, bias=False) + text_config = config.get_text_config() + vision_hidden_size = config.vision_config.hidden_size * (config.vision_config.pixel_shuffle_scale_factor**2) + backbone_hidden_size = text_config.hidden_size + self.linear_1 = nn.Linear(vision_hidden_size, 4 * vision_hidden_size, bias=False) self.silu = nn.SiLU() - self.linear_2 = nn.Linear(4 * self.vision_hidden_size, self.backbone_hidden_size, bias=False) + self.linear_2 = nn.Linear(4 * vision_hidden_size, backbone_hidden_size, bias=False) def forward(self, image_features): hidden_states = self.linear_1(image_features) @@ -528,24 +618,6 @@ def forward(self, image_features): return hidden_states -class IsaacVisionEmbedding(nn.Module): - def __init__(self, config: IsaacConfig): - super().__init__() - vision_cfg = config.vision_config - - self.vision_tower = IsaacVisionTransformer(vision_cfg) - self.multimodal_projector = IsaacMultiModalProjector(config) - - def forward( - self, - vision_tokens: tuple[torch.Tensor, torch.Tensor, torch.Tensor | None], - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - vision_patches, token_grids, vision_patch_attention_mask = vision_tokens - hidden_states = self.vision_tower((vision_patches, token_grids, vision_patch_attention_mask)) - projected = self.multimodal_projector(hidden_states) - return projected - - def get_scaled_image_size( scale: float, original_size: int, @@ -642,16 +714,27 @@ class IsaacConfig(PretrainedConfig): This configuration corresponds to checkpoints such as [Perceptron/isaac-base](https://huggingface.co/Perceptron/isaac-base). + + Args: + vision_config (`IsaacVisionConfig`, *optional*): + Configuration for the Isaac vision tower. If unset, the default [`IsaacVisionConfig`] is used. + text_config (`IsaacTextConfig` or `dict`, *optional*): + Configuration for the text backbone. Dictionaries are converted to [`IsaacTextConfig`]. + vision_rescale_factor (`float`, *optional*, defaults to 1 / 255): + Rescale factor applied by the image processor before normalization. + max_sequence_length (`int`, *optional*, defaults to 16384): + Maximum multimodal sequence length produced by the processor and expected by the model. + vision_token (`str`, *optional*, defaults to `""`): + Placeholder string inserted into text prompts to mark image positions. """ model_type = "isaac" - sub_configs = {"vision_config": IsaacVisionConfig, "text_config": Qwen3Config} - image_processor_type = "IsaacImageProcessor" + sub_configs = {"vision_config": IsaacVisionConfig, "text_config": IsaacTextConfig} def __init__( self, vision_config: IsaacVisionConfig | None = None, - text_config: Qwen3Config | dict | None = None, + text_config: IsaacTextConfig | dict | None = None, vision_rescale_factor: float = 1 / 255, max_sequence_length: int = 16384, vision_token: str = "", @@ -659,7 +742,7 @@ def __init__( ): if isinstance(text_config, dict): self.text_config = self.sub_configs["text_config"](**text_config) - elif isinstance(text_config, Qwen3Config): + elif isinstance(text_config, IsaacTextConfig): self.text_config = text_config elif text_config is None: self.text_config = self.sub_configs["text_config"]() @@ -671,32 +754,14 @@ def __init__( elif vision_config is None: self.vision_config = self.sub_configs["vision_config"]() - # Seed RoPE parameters before base init so the shared mixin can standardize/validate them. - self.rope_parameters = getattr(self.text_config, "rope_parameters", None) - self.layer_types = getattr(self.text_config, "layer_types", None) - super().__init__(**kwargs) - # Keep rope parameters aligned between the composite and text sub-configs. - self.text_config.rope_parameters = self.rope_parameters - - # Mirror frequently accessed Qwen3 attributes at the composite config level - self.vocab_size = self.text_config.vocab_size - self.hidden_size = self.text_config.hidden_size - self.num_hidden_layers = self.text_config.num_hidden_layers - self.num_attention_heads = self.text_config.num_attention_heads - self.head_dim = self.text_config.head_dim - self.hidden_act = self.text_config.hidden_act + # Mirror frequently accessed composite-level attributes. self.use_cache = self.text_config.use_cache - self.rope_theta = self.rope_parameters["rope_theta"] + self.rope_theta = self.text_config.rope_parameters["rope_theta"] self.max_position_embeddings = getattr(self.text_config, "max_position_embeddings", max_sequence_length) self.text_config.max_position_embeddings = self.max_position_embeddings - self.layer_types = getattr(self.text_config, "layer_types", None) - layer_type_validation(self.layer_types, self.num_hidden_layers) - - if getattr(self, "_attn_implementation", None) is None: - self._attn_implementation = "sdpa" # Vision normalization parameters self.vision_rescale_factor = float(vision_rescale_factor) @@ -705,302 +770,162 @@ def __init__( self.vision_token = vision_token +@auto_docstring class IsaacProcessor(ProcessorMixin): - """Processor that pairs the Isaac image processor with the Qwen2 tokenizer. - - Args: - image_processor: Vision preprocessor (fast) used for patch extraction. - tokenizer: Qwen2 tokenizer instance. - vision_token (str, optional): Placeholder token marking image locations. Defaults to "". - max_sequence_length (int, optional): Maximum combined text+vision tokens kept. Defaults to 16384. - rescale_factor (float, optional): Image rescale factor; defaults to 1/255. - config (IsaacConfig | dict, optional): If provided, overrides processor defaults from the model config. - - Returns: - BatchFeature: Top-level batched text and vision tensors. - """ - - attributes = ["image_processor", "tokenizer"] - image_processor_class = ("IsaacImageProcessorFast",) - tokenizer_class = ("Qwen2Tokenizer",) - pad_token_id = 151643 - def __init__( self, image_processor, tokenizer, - *, + chat_template: str | dict[str, str] | None = None, vision_token: str = "", max_sequence_length: int = 16384, rescale_factor: float | None = None, - config: IsaacConfig | dict | None = None, - ) -> None: - if isinstance(config, dict): - config = IsaacConfig(**config) - - if config is not None: - vision_token = config.vision_token - max_sequence_length = config.max_sequence_length - rescale_factor = config.vision_rescale_factor - - resolved_rescale_factor = float(rescale_factor) if rescale_factor is not None else float(1 / 255) - if config is not None: - config.vision_rescale_factor = resolved_rescale_factor + ): + """ + Args: + chat_template (`str` or `dict[str, str]`, *optional*): + Chat template override forwarded to [`~processing_utils.ProcessorMixin`]. + vision_token (`str`, *optional*, defaults to `""`): + Placeholder token used inside text prompts to mark image positions. + max_sequence_length (`int`, *optional*, defaults to 16384): + Maximum packed multimodal sequence length produced by the processor. + rescale_factor (`float`, *optional*): + Deprecated compatibility argument accepted for backward compatibility. + """ + if chat_template is None: + chat_template = getattr(tokenizer, "chat_template", None) self.image_processor = image_processor - super().__init__(image_processor, tokenizer) - - text_pad_token_id = getattr(self.tokenizer, "pad_token_id", None) - image_pad_token_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>") + super().__init__(image_processor, tokenizer, chat_template=chat_template) + self.text_pad_token_id = self.pad_token_id = tokenizer.pad_token_id + self.image_pad_token_id = tokenizer.image_pad_token_id + self.image_token = tokenizer.image_pad_token + self.image_token_id = self.image_pad_token_id - self.text_pad_token_id = int(text_pad_token_id) - self.image_pad_token_id = int(image_pad_token_id) - self.pad_token_id = self.text_pad_token_id - - self.current_processor = self.image_processor - self.config = config - self.chat_template = getattr(self.tokenizer, "chat_template", None) self.vision_token = vision_token self.max_sequence_length = max_sequence_length def _build_batch( self, text: str | list[str], - images: Image | list[Image] | None = None, + images: ImageInput | None = None, + text_kwargs: dict[str, Any] | None = None, ) -> dict[str, torch.Tensor | None]: + text_kwargs = copy.deepcopy(text_kwargs) if text_kwargs is not None else {} + truncation = text_kwargs.pop("truncation", None) + max_length = text_kwargs.pop("max_length", None) + padding = text_kwargs.pop("padding", True) + padding_side = text_kwargs.pop("padding_side", "left") + return_attention_mask = text_kwargs.pop("return_attention_mask", True) + pad_to_multiple_of = text_kwargs.pop("pad_to_multiple_of", None) + text_kwargs.pop("return_tensors", None) + text_kwargs.setdefault("add_special_tokens", False) + texts = [text] if isinstance(text, str) else text if images is None: - pairs = ((text_value, None) for text_value in texts) - elif isinstance(images, list) and len(images) == len(texts): - if not images: - images_list = [] - elif isinstance(images[0], list): - images_list = images - else: - images_list = [[image] for image in images] - pairs = zip(texts, images_list, strict=True) + batched_images = [[] for _ in texts] else: - pairs = ( - ( - text_value, - None - if text_value.count(self.vision_token) == 0 - else images - if isinstance(images, list) - else [images], + fetched_images = self.image_processor.fetch_images(images) + batched_images = make_nested_list_of_images(fetched_images) + if len(batched_images) != len(texts): + num_images_in_text = [text_value.count(self.vision_token) for text_value in texts] + num_images_in_images = [len(sample_images) for sample_images in batched_images] + add_message = "" + if sum(num_images_in_text) == sum(num_images_in_images): + add_message = " Make sure to pass your images as a nested list, where each sub-list holds images for one text sample." + + raise ValueError( + f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(texts)}).{add_message}" ) - for text_value in texts - ) + + pairs = list(zip(texts, batched_images, strict=True)) + image_inputs = self.image_processor(images=batched_images, return_tensors=TensorType.PYTORCH) + vision_token_grids = image_inputs["vision_token_grids"] + vision_segment_lengths = (vision_token_grids[..., 0] // self.image_processor.pixel_shuffle_scale) * ( + vision_token_grids[..., 1] // self.image_processor.pixel_shuffle_scale + ) + vision_token_offsets = torch.zeros_like(vision_segment_lengths) + vision_token_lengths = torch.zeros_like(vision_segment_lengths) sample_input_ids: list[torch.Tensor] = [] - sample_modality: list[torch.Tensor] = [] - sample_position_ids: list[torch.Tensor] = [] - sample_vision_patches: list[list[torch.Tensor]] = [] - sample_vision_grids: list[torch.Tensor] = [] - sample_vision_offsets: list[torch.Tensor] = [] - sample_vision_lengths: list[torch.Tensor] = [] - - for text_value, sample_images in pairs: + expanded_texts = [] + expected_image_lengths_per_sample = [] + + for batch_idx, (text_value, sample_images) in enumerate(pairs): segments = text_value.split(self.vision_token) num_images = len(segments) - 1 - num_provided_images = len(sample_images) if sample_images is not None else 0 + num_provided_images = len(sample_images) if num_images != num_provided_images: raise ValueError( f"IsaacProcessor expects one image per image token, got {num_images} tokens and {num_provided_images} images in sample with text {text_value} " ) - items: list[dict[str, Any]] = [] - total = 0 - for index, segment in enumerate(segments): - if segment: - text_tokens = ( - self.tokenizer.encode(segment, add_special_tokens=False, return_tensors="pt") - .squeeze(0) - .to(torch.long) - ) - segment_length = int(text_tokens.numel()) - items.append({"type": "text", "segment_length": segment_length, "tokens": text_tokens}) - total += segment_length - - if index < num_images: - feature = self.image_processor(images=sample_images[index], return_tensors=TensorType.PYTORCH) - patches = feature["patches"][0].reshape(-1, feature["patches"].shape[-1]) - virtual_pixel_size = feature["virtual_pixel_size"][0].to(torch.long).tolist() - real_pixel_size = feature["real_pixel_size"][0].to(torch.long).tolist() - dims = tuple((virtual_pixel_size + [1, 1, 1])[:3]) - segment_length = int(dims[0] * dims[1] * dims[2]) - items.append( - { - "type": "image", - "segment_length": segment_length, - "dims": dims, - "patches": patches, - "grid": (int(real_pixel_size[1]), int(real_pixel_size[2])), - } - ) - total += segment_length - - start = max(0, total - self.max_sequence_length) - end = total - base_device: torch.device | None = None - input_ids_chunks, modality_chunks, position_chunks = [], [], [] - vision_patches, vision_grids, vision_offsets, vision_lengths = [], [], [], [] - global_offset = 0 - position_offset = 0 - - for item in items: - segment_length = int(item["segment_length"]) - current_window_start = max(start, global_offset) - current_window_end = min(end, global_offset + segment_length) - has_overlap = current_window_end > current_window_start - - if has_overlap and base_device is None: - base_device = item["patches"].device if item["type"] == "image" else item["tokens"].device - - if has_overlap: - segment_local_start = int(current_window_start - global_offset) - segment_local_end = int(current_window_end - global_offset) - segment_local_indices = torch.arange( - segment_local_start, segment_local_end, device=base_device, dtype=torch.long - ) - segment_kept_length = segment_local_end - segment_local_start - - if item["type"] == "text": - slice_index = segment_local_indices + position_offset - zero_axis = torch.zeros_like(slice_index) - position_chunks.append(torch.stack((slice_index, zero_axis, zero_axis), -1)) - modality_chunks.append( - torch.full( - (segment_kept_length,), ModalityType.text.value, device=base_device, dtype=torch.long - ) - ) - input_ids_chunks.append(item["tokens"].to(base_device)[segment_local_start:segment_local_end]) - position_offset += segment_length - else: - num_pos_slices, grid_height_tokens, grid_width_tokens = item["dims"] - hw = grid_height_tokens * grid_width_tokens - slice_index = (segment_local_indices // hw) + position_offset - rem = segment_local_indices % hw - position_chunks.append( - torch.stack((slice_index, rem // grid_width_tokens, rem % grid_width_tokens), -1) - ) - modality_chunks.append( - torch.full( - (segment_kept_length,), ModalityType.image.value, device=base_device, dtype=torch.long - ) - ) - input_ids_chunks.append( - torch.full( - (segment_kept_length,), self.image_pad_token_id, device=base_device, dtype=torch.long - ) - ) - - vision_patches.append(item["patches"].to(base_device)) - vision_grids.append(item["grid"]) - vision_offsets.append(segment_local_start) - vision_lengths.append(segment_kept_length) - position_offset += int(num_pos_slices) - else: - position_offset += segment_length if item["type"] == "text" else int(item["dims"][0]) - - global_offset += segment_length - - if base_device is None: - base_device = torch.device("cpu") - - sample_input_ids.append( - torch.cat(input_ids_chunks, 0) - if input_ids_chunks - else torch.zeros((0,), device=base_device, dtype=torch.long) - ) - sample_modality.append( - torch.cat(modality_chunks, 0) - if modality_chunks - else torch.zeros((0,), device=base_device, dtype=torch.long) - ) - sample_position_ids.append( - torch.cat(position_chunks, 0) - if position_chunks - else torch.zeros((0, 3), device=base_device, dtype=torch.long) - ) - sample_vision_patches.append(vision_patches) - if vision_patches: - sample_vision_grids.append(torch.tensor(vision_grids, device=base_device, dtype=torch.long)) - sample_vision_offsets.append(torch.tensor(vision_offsets, device=base_device, dtype=torch.long)) - sample_vision_lengths.append(torch.tensor(vision_lengths, device=base_device, dtype=torch.long)) - else: - sample_vision_grids.append(torch.zeros((0, 2), device=base_device, dtype=torch.long)) - sample_vision_offsets.append(torch.zeros((0,), device=base_device, dtype=torch.long)) - sample_vision_lengths.append(torch.zeros((0,), device=base_device, dtype=torch.long)) - - batch_size = len(sample_input_ids) - lengths = [int(sample_input.shape[0]) for sample_input in sample_input_ids] - max_len = max(lengths, default=0) - base_device = next( - (sample_input.device for sample_input in sample_input_ids if sample_input.numel() > 0), - torch.device("cpu"), - ) - - input_ids = torch.full((batch_size, max_len), self.text_pad_token_id, device=base_device, dtype=torch.long) - attention_mask = torch.zeros((batch_size, max_len), device=base_device, dtype=torch.long) - modality_tensor = torch.full( - (batch_size, max_len), ModalityType.text.value, device=base_device, dtype=torch.long + expected_image_lengths = [ + int(vision_segment_lengths[batch_idx, image_idx].item()) for image_idx in range(num_images) + ] + expected_image_lengths_per_sample.append(expected_image_lengths) + + expanded_text = segments[0] + for image_idx, segment_length in enumerate(expected_image_lengths): + expanded_text += (self.image_token * segment_length) + segments[image_idx + 1] + expanded_texts.append(expanded_text) + + text_inputs = self.tokenizer(expanded_texts, return_tensors=None, **text_kwargs) + self._check_special_mm_tokens(expanded_texts, text_inputs, modalities=["image"]) + + effective_max_length = self.max_sequence_length + if truncation and max_length is not None: + effective_max_length = max_length + + for batch_idx, (expected_image_lengths, sample_input_ids_list) in enumerate( + zip(expected_image_lengths_per_sample, text_inputs["input_ids"], strict=True) + ): + sample_input = torch.tensor(sample_input_ids_list, dtype=torch.long) + image_positions = sample_input.eq(self.image_pad_token_id).nonzero(as_tuple=False).flatten() + image_spans = image_positions.split(expected_image_lengths) if expected_image_lengths else () + image_bounds = [] + + for image_idx, (segment_length, image_span) in enumerate( + zip(expected_image_lengths, image_spans, strict=True) + ): + image_start = int(image_span[0].item()) + image_end = int(image_span[-1].item()) + 1 + image_bounds.append((image_start, image_end)) + total = int(sample_input.shape[0]) + start = max(0, total - effective_max_length) + sample_input_ids.append(sample_input[start:]) + + for image_idx, (image_start, image_end) in enumerate(image_bounds): + kept_start = max(start, image_start) + kept_end = min(total, image_end) + if kept_end > kept_start: + vision_token_offsets[batch_idx, image_idx] = kept_start - image_start + vision_token_lengths[batch_idx, image_idx] = kept_end - kept_start + + text_inputs = self.tokenizer.pad( + {"input_ids": [sample_input.tolist() for sample_input in sample_input_ids]}, + padding=padding, + max_length=max_length if padding == "max_length" else None, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + return_tensors=TensorType.PYTORCH, ) - position_ids = torch.zeros((batch_size, max_len, 3), device=base_device, dtype=torch.long) + input_ids = text_inputs["input_ids"] + attention_mask = text_inputs.get("attention_mask") + if attention_mask is None: + attention_mask = input_ids.ne(self.pad_token_id).to(dtype=torch.long) - for batch_idx, length in enumerate(lengths): - if length == 0: - continue - input_ids[batch_idx, -length:] = sample_input_ids[batch_idx] - attention_mask[batch_idx, -length:] = 1 - modality_tensor[batch_idx, -length:] = sample_modality[batch_idx] - position_ids[batch_idx, -length:] = sample_position_ids[batch_idx] - - image_counts = [len(patches) for patches in sample_vision_patches] - max_images = max(image_counts, default=0) - if max_images == 0: - vision_patches = None - vision_patch_attention_mask = None - vision_token_grids = None - vision_token_offsets = None - vision_token_lengths = None - vision_image_attention_mask = None - else: - first_patch = next((patches[0] for patches in sample_vision_patches if patches), None) - patch_dim = first_patch.shape[-1] - patch_dtype = first_patch.dtype - max_patches = max((patch.shape[0] for patches in sample_vision_patches for patch in patches), default=0) + mm_token_type_ids = input_ids.eq(self.image_pad_token_id).to(dtype=torch.long) + vision_image_attention_mask = vision_token_lengths.gt(0).to(dtype=torch.long) - vision_patches = torch.zeros( - (batch_size, max_images, max_patches, patch_dim), device=base_device, dtype=patch_dtype - ) - vision_patch_attention_mask = torch.zeros( - (batch_size, max_images, max_patches), device=base_device, dtype=torch.long - ) - vision_token_grids = torch.zeros((batch_size, max_images, 2), device=base_device, dtype=torch.long) - vision_token_offsets = torch.zeros((batch_size, max_images), device=base_device, dtype=torch.long) - vision_token_lengths = torch.zeros((batch_size, max_images), device=base_device, dtype=torch.long) - vision_image_attention_mask = torch.zeros((batch_size, max_images), device=base_device, dtype=torch.long) - - for batch_idx, sample_patches in enumerate(sample_vision_patches): - sample_image_count = len(sample_patches) - if sample_image_count == 0: - continue - vision_token_grids[batch_idx, :sample_image_count] = sample_vision_grids[batch_idx] - vision_token_offsets[batch_idx, :sample_image_count] = sample_vision_offsets[batch_idx] - vision_token_lengths[batch_idx, :sample_image_count] = sample_vision_lengths[batch_idx] - vision_image_attention_mask[batch_idx, :sample_image_count] = 1 - - for image_idx, patches in enumerate(sample_patches): - patch_count = int(patches.shape[0]) - vision_patches[batch_idx, image_idx, :patch_count] = patches - vision_patch_attention_mask[batch_idx, image_idx, :patch_count] = 1 + vision_patches = image_inputs["vision_patches"] + vision_patch_attention_mask = image_inputs["vision_patch_attention_mask"] return { "input_ids": input_ids, "attention_mask": attention_mask, - "position_ids": position_ids, - "modality_tensor": modality_tensor, + "mm_token_type_ids": mm_token_type_ids, "vision_patches": vision_patches, "vision_patch_attention_mask": vision_patch_attention_mask, "vision_token_grids": vision_token_grids, @@ -1009,21 +934,53 @@ def _build_batch( "vision_image_attention_mask": vision_image_attention_mask, } + def post_process_generation( + self, + text: str, + expected: str | None = None, + cleanup_and_extract: bool = True, + ) -> str | tuple[str, list[SinglePoint | BoundingBox]]: + if cleanup_and_extract: + return clean_text_and_extract_points(text, expected=expected) + return text + + def post_process_image_text_to_text( + self, + generated_outputs, + skip_special_tokens: bool = True, + cleanup_and_extract: bool = False, + expected: str | None = None, + **kwargs, + ): + generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs) + return [ + self.post_process_generation(text, expected=expected, cleanup_and_extract=cleanup_and_extract) + for text in generated_texts + ] + def __call__( self, text: str | list[str], - images: Image | list[Image] | None = None, + images: ImageInput | None = None, return_tensors: str | TensorType | None = TensorType.PYTORCH, **kwargs, ) -> BatchFeature: - return BatchFeature(data=self._build_batch(text=text, images=images), tensor_type=return_tensors) + output_kwargs = self._merge_kwargs( + IsaacProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + return BatchFeature( + data=self._build_batch(text=text, images=images, text_kwargs=output_kwargs["text_kwargs"]), + tensor_type=return_tensors, + ) -class IsaacRotaryEmbedding(Qwen2_5_VLRotaryEmbedding): - def __init__(self, config: IsaacConfig, device=None): - rope_source_cfg = config.get_text_config() if hasattr(config, "get_text_config") else config - rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} +class IsaacRotaryEmbedding(Qwen3VLTextRotaryEmbedding): + def __init__(self, config: IsaacConfig | IsaacTextConfig, device=None): + rope_source_cfg = config.get_text_config() config_for_rope = copy.copy(rope_source_cfg) + rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} config_for_rope.rope_scaling = rope_scaling super().__init__( @@ -1031,9 +988,8 @@ def __init__(self, config: IsaacConfig, device=None): device=device if device is not None and getattr(device, "type", None) != "meta" else None, ) - rotary_half_dim = self.inv_freq.shape[0] - self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), rotary_half_dim) - self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size + self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), self.inv_freq.shape[0]) + self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) @staticmethod def _resolve_mrope_section(section: list[int] | None, rotary_half_dim: int) -> list[int]: @@ -1046,52 +1002,23 @@ def _resolve_mrope_section(section: list[int] | None, rotary_half_dim: int) -> l section = [int(v) for v in section] return section - def _combine_axes(self, tensor: torch.Tensor) -> torch.Tensor: - split_sections = tuple(self.mrope_section * 2) - chunks = tensor.split(split_sections, dim=-1) + def apply_interleaved_mrope(self, freqs: torch.Tensor, mrope_section: list[int]) -> torch.Tensor: + chunks = freqs.split(tuple(mrope_section), dim=-1) return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) - def forward( - self, - position_ids: torch.Tensor, - modality_tensor: torch.Tensor, - hidden_states: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if hidden_states is None: - batch, seq_len, _ = position_ids.shape - hidden_states = torch.zeros( - batch, - seq_len, - self.hidden_size, - dtype=torch.float32, - device=position_ids.device, - ) - with torch.no_grad(): - pos = position_ids.clone() - not_spatial = modality_tensor == 1 - data_1d = pos[not_spatial][..., 0].unsqueeze(-1) # Collapse non-vision modalities to 1D positions - pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) - pos_axes = pos.permute(2, 0, 1).contiguous() +class IsaacTextAttention(Qwen3VLTextAttention): + pass - inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, pos_axes.shape[1], -1, 1) - pos_axes_expanded = pos_axes[:, :, None, :].float() # shape (3, bs, 1, positions) - device_type = ( - hidden_states.device.type - if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps" - else "cpu" - ) - with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ pos_axes_expanded.float()).transpose(2, 3) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling +class IsaacTextDecoderLayer(Qwen3VLTextDecoderLayer): + pass - cos_axes, sin_axes = cos.to(hidden_states.dtype), sin.to(hidden_states.dtype) - cos_combined, sin_combined = self._combine_axes(cos_axes), self._combine_axes(sin_axes) - return cos_combined, sin_combined +class IsaacTextModel(Qwen3VLTextModel): + def __init__(self, config: IsaacTextConfig): + super().__init__(config) + self.rotary_emb = IsaacRotaryEmbedding(config=config, device=self.device) @auto_docstring @@ -1099,20 +1026,14 @@ class IsaacModel(Qwen3PreTrainedModel): supports_gradient_checkpointing = True _can_compile_fullgraph = False _supports_flex_attn = False - _can_record_outputs = { - "hidden_states": OutputRecorder(Qwen3DecoderLayer), - "attentions": Qwen3Attention, - "vision_attentions": IsaacVisionAttention, - } - all_tied_weights_keys: dict[str, str] = {} + _tied_weights_keys = {} def __init__(self, config: IsaacConfig): Qwen3PreTrainedModel.__init__(self, config) - self.text_model = Qwen3Model._from_config(config.text_config) + self.text_model = IsaacTextModel._from_config(config.text_config) - self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) - - self.vision_embedding = IsaacVisionEmbedding(config) + self.vision_tower = IsaacVisionTransformer(config.vision_config) + self.multimodal_projector = IsaacMultiModalProjector(config) self.max_sequence_length = config.max_sequence_length self.vision_rescale_factor = config.vision_rescale_factor self.vision_token = config.vision_token @@ -1125,282 +1046,376 @@ def get_input_embeddings(self) -> nn.Module: def set_input_embeddings(self, value: nn.Module) -> None: self.text_model.set_input_embeddings(value) - vocab_size = getattr(value, "num_embeddings", None) - if vocab_size is not None: - self.config.vocab_size = vocab_size - if hasattr(self.config, "text_config"): - self.config.text_config.vocab_size = vocab_size - self.text_model.config.vocab_size = vocab_size - - @property - def final_norm(self) -> nn.Module: - return self.text_model.norm - - @property - def embed_tokens(self) -> nn.Module: - return self.text_model.embed_tokens - - @embed_tokens.setter - def embed_tokens(self, value: nn.Module) -> None: - self.text_model.embed_tokens = value - - def embed_multimodal_inputs( + + @can_return_tuple + @auto_docstring + def get_image_features( self, - input_ids: torch.Tensor, - modality_tensor: torch.Tensor, - vision_patches: torch.Tensor, - vision_token_grids: torch.Tensor, - vision_patch_attention_mask: torch.Tensor | None = None, - vision_token_offsets: torch.Tensor | None = None, - vision_token_lengths: torch.Tensor | None = None, - vision_image_attention_mask: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - modality = modality_tensor.to(device=input_ids.device, dtype=torch.long) - embeds = self.text_model.embed_tokens(input_ids) - image_token_mask = modality == ModalityType.image.value - - if vision_patches is None or vision_token_grids is None: - if torch.any(image_token_mask): - raise ValueError("Image placeholders require `vision_patches` and `vision_token_grids`.") - return embeds, modality - - vision_patches = vision_patches.to(device=embeds.device) - token_grids = vision_token_grids.to(device=embeds.device, dtype=torch.long) - image_attention_mask = ( - vision_image_attention_mask.to(device=embeds.device, dtype=torch.bool) - if vision_image_attention_mask is not None - else torch.ones(token_grids.shape[:2], device=embeds.device, dtype=torch.bool) - ) - patch_attention_mask = ( - vision_patch_attention_mask.to(device=embeds.device, dtype=torch.long) - if vision_patch_attention_mask is not None - else torch.ones(vision_patches.shape[:3], device=embeds.device, dtype=torch.long) - ) - offsets = ( - vision_token_offsets.to(device=embeds.device, dtype=torch.long) - if vision_token_offsets is not None - else torch.zeros(token_grids.shape[:2], device=embeds.device, dtype=torch.long) + pixel_values: torch.Tensor, + image_token_grids: torch.Tensor, + image_patch_attention_mask: torch.Tensor | None = None, + image_token_offsets: torch.Tensor | None = None, + image_token_lengths: torch.Tensor | None = None, + image_attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + """ + Args: + pixel_values (`torch.Tensor`): + Padded per-image patch vectors with shape `(batch_size, max_images, max_patches, patch_dim)`. + image_token_grids (`torch.Tensor`): + Per-image token grids shaped `(batch_size, max_images, 2)` with `(height, width)` entries. + image_patch_attention_mask (`torch.Tensor`, *optional*): + Mask for valid patch rows in `pixel_values`, shaped `(batch_size, max_images, max_patches)`. + image_token_offsets (`torch.Tensor`, *optional*): + Start offsets inside each per-image embedding sequence, shaped `(batch_size, max_images)`. + image_token_lengths (`torch.Tensor`, *optional*): + Number of image tokens to gather per image for placeholder scattering, shaped `(batch_size, max_images)`. + image_attention_mask (`torch.Tensor`, *optional*): + Mask indicating which image slots are populated, shaped `(batch_size, max_images)`. + """ + device = self.text_model.embed_tokens.weight.device + pixel_values = pixel_values.to(device=device) + image_token_grids = image_token_grids.to(device=device, dtype=torch.long) + patch_attention_mask = image_patch_attention_mask.to(device=device, dtype=torch.long) + if image_attention_mask is None: + if image_token_lengths is not None: + image_attention_mask = image_token_lengths.to(device=device, dtype=torch.long) > 0 + else: + image_attention_mask = image_token_grids.any(dim=-1) + else: + image_attention_mask = image_attention_mask.to(device=device, dtype=torch.bool) + + batch_size, max_images = pixel_values.shape[:2] + hidden_size = self.config.get_text_config().hidden_size + + if image_attention_mask.any(): + vision_kwargs = { + key: value + for key in ("output_hidden_states", "output_attentions") + if (value := kwargs.get(key)) is not None + } + vision_outputs = self.vision_tower( + vision_patches=pixel_values[image_attention_mask], + vision_token_grids=image_token_grids[image_attention_mask], + vision_patch_attention_mask=patch_attention_mask[image_attention_mask], + return_dict=True, + **vision_kwargs, + ) + flat_projected_features = self.multimodal_projector(vision_outputs.last_hidden_state) + max_tokens = flat_projected_features.shape[1] + projected_features = flat_projected_features.new_zeros((batch_size, max_images, max_tokens, hidden_size)) + projected_features[image_attention_mask] = flat_projected_features + offsets = ( + image_token_offsets.to(device=device, dtype=torch.long) + if image_token_offsets is not None + else torch.zeros((batch_size, max_images), device=device, dtype=torch.long) + ) + lengths = ( + image_token_lengths.to(device=device, dtype=torch.long) + if image_token_lengths is not None + else torch.full((batch_size, max_images), max_tokens, device=device, dtype=torch.long) + ) + flat_offsets = offsets[image_attention_mask] + flat_lengths = lengths[image_attention_mask] + token_positions = torch.arange(flat_lengths.max(), device=device, dtype=torch.long) + gather_positions = flat_offsets[:, None] + token_positions[None, :] + gather_mask = token_positions[None, :] < flat_lengths[:, None] + image_features = flat_projected_features[ + torch.arange(flat_projected_features.shape[0], device=device, dtype=torch.long)[:, None], + gather_positions, + ][gather_mask] + hidden_states = vision_outputs.hidden_states + attentions = vision_outputs.attentions + else: + projected_features = pixel_values.new_zeros((batch_size, max_images, 0, hidden_size)) + image_features = pixel_values.new_zeros((0, hidden_size)) + hidden_states = None + attentions = None + + return BaseModelOutputWithPooling( + last_hidden_state=projected_features, + pooler_output=image_features, + hidden_states=hidden_states, + attentions=attentions, ) - reduction_factor = int(self.config.vision_config.pixel_shuffle_scale_factor) ** 2 - lengths = ( - vision_token_lengths.to(device=embeds.device, dtype=torch.long) - if vision_token_lengths is not None - else token_grids.prod(-1).div(reduction_factor, rounding_mode="floor").to(dtype=torch.long) + + def get_placeholder_mask( + self, + mm_token_type_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor, + ) -> torch.BoolTensor: + image_token_mask = mm_token_type_ids.to(device=inputs_embeds.device, dtype=torch.long) == 1 + n_image_tokens = image_token_mask.sum() + image_token_mask = image_token_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[image_token_mask].numel() == image_features.numel(), + f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", ) + return image_token_mask - flat_vision_patches = vision_patches[image_attention_mask] - flat_patch_attention_mask = patch_attention_mask[image_attention_mask] - flat_token_grids = token_grids[image_attention_mask] - flat_offsets = offsets[image_attention_mask] - flat_lengths = lengths[image_attention_mask] - - vision_embeddings = self.vision_embedding((flat_vision_patches, flat_token_grids, flat_patch_attention_mask)) - token_positions = torch.arange(flat_lengths.max(), device=embeds.device, dtype=torch.long) - gather_positions = flat_offsets[:, None] + token_positions[None, :] - gather_mask = token_positions[None, :] < flat_lengths[:, None] - image_features = vision_embeddings[ - torch.arange(vision_embeddings.shape[0], device=embeds.device, dtype=torch.long)[:, None], - gather_positions, - ][gather_mask] - scatter_mask = image_token_mask.unsqueeze(-1).expand_as(embeds) - embeds = embeds.masked_scatter(scatter_mask, image_features) - - return embeds, modality + def get_vision_position_ids( + self, + start_position: int, + grid_hw: torch.LongTensor, + token_offset: int, + token_length: int, + ) -> torch.LongTensor: + height, width = grid_hw[0].item(), grid_hw[1].item() + token_positions = torch.arange(height * width, device=grid_hw.device, dtype=torch.long) + vision_position_ids = torch.stack( + ( + torch.full((token_positions.shape[0],), start_position, device=grid_hw.device, dtype=torch.long), + token_positions.div(width, rounding_mode="floor"), + token_positions.remainder(width), + ), + dim=0, + ) + return vision_position_ids[:, token_offset : token_offset + token_length] def get_rope_index( self, - *, - position_ids: torch.Tensor | None = None, + mm_token_type_ids: torch.Tensor, + image_token_grids: torch.Tensor, + image_token_offsets: torch.Tensor, + image_token_lengths: torch.Tensor, attention_mask: torch.Tensor, - inputs_embeds: torch.Tensor, - cache_position: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Prepare multimodal RoPE positions and carry forward per-batch offsets. + """Prepare multimodal RoPE positions for the current prefill sequence. Unlike vanilla 1D RoPE, Isaac builds 3-axis indices for text and vision tokens. - If callers do not supply positions, we synthesize them from `cache_position` and - use `attention_mask` to strip left padding so pad tokens never consume RoPE slots. - The returned `rope_deltas` capture any custom offset (i.e., prefill length) and - are reused across generation steps so newly decoded tokens keep counting forward - after the cached prefix.""" - - device = inputs_embeds.device - batch_size, seq_len = inputs_embeds.shape[:2] - - if position_ids is None: - cp = cache_position.to(device=device, dtype=torch.long) - if cp.ndim == 1: - cp = cp.view(1, -1).expand(batch_size or 1, -1) - - is_new_prefill = cp[:, :1].eq(0).all(dim=1, keepdim=True) - if self.rope_deltas is None: - base_delta = torch.zeros((batch_size, 1), device=device, dtype=torch.long) - else: - previous_delta = torch.as_tensor(self.rope_deltas, device=device, dtype=torch.long).reshape(-1, 1) - previous_delta = torch.broadcast_to(previous_delta, (batch_size, 1)) - base_delta = torch.where(is_new_prefill, torch.zeros_like(previous_delta), previous_delta) + If callers do not supply positions, we synthesize text-style positions from + `attention_mask`. The returned `rope_deltas` capture any custom offset between + the attended sequence length and Isaac's multimodal positions so decode steps can + keep counting forward from the cached prefix.""" + + device = attention_mask.device + batch_size, seq_len = attention_mask.shape + mm_token_type_ids = mm_token_type_ids.to(device=device, dtype=torch.long) + image_token_grids = image_token_grids.to(device=device, dtype=torch.long) + image_token_offsets = image_token_offsets.to(device=device, dtype=torch.long) + image_token_lengths = image_token_lengths.to(device=device, dtype=torch.long) + attention_mask = attention_mask.to(device=device, dtype=torch.long) + image_attention_mask = image_token_lengths > 0 + + position_ids = torch.zeros((3, batch_size, seq_len), device=device, dtype=torch.long) + rope_deltas = torch.zeros((batch_size, 1), device=device, dtype=torch.long) + pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor + + for batch_idx in range(batch_size): + sample_attention_mask = attention_mask[batch_idx].bool() + sample_token_types = mm_token_type_ids[batch_idx][sample_attention_mask] + sample_grids = image_token_grids[batch_idx][image_attention_mask[batch_idx]] + sample_offsets = image_token_offsets[batch_idx][image_attention_mask[batch_idx]] + sample_lengths = image_token_lengths[batch_idx][image_attention_mask[batch_idx]] + + current_pos = 0 + image_idx = 0 + seq_pos = 0 + llm_pos_ids_list = [] + + while seq_pos < sample_token_types.shape[0]: + modality_type = int(sample_token_types[seq_pos].item()) + group_end = seq_pos + 1 + while group_end < sample_token_types.shape[0] and sample_token_types[group_end] == modality_type: + group_end += 1 + + group_length = group_end - seq_pos + if modality_type == 0: + llm_pos_ids_list.append( + torch.arange(group_length, device=device, dtype=torch.long).view(1, -1).expand(3, -1) + + current_pos + ) + current_pos += group_length + else: + grid_hw = sample_grids[image_idx].div(pixel_shuffle_scale, rounding_mode="floor") + token_offset = int(sample_offsets[image_idx].item()) + token_length = int(sample_lengths[image_idx].item()) + llm_pos_ids_list.append( + self.get_vision_position_ids(current_pos, grid_hw, token_offset, token_length) + ) + current_pos += 1 + image_idx += 1 + + seq_pos = group_end + + llm_positions = torch.cat(llm_pos_ids_list, dim=1) + position_ids[:, batch_idx, sample_attention_mask] = llm_positions + rope_deltas[batch_idx, 0] = llm_positions.max() + 1 - sample_token_types.shape[0] - mask_delta = attention_mask.to(device=device, dtype=torch.long).sum(1, keepdim=True) - attention_mask.size( - 1 - ) - rope_position = cp + base_delta + mask_delta - pos_3d = rope_position.unsqueeze(-1).expand(-1, -1, 3) - return pos_3d, base_delta - - position_ids = position_ids.to(device=device) - if position_ids.ndim == 2: - position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) - - if position_ids.shape[1] != seq_len: - start_positions = position_ids[:, :1, 0] - position_ids = torch.arange(seq_len, device=position_ids.device).view(1, -1) + start_positions - position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) - - attn = attention_mask.to(device=device, dtype=torch.long) - m_per_batch = position_ids.amax(dim=(1, 2)) - seq_lens = attn.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=device) - rope_deltas = (m_per_batch + 1 - seq_lens).to(dtype=position_ids.dtype).unsqueeze(1) return position_ids, rope_deltas - @auto_docstring + def compute_3d_position_ids( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor | None, + mm_token_type_ids: torch.Tensor | None = None, + image_token_grids: torch.Tensor | None = None, + image_token_offsets: torch.Tensor | None = None, + image_token_lengths: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values: torch.Tensor | None = None, + ) -> torch.Tensor: + past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() + + if image_token_lengths is not None and image_token_lengths.gt(0).any() and past_seen_tokens == 0: + position_ids, rope_deltas = self.get_rope_index( + mm_token_type_ids=mm_token_type_ids, + image_token_grids=image_token_grids, + image_token_offsets=image_token_offsets, + image_token_lengths=image_token_lengths, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + ) + self.rope_deltas = rope_deltas + return position_ids + + if position_ids is not None and past_seen_tokens == 0: + position_ids = position_ids.to(device=inputs_embeds.device) + if position_ids.ndim == 2: + return position_ids.view(1, position_ids.shape[0], -1).expand(3, -1, -1) + if position_ids.ndim == 3 and position_ids.shape[0] in (1, 4): + return position_ids + + if self.rope_deltas is None: + return None + + rope_deltas = torch.as_tensor(self.rope_deltas, device=inputs_embeds.device, dtype=torch.long).reshape(-1, 1) + if rope_deltas.shape[0] != inputs_embeds.shape[0]: + if inputs_embeds.shape[0] % rope_deltas.shape[0] == 0: + rope_deltas = rope_deltas.repeat_interleave(inputs_embeds.shape[0] // rope_deltas.shape[0], dim=0) + else: + rope_deltas = rope_deltas[:1].expand(inputs_embeds.shape[0], -1) + + if attention_mask is not None and attention_mask.shape[-1] > inputs_embeds.shape[1]: + rope_position = attention_mask.long().cumsum(dim=-1) - 1 + rope_position = rope_position.masked_fill(attention_mask == 0, 0) + rope_position = rope_position[:, -inputs_embeds.shape[1] :] + else: + rope_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + dtype=torch.long, + ).view(1, -1) + rope_position = rope_position.expand(inputs_embeds.shape[0], -1) + + position_ids = rope_position.view(1, inputs_embeds.shape[0], -1).expand(3, -1, -1) + return position_ids + rope_deltas.to(device=inputs_embeds.device).unsqueeze(0) + + @auto_docstring( + custom_intro=""" + Forward pass with multimodal MRoPE position ids. + + When image placeholders are present, Isaac computes vision features, scatters them into the token + embeddings, and runs the shared text backbone on the mixed sequence. + """, + ) + @can_return_tuple @merge_with_config_defaults - @capture_outputs def forward( self, input_ids: torch.LongTensor | None = None, - modality_tensor: torch.LongTensor | None = None, + mm_token_type_ids: torch.LongTensor | None = None, vision_patches: torch.Tensor | None = None, vision_patch_attention_mask: torch.Tensor | None = None, + image_patch_attention_mask: torch.Tensor | None = None, vision_token_grids: torch.LongTensor | None = None, + image_token_grids: torch.LongTensor | None = None, vision_token_offsets: torch.LongTensor | None = None, vision_token_lengths: torch.LongTensor | None = None, - vision_image_attention_mask: torch.LongTensor | None = None, + image_attention_mask: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: list[torch.FloatTensor] | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, - cache_position: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPast: """ - Forward pass with MRoPE position embeddings. - - Computes position embeddings once and passes them through all layers. - Args: - modality_tensor (`torch.LongTensor`, *optional*): - Modality identifiers aligned with the embedded sequence, shaped `(batch_size, seq_len)` and containing - values from `ModalityType`. Treated as text-only when omitted. + mm_token_type_ids (`torch.LongTensor`, *optional*): + Multimodal token type ids aligned with the embedded sequence, shaped `(batch_size, seq_len)`. Isaac + follows the standard convention `0 -> text`, `1 -> image`. Treated as text-only when omitted. vision_patches (`torch.FloatTensor`, *optional*): Padded per-image patch vectors of shape `(batch_size, max_images, max_patches, patch_dim)`. vision_patch_attention_mask (`torch.LongTensor`, *optional*): Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. + image_patch_attention_mask (`torch.LongTensor`, *optional*): + Alias for `vision_patch_attention_mask`. vision_token_grids (`torch.LongTensor`, *optional*): Per-image patch grids `(h, w)` with shape `(batch_size, max_images, 2)`. + image_token_grids (`torch.LongTensor`, *optional*): + Alias for `vision_token_grids`. vision_token_offsets (`torch.LongTensor`, *optional*): Start offsets inside the per-image vision embedding sequence, shape `(batch_size, max_images)`. vision_token_lengths (`torch.LongTensor`, *optional*): Number of vision tokens to consume per image, shape `(batch_size, max_images)`. - vision_image_attention_mask (`torch.LongTensor`, *optional*): - Mask indicating which image slots are populated, shape `(batch_size, max_images)`. + image_attention_mask (`torch.LongTensor`, *optional*): + Backward-compatible override for populated image slots. When omitted, the model derives it from + `vision_token_lengths > 0`. """ + created_inputs_embeds = inputs_embeds is None + if created_inputs_embeds: + inputs_embeds = self.text_model.embed_tokens(input_ids) - if inputs_embeds is None: - if input_ids is None: - raise ValueError("`input_ids` or `inputs_embeds` must be provided.") - - has_vision_inputs = any( - value is not None - for value in ( - vision_patches, - vision_patch_attention_mask, - vision_token_grids, - vision_token_offsets, - vision_token_lengths, - vision_image_attention_mask, - ) - ) - if modality_tensor is not None or has_vision_inputs: - if modality_tensor is None: - modality_tensor = torch.full_like(input_ids, ModalityType.text.value) - inputs_embeds, modality_tensor = self.embed_multimodal_inputs( - input_ids=input_ids, - modality_tensor=modality_tensor, - vision_patches=vision_patches, - vision_patch_attention_mask=vision_patch_attention_mask, - vision_token_grids=vision_token_grids, - vision_token_offsets=vision_token_offsets, - vision_token_lengths=vision_token_lengths, - vision_image_attention_mask=vision_image_attention_mask, - ) - else: - inputs_embeds = self.text_model.embed_tokens(input_ids) - - if modality_tensor is None: + if mm_token_type_ids is None: batch_size, seq_len = inputs_embeds.shape[:2] - modality_tensor = torch.full( - (batch_size, seq_len), ModalityType.text.value, device=inputs_embeds.device, dtype=torch.long + mm_token_type_ids = torch.full((batch_size, seq_len), 0, device=inputs_embeds.device, dtype=torch.long) + else: + mm_token_type_ids = mm_token_type_ids.to(device=inputs_embeds.device, dtype=torch.long) + + image_token_mask = mm_token_type_ids == 1 + if created_inputs_embeds and torch.any(image_token_mask): + image_outputs = self.get_image_features( + pixel_values=vision_patches, + image_token_grids=vision_token_grids, + image_patch_attention_mask=vision_patch_attention_mask, + image_token_offsets=vision_token_offsets, + image_token_lengths=vision_token_lengths, + image_attention_mask=image_attention_mask, + return_dict=True, ) + image_features = image_outputs.pooler_output.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype) + scatter_mask = self.get_placeholder_mask( + mm_token_type_ids=mm_token_type_ids, + inputs_embeds=inputs_embeds, + image_features=image_features, + ) + inputs_embeds = inputs_embeds.masked_scatter(scatter_mask, image_features) - device = inputs_embeds.device - batch_size, seq_len = inputs_embeds.shape[:2] - - if use_cache and past_key_values is None: - past_key_values = DynamicCache(config=self.config.get_text_config()) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=device) - - if attention_mask is None: - attention_mask = torch.ones(inputs_embeds.shape[:2], device=inputs_embeds.device, dtype=torch.long) + if isinstance(attention_mask, dict): + attention_mask = attention_mask.get("full_attention", next(iter(attention_mask.values()))) - position_ids, rope_deltas = self.get_rope_index( + position_ids = self.compute_3d_position_ids( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + mm_token_type_ids=mm_token_type_ids, + image_token_grids=vision_token_grids, + image_token_offsets=vision_token_offsets, + image_token_lengths=vision_token_lengths, position_ids=position_ids, attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - cache_position=cache_position, + past_key_values=past_key_values, ) - self.rope_deltas = rope_deltas - - cos, sin = self.rotary_emb(position_ids, modality_tensor, hidden_states=inputs_embeds) - - decoder_position_ids = position_ids[..., 0] if position_ids.ndim == 3 else position_ids - - if not isinstance(attention_mask, dict): - attention_mask = create_masks_for_generate( - config=self.config, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - cache_position=cache_position, - past_key_values=past_key_values, - position_ids=decoder_position_ids, - ) - - is_mask_dict = isinstance(attention_mask, dict) - hidden_states = inputs_embeds - - for layer in self.text_model.layers: - layer_mask = attention_mask[layer.attention_type] if is_mask_dict else attention_mask - layer_outputs = layer( - hidden_states, - attention_mask=layer_mask, - position_ids=decoder_position_ids, - past_key_values=past_key_values, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=(cos, sin), - **kwargs, - ) - - hidden_states = layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs - hidden_states = self.final_norm(hidden_states) + text_model_outputs = self.text_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, + last_hidden_state=text_model_outputs.last_hidden_state, + past_key_values=text_model_outputs.past_key_values, + hidden_states=text_model_outputs.hidden_states, + attentions=text_model_outputs.attentions, ) @@ -1409,74 +1424,81 @@ class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): config_class = IsaacConfig _can_compile_fullgraph = False _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"} - all_tied_weights_keys: dict[str, str] = {"lm_head.weight": "model.text_model.embed_tokens.weight"} def __init__(self, config: IsaacConfig): super().__init__(config) self.model = IsaacModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.vocab_size = config.get_text_config().vocab_size + self.lm_head = nn.Linear(config.get_text_config().hidden_size, config.get_text_config().vocab_size, bias=False) @auto_docstring @can_return_tuple - @merge_with_config_defaults def forward( self, input_ids: torch.LongTensor | None = None, - modality_tensor: torch.LongTensor | None = None, + mm_token_type_ids: torch.LongTensor | None = None, vision_patches: torch.Tensor | None = None, + pixel_values: torch.Tensor | None = None, vision_patch_attention_mask: torch.Tensor | None = None, + image_patch_attention_mask: torch.Tensor | None = None, vision_token_grids: torch.LongTensor | None = None, + image_token_grids: torch.LongTensor | None = None, vision_token_offsets: torch.LongTensor | None = None, vision_token_lengths: torch.LongTensor | None = None, - vision_image_attention_mask: torch.LongTensor | None = None, + image_attention_mask: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: list[torch.FloatTensor] | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, - cache_position: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | CausalLMOutputWithPast: r""" - modality_tensor (`torch.LongTensor`, *optional*): - Modality identifiers aligned with the token sequence, shaped `(batch_size, seq_len)`. + mm_token_type_ids (`torch.LongTensor`, *optional*): + Multimodal token type ids aligned with the token sequence, shaped `(batch_size, seq_len)`, using + `0 -> text` and `1 -> image`. vision_patches (`torch.FloatTensor`, *optional*): Padded per-image patch vectors of shape `(batch_size, max_images, max_patches, patch_dim)`. + pixel_values (`torch.FloatTensor`, *optional*): + Alias for `vision_patches` accepted by generic image-feature and generation helpers. vision_patch_attention_mask (`torch.LongTensor`, *optional*): Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. + image_patch_attention_mask (`torch.LongTensor`, *optional*): + Alias for `vision_patch_attention_mask`. vision_token_grids (`torch.LongTensor`, *optional*): Per-image patch grids `(h, w)` with shape `(batch_size, max_images, 2)`. + image_token_grids (`torch.LongTensor`, *optional*): + Alias for `vision_token_grids`. vision_token_offsets (`torch.LongTensor`, *optional*): Start offsets inside the per-image vision embedding sequence, shape `(batch_size, max_images)`. vision_token_lengths (`torch.LongTensor`, *optional*): Number of vision tokens to consume per image, shape `(batch_size, max_images)`. - vision_image_attention_mask (`torch.LongTensor`, *optional*): - Mask indicating which image slots are populated, shape `(batch_size, max_images)`. + image_attention_mask (`torch.LongTensor`, *optional*): + Backward-compatible override for populated image slots. When omitted, the model derives it from + `vision_token_lengths > 0`. """ outputs = self.model( input_ids=input_ids, - modality_tensor=modality_tensor, + mm_token_type_ids=mm_token_type_ids, vision_patches=vision_patches, vision_patch_attention_mask=vision_patch_attention_mask, vision_token_grids=vision_token_grids, vision_token_offsets=vision_token_offsets, vision_token_lengths=vision_token_lengths, - vision_image_attention_mask=vision_image_attention_mask, + image_attention_mask=image_attention_mask, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - cache_position=cache_position, **kwargs, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=logits.shape[-1]) return CausalLMOutputWithPast( loss=loss, @@ -1492,63 +1514,124 @@ def prepare_inputs_for_generation( past_key_values: list[torch.FloatTensor] | None = None, attention_mask: torch.Tensor | None = None, inputs_embeds: torch.FloatTensor | None = None, - modality_tensor: torch.LongTensor | None = None, + mm_token_type_ids: torch.LongTensor | None = None, vision_patches: torch.Tensor | None = None, + pixel_values: torch.Tensor | None = None, vision_patch_attention_mask: torch.Tensor | None = None, + image_patch_attention_mask: torch.Tensor | None = None, vision_token_grids: torch.LongTensor | None = None, + image_token_grids: torch.LongTensor | None = None, vision_token_offsets: torch.LongTensor | None = None, vision_token_lengths: torch.LongTensor | None = None, - vision_image_attention_mask: torch.LongTensor | None = None, - cache_position: torch.LongTensor | None = None, + image_attention_mask: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, + is_first_iteration=False, + use_cache=True, **kwargs, ) -> dict[str, Any]: + if vision_patches is None: + vision_patch_attention_mask = ( + image_patch_attention_mask if vision_patch_attention_mask is None else vision_patch_attention_mask + ) + vision_token_grids = image_token_grids if vision_token_grids is None else vision_token_grids + if position_ids is None or position_ids.ndim == 2: + position_ids = self._prepare_position_ids_for_generation( + input_ids, + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "mm_token_type_ids": mm_token_type_ids, + "vision_token_grids": vision_token_grids, + "vision_token_offsets": vision_token_offsets, + "vision_token_lengths": vision_token_lengths, + }, + ) model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, - cache_position=cache_position, - position_ids=None, + position_ids=position_ids, + is_first_iteration=is_first_iteration, + use_cache=use_cache, **kwargs, ) multimodal_inputs = { - "modality_tensor": modality_tensor, + "mm_token_type_ids": mm_token_type_ids, "vision_patches": vision_patches, "vision_patch_attention_mask": vision_patch_attention_mask, "vision_token_grids": vision_token_grids, "vision_token_offsets": vision_token_offsets, "vision_token_lengths": vision_token_lengths, - "vision_image_attention_mask": vision_image_attention_mask, } - if not any(value is not None for value in multimodal_inputs.values()): - return model_inputs - - past_len = past_key_values.get_seq_length() if past_key_values is not None else 0 - first_step = past_len == 0 + is_prefill = is_first_iteration or not use_cache for key, value in multimodal_inputs.items(): - model_inputs[key] = value if first_step else None - model_inputs["position_ids"] = position_ids if first_step else None + model_inputs[key] = value if is_prefill else None return model_inputs - @classmethod - def can_generate(cls) -> bool: - return True + def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs): + text_positions = super()._prepare_position_ids_for_generation(inputs_tensor, model_kwargs) + + past_length = 0 + if (cache := model_kwargs.get("past_key_values")) is not None: + past_length = cache.get_seq_length() + if past_length != 0 and self.model.rope_deltas is not None: + return text_positions[None, ...] + self.model.rope_deltas + + if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0: + inputs_tensor = model_kwargs["input_ids"] + + if ( + model_kwargs.get("image_token_lengths") is not None + and len(inputs_tensor.shape) == 2 + and inputs_tensor.dtype in [torch.int, torch.long] + ): + vision_positions, rope_deltas = self.model.get_rope_index( + mm_token_type_ids=model_kwargs["mm_token_type_ids"], + image_token_grids=model_kwargs["vision_token_grids"], + image_token_offsets=model_kwargs["vision_token_offsets"], + image_token_lengths=model_kwargs["vision_token_lengths"], + attention_mask=model_kwargs.get("attention_mask"), + ) + self.model.rope_deltas = rope_deltas + else: + vision_positions = text_positions.unsqueeze(0).expand(3, -1, -1) + self.model.rope_deltas = torch.zeros( + inputs_tensor.shape[0], 1, dtype=torch.long, device=inputs_tensor.device + ) - def set_input_embeddings(self, value: nn.Module) -> None: - self.model.set_input_embeddings(value) - vocab_size = getattr(value, "num_embeddings", None) - self.config.vocab_size = vocab_size - self.model.config.vocab_size = vocab_size - self.model.text_model.config.vocab_size = vocab_size - if self.lm_head.weight.shape[0] != vocab_size: - self.lm_head = nn.Linear(self.config.hidden_size, vocab_size, bias=False) - self.lm_head.weight = self.model.text_model.embed_tokens.weight + return torch.cat([text_positions[None, ...], vision_positions], dim=0) + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: torch.LongTensor | None = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + position_ids = model_kwargs.pop("position_ids", None) + input_ids, model_kwargs = super()._expand_inputs_for_generation( + expand_size=expand_size, + is_encoder_decoder=is_encoder_decoder, + input_ids=input_ids, + **model_kwargs, + ) + if position_ids is not None: + dim = 1 if position_ids.ndim == 3 else 0 + model_kwargs["position_ids"] = position_ids.repeat_interleave(expand_size, dim=dim) + return input_ids, model_kwargs + + def get_input_embeddings(self) -> nn.Module: + return self.model.get_input_embeddings() __all__ = [ "IsaacConfig", + "IsaacTextConfig", + "IsaacTextModel", + "IsaacVisionConfig", "IsaacModel", "IsaacPreTrainedModel", # noqa: F822 "IsaacForConditionalGeneration", diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 81688aa74144..93e991702649 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -4,7 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved. +# Copyright 2026 Perceptron, Inc and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,320 +18,254 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import copy +import re from typing import Any from ...feature_extraction_utils import BatchFeature -from ...processing_utils import ProcessorMixin -from ...utils import TensorType -from ...utils.import_utils import is_torch_available, is_vision_available -from .configuration_isaac import IsaacConfig -from .modeling_isaac import ModalityType +from ...image_utils import ImageInput, make_nested_list_of_images +from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...utils import TensorType, auto_docstring +from ...utils.import_utils import is_torch_available +from .modeling_isaac import BoundingBox, SinglePoint if is_torch_available(): import torch -if is_vision_available(): - from PIL.Image import Image -else: - Image = None - -class IsaacProcessor(ProcessorMixin): - """Processor that pairs the Isaac image processor with the Qwen2 tokenizer. - Args: - image_processor: Vision preprocessor (fast) used for patch extraction. - tokenizer: Qwen2 tokenizer instance. - vision_token (str, optional): Placeholder token marking image locations. Defaults to "". - max_sequence_length (int, optional): Maximum combined text+vision tokens kept. Defaults to 16384. - rescale_factor (float, optional): Image rescale factor; defaults to 1/255. - config (IsaacConfig | dict, optional): If provided, overrides processor defaults from the model config. +class IsaacProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": True, + "return_attention_mask": True, + }, + } + + +_POINT_OR_BOX_TAG = re.compile( + r"<(?Ppoint|point_box)(?P[^>]*)>(?P[\s\S]*?)", re.IGNORECASE +) +_ATTR_RE = re.compile(r"(\w+)\s*=\s*(?:\"([^\"]*)\"|([^\s>]+))") +_COORD_RE = re.compile(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)") + + +def _maybe_float(value: str | None) -> float | None: + if value is None: + return None + try: + return float(value) + except ValueError: + return None + + +def _parse_attrs(attr_text: str) -> dict[str, str]: + attrs = {} + for match in _ATTR_RE.finditer(attr_text or ""): + key = match.group(1) + value = match.group(2) or match.group(3) or "" + attrs[key] = value + return attrs + + +def _parse_point_body(body: str, mention: str | None = None, t: str | None = None) -> SinglePoint: + match = _COORD_RE.search(body) + if not match: + raise ValueError(f"Malformed tag: {body!r}") + x, y = int(match.group(1)), int(match.group(2)) + return SinglePoint(x=x, y=y, mention=mention, t=_maybe_float(t)) + + +def _parse_box_body(body: str, mention: str | None = None, t: str | None = None) -> BoundingBox: + coords = list(_COORD_RE.finditer(body)) + if len(coords) < 2: + raise ValueError(f"Malformed tag: {body!r}") + + top_left = SinglePoint(x=int(coords[0].group(1)), y=int(coords[0].group(2))) + bottom_right = SinglePoint(x=int(coords[1].group(1)), y=int(coords[1].group(2))) + return BoundingBox(top_left=top_left, bottom_right=bottom_right, mention=mention, t=_maybe_float(t)) + + +def clean_text_and_extract_points( + text: str, + expected: str | None = None, +) -> tuple[str, list[SinglePoint | BoundingBox]]: + results = [] + for match in _POINT_OR_BOX_TAG.finditer(text or ""): + tag = match.group("tag").lower() + attrs = _parse_attrs(match.group("attrs")) + mention = attrs.get("mention") + t = attrs.get("t") + if tag == "point": + if expected not in (None, "point"): + continue + results.append(_parse_point_body(match.group("body"), mention=mention, t=t)) + else: + if expected not in (None, "box"): + continue + results.append(_parse_box_body(match.group("body"), mention=mention, t=t)) - Returns: - BatchFeature: Top-level batched text and vision tensors. - """ + clean_text = re.sub(r"\s+", " ", _POINT_OR_BOX_TAG.sub("", text or "")).strip() + return clean_text, results - attributes = ["image_processor", "tokenizer"] - image_processor_class = ("IsaacImageProcessorFast",) - tokenizer_class = ("Qwen2Tokenizer",) - pad_token_id = 151643 +@auto_docstring +class IsaacProcessor(ProcessorMixin): def __init__( self, image_processor, tokenizer, - *, + chat_template: str | dict[str, str] | None = None, vision_token: str = "", max_sequence_length: int = 16384, rescale_factor: float | None = None, - config: IsaacConfig | dict | None = None, - ) -> None: - if isinstance(config, dict): - config = IsaacConfig(**config) - - if config is not None: - vision_token = config.vision_token - max_sequence_length = config.max_sequence_length - rescale_factor = config.vision_rescale_factor - - resolved_rescale_factor = float(rescale_factor) if rescale_factor is not None else float(1 / 255) - if config is not None: - config.vision_rescale_factor = resolved_rescale_factor + ): + """ + Args: + chat_template (`str` or `dict[str, str]`, *optional*): + Chat template override forwarded to [`~processing_utils.ProcessorMixin`]. + vision_token (`str`, *optional*, defaults to `""`): + Placeholder token used inside text prompts to mark image positions. + max_sequence_length (`int`, *optional*, defaults to 16384): + Maximum packed multimodal sequence length produced by the processor. + rescale_factor (`float`, *optional*): + Deprecated compatibility argument accepted for backward compatibility. + """ + if chat_template is None: + chat_template = getattr(tokenizer, "chat_template", None) self.image_processor = image_processor - super().__init__(image_processor, tokenizer) - - text_pad_token_id = getattr(self.tokenizer, "pad_token_id", None) - image_pad_token_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>") + super().__init__(image_processor, tokenizer, chat_template=chat_template) + self.text_pad_token_id = self.pad_token_id = tokenizer.pad_token_id + self.image_pad_token_id = tokenizer.image_pad_token_id + self.image_token = tokenizer.image_pad_token + self.image_token_id = self.image_pad_token_id - self.text_pad_token_id = int(text_pad_token_id) - self.image_pad_token_id = int(image_pad_token_id) - self.pad_token_id = self.text_pad_token_id - - self.current_processor = self.image_processor - self.config = config - self.chat_template = getattr(self.tokenizer, "chat_template", None) self.vision_token = vision_token self.max_sequence_length = max_sequence_length def _build_batch( self, text: str | list[str], - images: Image | list[Image] | None = None, + images: ImageInput | None = None, + text_kwargs: dict[str, Any] | None = None, ) -> dict[str, torch.Tensor | None]: + text_kwargs = copy.deepcopy(text_kwargs) if text_kwargs is not None else {} + truncation = text_kwargs.pop("truncation", None) + max_length = text_kwargs.pop("max_length", None) + padding = text_kwargs.pop("padding", True) + padding_side = text_kwargs.pop("padding_side", "left") + return_attention_mask = text_kwargs.pop("return_attention_mask", True) + pad_to_multiple_of = text_kwargs.pop("pad_to_multiple_of", None) + text_kwargs.pop("return_tensors", None) + text_kwargs.setdefault("add_special_tokens", False) + texts = [text] if isinstance(text, str) else text if images is None: - pairs = ((text_value, None) for text_value in texts) - elif isinstance(images, list) and len(images) == len(texts): - if not images: - images_list = [] - elif isinstance(images[0], list): - images_list = images - else: - images_list = [[image] for image in images] - pairs = zip(texts, images_list, strict=True) + batched_images = [[] for _ in texts] else: - pairs = ( - ( - text_value, - None - if text_value.count(self.vision_token) == 0 - else images - if isinstance(images, list) - else [images], + fetched_images = self.image_processor.fetch_images(images) + batched_images = make_nested_list_of_images(fetched_images) + if len(batched_images) != len(texts): + num_images_in_text = [text_value.count(self.vision_token) for text_value in texts] + num_images_in_images = [len(sample_images) for sample_images in batched_images] + add_message = "" + if sum(num_images_in_text) == sum(num_images_in_images): + add_message = " Make sure to pass your images as a nested list, where each sub-list holds images for one text sample." + + raise ValueError( + f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(texts)}).{add_message}" ) - for text_value in texts - ) + + pairs = list(zip(texts, batched_images, strict=True)) + image_inputs = self.image_processor(images=batched_images, return_tensors=TensorType.PYTORCH) + vision_token_grids = image_inputs["vision_token_grids"] + vision_segment_lengths = (vision_token_grids[..., 0] // self.image_processor.pixel_shuffle_scale) * ( + vision_token_grids[..., 1] // self.image_processor.pixel_shuffle_scale + ) + vision_token_offsets = torch.zeros_like(vision_segment_lengths) + vision_token_lengths = torch.zeros_like(vision_segment_lengths) sample_input_ids: list[torch.Tensor] = [] - sample_modality: list[torch.Tensor] = [] - sample_position_ids: list[torch.Tensor] = [] - sample_vision_patches: list[list[torch.Tensor]] = [] - sample_vision_grids: list[torch.Tensor] = [] - sample_vision_offsets: list[torch.Tensor] = [] - sample_vision_lengths: list[torch.Tensor] = [] - - for text_value, sample_images in pairs: + expanded_texts = [] + expected_image_lengths_per_sample = [] + + for batch_idx, (text_value, sample_images) in enumerate(pairs): segments = text_value.split(self.vision_token) num_images = len(segments) - 1 - num_provided_images = len(sample_images) if sample_images is not None else 0 + num_provided_images = len(sample_images) if num_images != num_provided_images: raise ValueError( f"IsaacProcessor expects one image per image token, got {num_images} tokens and {num_provided_images} images in sample with text {text_value} " ) - items: list[dict[str, Any]] = [] - total = 0 - for index, segment in enumerate(segments): - if segment: - text_tokens = ( - self.tokenizer.encode(segment, add_special_tokens=False, return_tensors="pt") - .squeeze(0) - .to(torch.long) - ) - segment_length = int(text_tokens.numel()) - items.append({"type": "text", "segment_length": segment_length, "tokens": text_tokens}) - total += segment_length - - if index < num_images: - feature = self.image_processor(images=sample_images[index], return_tensors=TensorType.PYTORCH) - patches = feature["patches"][0].reshape(-1, feature["patches"].shape[-1]) - virtual_pixel_size = feature["virtual_pixel_size"][0].to(torch.long).tolist() - real_pixel_size = feature["real_pixel_size"][0].to(torch.long).tolist() - dims = tuple((virtual_pixel_size + [1, 1, 1])[:3]) - segment_length = int(dims[0] * dims[1] * dims[2]) - items.append( - { - "type": "image", - "segment_length": segment_length, - "dims": dims, - "patches": patches, - "grid": (int(real_pixel_size[1]), int(real_pixel_size[2])), - } - ) - total += segment_length - - start = max(0, total - self.max_sequence_length) - end = total - base_device: torch.device | None = None - input_ids_chunks, modality_chunks, position_chunks = [], [], [] - vision_patches, vision_grids, vision_offsets, vision_lengths = [], [], [], [] - global_offset = 0 - position_offset = 0 - - for item in items: - segment_length = int(item["segment_length"]) - current_window_start = max(start, global_offset) - current_window_end = min(end, global_offset + segment_length) - has_overlap = current_window_end > current_window_start - - if has_overlap and base_device is None: - base_device = item["patches"].device if item["type"] == "image" else item["tokens"].device - - if has_overlap: - segment_local_start = int(current_window_start - global_offset) - segment_local_end = int(current_window_end - global_offset) - segment_local_indices = torch.arange( - segment_local_start, segment_local_end, device=base_device, dtype=torch.long - ) - segment_kept_length = segment_local_end - segment_local_start - - if item["type"] == "text": - slice_index = segment_local_indices + position_offset - zero_axis = torch.zeros_like(slice_index) - position_chunks.append(torch.stack((slice_index, zero_axis, zero_axis), -1)) - modality_chunks.append( - torch.full( - (segment_kept_length,), ModalityType.text.value, device=base_device, dtype=torch.long - ) - ) - input_ids_chunks.append(item["tokens"].to(base_device)[segment_local_start:segment_local_end]) - position_offset += segment_length - else: - num_pos_slices, grid_height_tokens, grid_width_tokens = item["dims"] - hw = grid_height_tokens * grid_width_tokens - slice_index = (segment_local_indices // hw) + position_offset - rem = segment_local_indices % hw - position_chunks.append( - torch.stack((slice_index, rem // grid_width_tokens, rem % grid_width_tokens), -1) - ) - modality_chunks.append( - torch.full( - (segment_kept_length,), ModalityType.image.value, device=base_device, dtype=torch.long - ) - ) - input_ids_chunks.append( - torch.full( - (segment_kept_length,), self.image_pad_token_id, device=base_device, dtype=torch.long - ) - ) - - vision_patches.append(item["patches"].to(base_device)) - vision_grids.append(item["grid"]) - vision_offsets.append(segment_local_start) - vision_lengths.append(segment_kept_length) - position_offset += int(num_pos_slices) - else: - position_offset += segment_length if item["type"] == "text" else int(item["dims"][0]) - - global_offset += segment_length - - if base_device is None: - base_device = torch.device("cpu") - - sample_input_ids.append( - torch.cat(input_ids_chunks, 0) - if input_ids_chunks - else torch.zeros((0,), device=base_device, dtype=torch.long) - ) - sample_modality.append( - torch.cat(modality_chunks, 0) - if modality_chunks - else torch.zeros((0,), device=base_device, dtype=torch.long) - ) - sample_position_ids.append( - torch.cat(position_chunks, 0) - if position_chunks - else torch.zeros((0, 3), device=base_device, dtype=torch.long) - ) - sample_vision_patches.append(vision_patches) - if vision_patches: - sample_vision_grids.append(torch.tensor(vision_grids, device=base_device, dtype=torch.long)) - sample_vision_offsets.append(torch.tensor(vision_offsets, device=base_device, dtype=torch.long)) - sample_vision_lengths.append(torch.tensor(vision_lengths, device=base_device, dtype=torch.long)) - else: - sample_vision_grids.append(torch.zeros((0, 2), device=base_device, dtype=torch.long)) - sample_vision_offsets.append(torch.zeros((0,), device=base_device, dtype=torch.long)) - sample_vision_lengths.append(torch.zeros((0,), device=base_device, dtype=torch.long)) - - batch_size = len(sample_input_ids) - lengths = [int(sample_input.shape[0]) for sample_input in sample_input_ids] - max_len = max(lengths, default=0) - base_device = next( - (sample_input.device for sample_input in sample_input_ids if sample_input.numel() > 0), - torch.device("cpu"), + expected_image_lengths = [ + int(vision_segment_lengths[batch_idx, image_idx].item()) for image_idx in range(num_images) + ] + expected_image_lengths_per_sample.append(expected_image_lengths) + + expanded_text = segments[0] + for image_idx, segment_length in enumerate(expected_image_lengths): + expanded_text += (self.image_token * segment_length) + segments[image_idx + 1] + expanded_texts.append(expanded_text) + + text_inputs = self.tokenizer(expanded_texts, return_tensors=None, **text_kwargs) + self._check_special_mm_tokens(expanded_texts, text_inputs, modalities=["image"]) + + effective_max_length = self.max_sequence_length + if truncation and max_length is not None: + effective_max_length = max_length + + for batch_idx, (expected_image_lengths, sample_input_ids_list) in enumerate( + zip(expected_image_lengths_per_sample, text_inputs["input_ids"], strict=True) + ): + sample_input = torch.tensor(sample_input_ids_list, dtype=torch.long) + image_positions = sample_input.eq(self.image_pad_token_id).nonzero(as_tuple=False).flatten() + image_spans = image_positions.split(expected_image_lengths) if expected_image_lengths else () + image_bounds = [] + + for image_idx, (segment_length, image_span) in enumerate( + zip(expected_image_lengths, image_spans, strict=True) + ): + image_start = int(image_span[0].item()) + image_end = int(image_span[-1].item()) + 1 + image_bounds.append((image_start, image_end)) + total = int(sample_input.shape[0]) + start = max(0, total - effective_max_length) + sample_input_ids.append(sample_input[start:]) + + for image_idx, (image_start, image_end) in enumerate(image_bounds): + kept_start = max(start, image_start) + kept_end = min(total, image_end) + if kept_end > kept_start: + vision_token_offsets[batch_idx, image_idx] = kept_start - image_start + vision_token_lengths[batch_idx, image_idx] = kept_end - kept_start + + text_inputs = self.tokenizer.pad( + {"input_ids": [sample_input.tolist() for sample_input in sample_input_ids]}, + padding=padding, + max_length=max_length if padding == "max_length" else None, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + return_tensors=TensorType.PYTORCH, ) + input_ids = text_inputs["input_ids"] + attention_mask = text_inputs.get("attention_mask") + if attention_mask is None: + attention_mask = input_ids.ne(self.pad_token_id).to(dtype=torch.long) - input_ids = torch.full((batch_size, max_len), self.text_pad_token_id, device=base_device, dtype=torch.long) - attention_mask = torch.zeros((batch_size, max_len), device=base_device, dtype=torch.long) - modality_tensor = torch.full( - (batch_size, max_len), ModalityType.text.value, device=base_device, dtype=torch.long - ) - position_ids = torch.zeros((batch_size, max_len, 3), device=base_device, dtype=torch.long) + mm_token_type_ids = input_ids.eq(self.image_pad_token_id).to(dtype=torch.long) + vision_image_attention_mask = vision_token_lengths.gt(0).to(dtype=torch.long) - for batch_idx, length in enumerate(lengths): - if length == 0: - continue - input_ids[batch_idx, -length:] = sample_input_ids[batch_idx] - attention_mask[batch_idx, -length:] = 1 - modality_tensor[batch_idx, -length:] = sample_modality[batch_idx] - position_ids[batch_idx, -length:] = sample_position_ids[batch_idx] - - image_counts = [len(patches) for patches in sample_vision_patches] - max_images = max(image_counts, default=0) - if max_images == 0: - vision_patches = None - vision_patch_attention_mask = None - vision_token_grids = None - vision_token_offsets = None - vision_token_lengths = None - vision_image_attention_mask = None - else: - first_patch = next((patches[0] for patches in sample_vision_patches if patches), None) - patch_dim = first_patch.shape[-1] - patch_dtype = first_patch.dtype - max_patches = max((patch.shape[0] for patches in sample_vision_patches for patch in patches), default=0) - - vision_patches = torch.zeros( - (batch_size, max_images, max_patches, patch_dim), device=base_device, dtype=patch_dtype - ) - vision_patch_attention_mask = torch.zeros( - (batch_size, max_images, max_patches), device=base_device, dtype=torch.long - ) - vision_token_grids = torch.zeros((batch_size, max_images, 2), device=base_device, dtype=torch.long) - vision_token_offsets = torch.zeros((batch_size, max_images), device=base_device, dtype=torch.long) - vision_token_lengths = torch.zeros((batch_size, max_images), device=base_device, dtype=torch.long) - vision_image_attention_mask = torch.zeros((batch_size, max_images), device=base_device, dtype=torch.long) - - for batch_idx, sample_patches in enumerate(sample_vision_patches): - sample_image_count = len(sample_patches) - if sample_image_count == 0: - continue - vision_token_grids[batch_idx, :sample_image_count] = sample_vision_grids[batch_idx] - vision_token_offsets[batch_idx, :sample_image_count] = sample_vision_offsets[batch_idx] - vision_token_lengths[batch_idx, :sample_image_count] = sample_vision_lengths[batch_idx] - vision_image_attention_mask[batch_idx, :sample_image_count] = 1 - - for image_idx, patches in enumerate(sample_patches): - patch_count = int(patches.shape[0]) - vision_patches[batch_idx, image_idx, :patch_count] = patches - vision_patch_attention_mask[batch_idx, image_idx, :patch_count] = 1 + vision_patches = image_inputs["vision_patches"] + vision_patch_attention_mask = image_inputs["vision_patch_attention_mask"] return { "input_ids": input_ids, "attention_mask": attention_mask, - "position_ids": position_ids, - "modality_tensor": modality_tensor, + "mm_token_type_ids": mm_token_type_ids, "vision_patches": vision_patches, "vision_patch_attention_mask": vision_patch_attention_mask, "vision_token_grids": vision_token_grids, @@ -340,14 +274,46 @@ def _build_batch( "vision_image_attention_mask": vision_image_attention_mask, } + def post_process_generation( + self, + text: str, + expected: str | None = None, + cleanup_and_extract: bool = True, + ) -> str | tuple[str, list[SinglePoint | BoundingBox]]: + if cleanup_and_extract: + return clean_text_and_extract_points(text, expected=expected) + return text + + def post_process_image_text_to_text( + self, + generated_outputs, + skip_special_tokens: bool = True, + cleanup_and_extract: bool = False, + expected: str | None = None, + **kwargs, + ): + generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs) + return [ + self.post_process_generation(text, expected=expected, cleanup_and_extract=cleanup_and_extract) + for text in generated_texts + ] + def __call__( self, text: str | list[str], - images: Image | list[Image] | None = None, + images: ImageInput | None = None, return_tensors: str | TensorType | None = TensorType.PYTORCH, **kwargs, ) -> BatchFeature: - return BatchFeature(data=self._build_batch(text=text, images=images), tensor_type=return_tensors) + output_kwargs = self._merge_kwargs( + IsaacProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + return BatchFeature( + data=self._build_batch(text=text, images=images, text_kwargs=output_kwargs["text_kwargs"]), + tensor_type=return_tensors, + ) __all__ = ["IsaacProcessor"] diff --git a/tests/models/isaac/__init__.py b/tests/models/isaac/__init__.py index 2f76d5676d10..e69de29bb2d1 100644 --- a/tests/models/isaac/__init__.py +++ b/tests/models/isaac/__init__.py @@ -1,13 +0,0 @@ -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/models/isaac/test_image_processing_isaac.py b/tests/models/isaac/test_image_processing_isaac.py new file mode 100644 index 000000000000..1b968087b277 --- /dev/null +++ b/tests/models/isaac/test_image_processing_isaac.py @@ -0,0 +1,417 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +import pytest + +from transformers.testing_utils import ( + require_torch, + require_torch_accelerator, + require_vision, + slow, + torch_device, +) +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + +if is_torchvision_available(): + from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessorFast + + +def _make_dummy_image(size=(32, 32), color=(255, 0, 0)): + return Image.new("RGB", size, color=color) + + +class IsaacImageProcessingTester: + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + min_resolution=30, + max_resolution=80, + do_resize=True, + do_rescale=True, + rescale_factor=1 / 255, + do_normalize=True, + image_mean=None, + image_std=None, + patch_size=16, + max_num_patches=16, + min_num_patches=4, + pixel_shuffle_scale=1, + do_convert_rgb=True, + ): + if image_mean is None: + image_mean = [0.485, 0.456, 0.406] + if image_std is None: + image_std = [0.229, 0.224, 0.225] + + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.patch_size = patch_size + self.max_num_patches = max_num_patches + self.min_num_patches = min_num_patches + self.pixel_shuffle_scale = pixel_shuffle_scale + self.do_convert_rgb = do_convert_rgb + + @property + def patch_dim(self): + return self.num_channels * self.patch_size * self.patch_size + + def prepare_image_processor_dict(self): + return { + "do_resize": self.do_resize, + "do_rescale": self.do_rescale, + "rescale_factor": self.rescale_factor, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "patch_size": self.patch_size, + "max_num_patches": self.max_num_patches, + "min_num_patches": self.min_num_patches, + "pixel_shuffle_scale": self.pixel_shuffle_scale, + "do_convert_rgb": self.do_convert_rgb, + } + + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + images = prepare_image_inputs( + batch_size=self.batch_size, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + num_channels=self.num_channels, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + return [[image] for image in images] + + +@require_torch +@require_vision +class IsaacImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = None + fast_image_processing_class = IsaacImageProcessorFast if is_torchvision_available() else None + test_slow_image_processor = False + + def setUp(self): + super().setUp() + self.image_processor_tester = IsaacImageProcessingTester(self) + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def _assert_output_contract( + self, + encoding, + *, + expected_batch_size=None, + expected_image_slots=None, + expected_patch_dim=None, + ): + self.assertEqual( + set(encoding.keys()), + {"vision_patches", "vision_patch_attention_mask", "vision_token_grids"}, + ) + + vision_patches = encoding["vision_patches"] + vision_patch_attention_mask = encoding["vision_patch_attention_mask"] + vision_token_grids = encoding["vision_token_grids"] + + self.assertEqual(vision_patches.dtype, torch.float32) + self.assertEqual(vision_patch_attention_mask.dtype, torch.long) + self.assertEqual(vision_token_grids.dtype, torch.long) + + if expected_batch_size is not None: + self.assertEqual(vision_patches.shape[0], expected_batch_size) + if expected_image_slots is not None: + self.assertEqual(vision_patches.shape[1], expected_image_slots) + if expected_patch_dim is not None: + self.assertEqual(vision_patches.shape[-1], expected_patch_dim) + + self.assertEqual(tuple(vision_patch_attention_mask.shape), tuple(vision_patches.shape[:-1])) + self.assertEqual(tuple(vision_token_grids.shape), tuple(vision_patches.shape[:2]) + (2,)) + + expected_patch_counts = torch.prod(vision_token_grids, dim=-1) + torch.testing.assert_close(vision_patch_attention_mask.sum(dim=-1), expected_patch_counts) + + padded_patch_rows = vision_patches[vision_patch_attention_mask == 0] + if padded_patch_rows.numel() > 0: + self.assertTrue(torch.all(padded_patch_rows == 0)) + + def _assert_encoding_close(self, eager_encoding, compiled_encoding): + torch.testing.assert_close( + eager_encoding["vision_patches"], + compiled_encoding["vision_patches"], + atol=1e-4, + rtol=1e-4, + ) + torch.testing.assert_close( + eager_encoding["vision_patch_attention_mask"], + compiled_encoding["vision_patch_attention_mask"], + ) + torch.testing.assert_close(eager_encoding["vision_token_grids"], compiled_encoding["vision_token_grids"]) + + def test_image_processor_properties(self): + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processor, "do_resize")) + self.assertTrue(hasattr(image_processor, "do_rescale")) + self.assertTrue(hasattr(image_processor, "rescale_factor")) + self.assertTrue(hasattr(image_processor, "do_normalize")) + self.assertTrue(hasattr(image_processor, "image_mean")) + self.assertTrue(hasattr(image_processor, "image_std")) + self.assertTrue(hasattr(image_processor, "patch_size")) + self.assertTrue(hasattr(image_processor, "max_num_patches")) + self.assertTrue(hasattr(image_processor, "min_num_patches")) + self.assertTrue(hasattr(image_processor, "pixel_shuffle_scale")) + self.assertTrue(hasattr(image_processor, "do_convert_rgb")) + + def test_call_pil(self): + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + + for image in image_inputs: + self.assertIsInstance(image[0], Image.Image) + + single_output = image_processor(image_inputs[0], return_tensors="pt") + self._assert_output_contract( + single_output, + expected_batch_size=1, + expected_image_slots=1, + expected_patch_dim=self.image_processor_tester.patch_dim, + ) + + batched_output = image_processor(image_inputs, return_tensors="pt") + self._assert_output_contract( + batched_output, + expected_batch_size=self.image_processor_tester.batch_size, + expected_image_slots=1, + expected_patch_dim=self.image_processor_tester.patch_dim, + ) + + def test_call_numpy(self): + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + + for image in image_inputs: + self.assertIsInstance(image[0], np.ndarray) + + single_output = image_processor(image_inputs[0], return_tensors="pt") + self._assert_output_contract( + single_output, + expected_batch_size=1, + expected_image_slots=1, + expected_patch_dim=self.image_processor_tester.patch_dim, + ) + + batched_output = image_processor(image_inputs, return_tensors="pt") + self._assert_output_contract( + batched_output, + expected_batch_size=self.image_processor_tester.batch_size, + expected_image_slots=1, + expected_patch_dim=self.image_processor_tester.patch_dim, + ) + + def test_call_pytorch(self): + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + + for image in image_inputs: + self.assertIsInstance(image[0], torch.Tensor) + + single_output = image_processor(image_inputs[0], return_tensors="pt") + self._assert_output_contract( + single_output, + expected_batch_size=1, + expected_image_slots=1, + expected_patch_dim=self.image_processor_tester.patch_dim, + ) + + batched_output = image_processor(image_inputs, return_tensors="pt") + self._assert_output_contract( + batched_output, + expected_batch_size=self.image_processor_tester.batch_size, + expected_image_slots=1, + expected_patch_dim=self.image_processor_tester.patch_dim, + ) + + @unittest.skip(reason="Isaac image processor 4-channel coverage is not defined yet") + def test_call_numpy_4_channels(self): + pass + + def test_nested_multi_image_batch_preserves_grids_and_padding(self): + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class( + **{ + **self.image_processor_dict, + "do_resize": False, + "patch_size": 16, + "max_num_patches": 64, + "min_num_patches": 1, + "pixel_shuffle_scale": 1, + } + ) + image_inputs = [ + [_make_dummy_image(size=(32, 32), color=(255, 0, 0))], + [ + _make_dummy_image(size=(48, 32), color=(0, 255, 0)), + _make_dummy_image(size=(32, 48), color=(0, 0, 255)), + ], + ] + + encoding = image_processor(image_inputs, return_tensors="pt") + self._assert_output_contract( + encoding, + expected_batch_size=2, + expected_image_slots=2, + expected_patch_dim=768, + ) + + expected_grids = torch.tensor( + [ + [[2, 2], [0, 0]], + [[2, 3], [3, 2]], + ], + dtype=torch.long, + ) + expected_patch_counts = torch.tensor( + [ + [4, 0], + [6, 6], + ], + dtype=torch.long, + ) + + torch.testing.assert_close(encoding["vision_token_grids"], expected_grids) + torch.testing.assert_close(encoding["vision_patch_attention_mask"].sum(dim=-1), expected_patch_counts) + + def test_all_empty_images_returns_zero_sized_tensors(self): + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + encoding = image_processor([[], []], return_tensors="pt") + + self.assertEqual( + set(encoding.keys()), {"vision_patches", "vision_patch_attention_mask", "vision_token_grids"} + ) + self.assertEqual(tuple(encoding["vision_patches"].shape), (2, 0, 0, 0)) + self.assertEqual(tuple(encoding["vision_patch_attention_mask"].shape), (2, 0, 0)) + self.assertEqual(tuple(encoding["vision_token_grids"].shape), (2, 0, 2)) + self.assertEqual(encoding["vision_patches"].dtype, torch.float32) + self.assertEqual(encoding["vision_patch_attention_mask"].dtype, torch.long) + self.assertEqual(encoding["vision_token_grids"].dtype, torch.long) + + def test_do_resize_false_requires_patch_divisibility(self): + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class( + **{ + **self.image_processor_dict, + "do_resize": False, + "patch_size": 16, + } + ) + + with self.assertRaisesRegex(ValueError, "must be divisible by patch_size"): + image_processor([[_make_dummy_image(size=(31, 32))]], return_tensors="pt") + + def test_pixel_shuffle_scale_requires_divisible_token_grid(self): + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class( + **{ + **self.image_processor_dict, + "do_resize": False, + "patch_size": 16, + "pixel_shuffle_scale": 2, + } + ) + + with self.assertRaisesRegex(ValueError, "must be divisible by pixel_shuffle_scale"): + image_processor([[_make_dummy_image(size=(32, 16))]], return_tensors="pt") + + def test_cast_dtype_device(self): + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + + encoding = image_processor(image_inputs, return_tensors="pt") + self.assertEqual(encoding["vision_patches"].device, torch.device("cpu")) + self.assertEqual(encoding["vision_patches"].dtype, torch.float32) + self.assertEqual(encoding["vision_patch_attention_mask"].dtype, torch.long) + self.assertEqual(encoding["vision_token_grids"].dtype, torch.long) + + encoding = image_processor(image_inputs, return_tensors="pt").to(torch.float16) + self.assertEqual(encoding["vision_patches"].device, torch.device("cpu")) + self.assertEqual(encoding["vision_patches"].dtype, torch.float16) + self.assertEqual(encoding["vision_patch_attention_mask"].dtype, torch.long) + self.assertEqual(encoding["vision_token_grids"].dtype, torch.long) + + encoding = image_processor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16) + self.assertEqual(encoding["vision_patches"].device, torch.device("cpu")) + self.assertEqual(encoding["vision_patches"].dtype, torch.bfloat16) + self.assertEqual(encoding["vision_patch_attention_mask"].dtype, torch.long) + self.assertEqual(encoding["vision_token_grids"].dtype, torch.long) + + with self.assertRaises(TypeError): + _ = image_processor(image_inputs, return_tensors="pt").to(torch.bfloat16, "cpu") + + encoding = image_processor(image_inputs, return_tensors="pt") + encoding.update({"input_ids": torch.LongTensor([[1, 2, 3], [4, 5, 6]])}) + encoding = encoding.to(torch.float16) + + self.assertEqual(encoding["vision_patches"].device, torch.device("cpu")) + self.assertEqual(encoding["vision_patches"].dtype, torch.float16) + self.assertEqual(encoding["vision_patch_attention_mask"].dtype, torch.long) + self.assertEqual(encoding["vision_token_grids"].dtype, torch.long) + self.assertEqual(encoding["input_ids"].dtype, torch.long) + + @slow + @require_torch_accelerator + @require_vision + @pytest.mark.torch_compile_test + def test_can_compile_fast_image_processor(self): + if self.fast_image_processing_class is None: + self.skipTest("Skipping compilation test as fast image processor is not defined") + + torch.compiler.reset() + input_image = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8) + image_processor = self.fast_image_processing_class(**self.image_processor_dict) + output_eager = image_processor(input_image, device=torch_device, return_tensors="pt") + + image_processor = torch.compile(image_processor, mode="reduce-overhead") + output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt") + self._assert_encoding_close(output_eager, output_compiled) diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py index 6f7fa5984620..df4ee148e6c3 100644 --- a/tests/models/isaac/test_modeling_isaac.py +++ b/tests/models/isaac/test_modeling_isaac.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,11 +17,10 @@ import base64 import io import os -import re import unittest -from collections import namedtuple from functools import lru_cache from pathlib import Path +from unittest.mock import patch import pytest from huggingface_hub import is_offline_mode @@ -67,74 +66,6 @@ if is_torch_available(): import torch -SinglePoint = namedtuple("SinglePoint", ["x", "y", "mention", "t"], defaults=(None, None)) -BoundingBox = namedtuple( - "BoundingBox", - ["top_left", "bottom_right", "mention", "t"], - defaults=(None, None), -) - -_POINT_OR_BOX_TAG = re.compile( - r"<(?Ppoint|point_box)(?P[^>]*)>(?P[\s\S]*?)", re.IGNORECASE -) -_ATTR_RE = re.compile(r"(\w+)\s*=\s*(?:\"([^\"]*)\"|([^\s>]+))") -_COORD_RE = re.compile(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)") - - -def _maybe_float(val): - if val is None: - return None - try: - return float(val) - except ValueError: - return None - - -def _parse_attrs(attr_text: str) -> dict: - attrs = {} - for match in _ATTR_RE.finditer(attr_text or ""): - key = match.group(1) - val = match.group(2) or match.group(3) or "" - attrs[key] = val - return attrs - - -def _parse_point_body(body: str, mention=None, t=None): - match = _COORD_RE.search(body) - if not match: - raise ValueError(f"Malformed tag: {body!r}") - x, y = int(match.group(1)), int(match.group(2)) - return SinglePoint(x, y, mention, _maybe_float(t)) - - -def _parse_box_body(body: str, mention=None, t=None): - coords = list(_COORD_RE.finditer(body)) - if len(coords) < 2: - raise ValueError(f"Malformed tag: {body!r}") - x1, y1 = int(coords[0].group(1)), int(coords[0].group(2)) - x2, y2 = int(coords[1].group(1)), int(coords[1].group(2)) - return BoundingBox(SinglePoint(x1, y1, None, None), SinglePoint(x2, y2, None, None), mention, _maybe_float(t)) - - -def extract_points(text: str, expected: str | None = None): - """Minimal parser for Isaac pointing tags used in tests.""" - - results = [] - for match in _POINT_OR_BOX_TAG.finditer(text or ""): - tag = match.group("tag").lower() - attrs = _parse_attrs(match.group("attrs")) - mention = attrs.get("mention") - t = attrs.get("t") - if tag == "point": - if expected not in (None, "point"): - continue - results.append(_parse_point_body(match.group("body"), mention=mention, t=t)) - elif tag == "point_box": - if expected not in (None, "box"): - continue - results.append(_parse_box_body(match.group("body"), mention=mention, t=t)) - return results - BASE_MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1-Base") MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1") @@ -311,14 +242,12 @@ def create_isaac_processor( def to_model_multimodal_inputs(processor_output, device): keys = ( - "modality_tensor", - "position_ids", + "mm_token_type_ids", "vision_patches", "vision_patch_attention_mask", "vision_token_grids", "vision_token_offsets", "vision_token_lengths", - "vision_image_attention_mask", ) return { key: (value.to(device) if isinstance(value, torch.Tensor) else value) @@ -476,10 +405,10 @@ def get_config(self): text_config=self.text_config, vision_config=self.vision_config, ) - # Rely on vanilla SDPA so the tests do not need flash attention. - config._attn_implementation = "sdpa" - config.text_config._attn_implementation = "sdpa" - config.vision_attn_implementation = "sdpa" + # Rely on eager attention so output_attentions tests remain compatible without flash attention. + config._attn_implementation = "eager" + config.text_config._attn_implementation = "eager" + config.vision_attn_implementation = "eager" return config def prepare_config_and_inputs(self): @@ -495,12 +424,22 @@ def prepare_config_and_inputs(self): def prepare_config_and_inputs_for_common(self): config, input_ids, attention_mask, labels = self.prepare_config_and_inputs() - position_ids = torch.arange(self.seq_length, device=torch_device).view(1, -1) - position_ids = position_ids.expand(self.batch_size, -1).unsqueeze(2).expand(-1, -1, 3) + position_ids = torch.arange(self.seq_length, device=torch_device).view(1, -1).expand(self.batch_size, -1) + patch_size = self.vision_config["patch_size"] + patch_dim = self.vision_config["num_channels"] * patch_size * patch_size + num_image_patches = 4 inputs_dict = { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, + "pixel_values": torch.randn( + (self.batch_size, 1, num_image_patches, patch_dim), device=torch_device, dtype=torch.float32 + ), + "image_patch_attention_mask": torch.ones( + (self.batch_size, 1, num_image_patches), device=torch_device, dtype=torch.long + ), + "image_token_grids": torch.tensor([[[2, 2]]] * self.batch_size, device=torch_device, dtype=torch.long), + "image_attention_mask": torch.ones((self.batch_size, 1), device=torch_device, dtype=torch.long), } if labels is not None: inputs_dict["labels"] = labels @@ -558,18 +497,69 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): def test_retain_grad_hidden_states_attentions(self): pass - def test_model_forward(self): + def test_text_only_forward_ignores_metadata_without_vision_patches(self): config, input_ids, attention_mask, _ = self.model_tester.prepare_config_and_inputs() model = IsaacModel(config) model.to(torch_device) model.eval() + + vision_token_grids = torch.zeros((self.model_tester.batch_size, 0, 2), device=torch_device, dtype=torch.long) + with torch.no_grad(): - result = model(input_ids=input_ids, attention_mask=attention_mask) + reference = model(input_ids=input_ids, attention_mask=attention_mask) + + with patch.object(model, "get_image_features", wraps=model.get_image_features) as mock_get_image_features: + with torch.no_grad(): + result = model( + input_ids=input_ids, + attention_mask=attention_mask, + vision_token_grids=vision_token_grids, + ) - self.assertEqual( - result.last_hidden_state.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, config.hidden_size), + mock_get_image_features.assert_not_called() + torch.testing.assert_close(result.last_hidden_state, reference.last_hidden_state) + + def test_get_image_features_pooler_output_is_scatter_ready(self): + config = self.model_tester.get_config() + model = IsaacModel(config) + model.to(torch_device) + model.eval() + + patch_size = self.model_tester.vision_config["patch_size"] + patch_dim = self.model_tester.vision_config["num_channels"] * patch_size * patch_size + pixel_values = torch.randn((2, 2, 4, patch_dim), device=torch_device, dtype=torch.float32) + image_token_grids = torch.tensor( + [[[2, 2], [2, 2]], [[2, 2], [0, 0]]], + device=torch_device, + dtype=torch.long, ) + image_patch_attention_mask = torch.ones((2, 2, 4), device=torch_device, dtype=torch.long) + image_attention_mask = torch.tensor([[1, 1], [1, 0]], device=torch_device, dtype=torch.long) + image_token_offsets = torch.tensor([[1, 0], [2, 0]], device=torch_device, dtype=torch.long) + image_token_lengths = torch.tensor([[2, 1], [1, 0]], device=torch_device, dtype=torch.long) + + with torch.no_grad(): + outputs = model.get_image_features( + pixel_values=pixel_values, + image_token_grids=image_token_grids, + image_patch_attention_mask=image_patch_attention_mask, + image_attention_mask=image_attention_mask, + image_token_offsets=image_token_offsets, + image_token_lengths=image_token_lengths, + return_dict=True, + ) + + expected = torch.cat( + ( + outputs.last_hidden_state[0, 0, 1:3], + outputs.last_hidden_state[0, 1, 0:1], + outputs.last_hidden_state[1, 0, 2:3], + ), + dim=0, + ) + + self.assertEqual(outputs.pooler_output.ndim, 2) + torch.testing.assert_close(outputs.pooler_output, expected) def test_for_conditional_generation(self): config, input_ids, attention_mask, labels = self.model_tester.prepare_config_and_inputs() @@ -592,12 +582,15 @@ def test_isaac_for_conditional_generation_initialization(self): self.assertTrue(hasattr(model, "model")) self.assertTrue(hasattr(model, "lm_head")) - self.assertTrue(hasattr(model.model, "vision_embedding")) + self.assertTrue(hasattr(model.model, "vision_tower")) + self.assertTrue(hasattr(model.model, "multimodal_projector")) - input_ids = torch.randint(0, config.vocab_size, (1, 10), device=torch_device, dtype=torch.long) + input_vocab_size = model.get_input_embeddings().num_embeddings + output_vocab_size = model.get_output_embeddings().out_features + input_ids = torch.randint(0, input_vocab_size, (1, 10), device=torch_device, dtype=torch.long) with torch.no_grad(): outputs = model(input_ids=input_ids, return_dict=True) - self.assertEqual(outputs.logits.shape, (1, 10, config.vocab_size)) + self.assertEqual(outputs.logits.shape, (1, 10, output_vocab_size)) def test_isaac_for_conditional_generation_loss_and_generate_flag(self): config = self.model_tester.get_config() @@ -605,13 +598,42 @@ def test_isaac_for_conditional_generation_loss_and_generate_flag(self): self.assertTrue(model.can_generate()) batch_size, seq_len = 1, 8 - input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=torch_device) - labels = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=torch_device) + input_vocab_size = model.get_input_embeddings().num_embeddings + output_vocab_size = model.get_output_embeddings().out_features + input_ids = torch.randint(0, input_vocab_size, (batch_size, seq_len), device=torch_device) + labels = torch.randint(0, output_vocab_size, (batch_size, seq_len), device=torch_device) with torch.no_grad(): outputs = model(input_ids=input_ids, labels=labels, return_dict=True) self.assertIsNotNone(outputs.loss) self.assertEqual(outputs.loss.ndim, 0) - self.assertEqual(outputs.logits.shape, (batch_size, seq_len, config.vocab_size)) + self.assertEqual(outputs.logits.shape, (batch_size, seq_len, output_vocab_size)) + + +@require_torch +class IsaacPixelShufflePaddedTest(unittest.TestCase): + def test_pixel_shuffle_padded_matches_reference_no_attention_mask(self): + x = torch.arange(2 * 16 * 4, device=torch_device, dtype=torch.float32).view(2, 16, 4) + token_grids = torch.tensor([[4, 4], [2, 4]], device=torch_device, dtype=torch.long) + expected_hidden, expected_mask, expected_lengths = _pixel_shuffle_reference(x, token_grids, scale_factor=2) + + hidden = pixel_shuffle_padded(x=x, token_grids=token_grids, scale_factor=2) + + torch.testing.assert_close(hidden, expected_hidden) + + def test_pixel_shuffle_padded_raises_on_non_divisible_grid(self): + x = torch.randn(1, 15, 8, device=torch_device) + token_grids = torch.tensor([[3, 5]], device=torch_device, dtype=torch.long) + + with pytest.raises(ValueError, match="divisible"): + pixel_shuffle_padded(x=x, token_grids=token_grids, scale_factor=2) + + def test_pixel_shuffle_padded_zero_grid(self): + x = torch.randn(1, 4, 8, device=torch_device) + token_grids = torch.tensor([[0, 0]], device=torch_device, dtype=torch.long) + + hidden = pixel_shuffle_padded(x=x, token_grids=token_grids, scale_factor=2) + + self.assertEqual(hidden.shape, (1, 0, 32)) @require_torch @@ -782,7 +804,7 @@ def setUp(self): def _generate_from_messages(self, messages, images, num_tokens=None): prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() - processor_output = self.processor(text=prompt, images=images, return_tensors="pt") + processor_output = self.processor(text=prompt, images=images or None, return_tensors="pt") input_ids = processor_output["input_ids"].to(self.device) attention_mask = processor_output.get("attention_mask") if attention_mask is None: @@ -1006,7 +1028,8 @@ def test_batched_generation_matches_individual(self): pad_id = getattr(self.processor, "pad_token_id", 0) per_sample_outputs = [ - self.processor(text=prompt, images=imgs, return_tensors="pt") for prompt, imgs in zip(prompts, images_list) + self.processor(text=prompt, images=imgs or None, return_tensors="pt") + for prompt, imgs in zip(prompts, images_list) ] batch_outputs = self.processor(text=prompts, images=images_list, return_tensors="pt") batch_input_ids = batch_outputs["input_ids"] @@ -1023,26 +1046,19 @@ def test_batched_generation_matches_individual(self): torch.testing.assert_close(batch_ids[-single_len:], single_ids) - batch_modality_row = batch_packed["modality_tensor"][i] + batch_modality_row = batch_packed["mm_token_type_ids"][i] expected_modality = torch.full( (max_length,), batch_modality_row[-1].item(), dtype=batch_modality_row.dtype, device=batch_modality_row.device, ) - expected_modality[-single_len:] = single_packed["modality_tensor"].squeeze(0) + expected_modality[-single_len:] = single_packed["mm_token_type_ids"].squeeze(0) torch.testing.assert_close(batch_modality_row, expected_modality) - batch_positions_row = batch_packed["position_ids"][i] - expected_positions = torch.zeros( - (max_length, 3), dtype=batch_positions_row.dtype, device=batch_positions_row.device - ) - expected_positions[-single_len:] = single_packed["position_ids"].squeeze(0) - torch.testing.assert_close(batch_positions_row, expected_positions) - if single_packed["vision_patches"] is not None: - expected_image_count = int(single_packed["vision_image_attention_mask"].sum().item()) - batch_image_count = int(batch_packed["vision_image_attention_mask"][i].sum().item()) + expected_image_count = int(single_packed["vision_token_lengths"].gt(0).sum().item()) + batch_image_count = int(batch_packed["vision_token_lengths"][i].gt(0).sum().item()) assert batch_image_count == expected_image_count if expected_image_count > 0: torch.testing.assert_close( @@ -1072,7 +1088,6 @@ def test_batched_generation_matches_individual(self): assert batch_packed["vision_token_grids"] is not None assert batch_packed["vision_token_offsets"] is not None assert batch_packed["vision_token_lengths"] is not None - assert batch_packed["vision_image_attention_mask"] is not None batch_texts = self._generate_batch(prompts, images_list, num_tokens=100) assert len(batch_texts) == len(single_texts) == 3 @@ -1144,7 +1159,7 @@ def test_hf_generate_box_points(self): generated_ids = outputs.sequences hf_generated_tail = generated_ids[:, prompt_len:] hf_generated_text = self.tokenizer.decode(hf_generated_tail[0], skip_special_tokens=True) - points = extract_points(hf_generated_text) + clean_text, points = self.processor.post_process_generation(hf_generated_text, expected="box") assert len(points) == 1 first_point = points[0] assert first_point.top_left.x < first_point.bottom_right.x diff --git a/tests/models/isaac/test_post_processing_isaac.py b/tests/models/isaac/test_post_processing_isaac.py new file mode 100644 index 000000000000..613ce3a9732f --- /dev/null +++ b/tests/models/isaac/test_post_processing_isaac.py @@ -0,0 +1,102 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Isaac processor post-processing helpers.""" + +import pytest + +from transformers import PythonBackend +from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessorFast +from transformers.models.isaac.processing_isaac import IsaacProcessor +from transformers.testing_utils import require_torch + + +class SimpleIsaacTokenizer(PythonBackend): + vocab_files_names = {} + model_input_names = ["input_ids"] + + def __init__(self): + self._vocab = { + "": 0, + "": 1, + "": 2, + "": 3, + "": 4, + } + self._ids_to_tokens = {idx: tok for tok, idx in self._vocab.items()} + super().__init__( + bos_token="", + eos_token="", + pad_token="", + unk_token="", + extra_special_tokens=[""], + model_max_length=512, + ) + + def get_vocab(self): + return dict(self._vocab) + + def _tokenize(self, text): + clean = text.replace("\n", " ").strip() + if not clean: + return [] + return [token for token in clean.split(" ") if token] + + def _convert_token_to_id(self, token): + if token not in self._vocab: + next_id = len(self._vocab) + self._vocab[token] = next_id + self._ids_to_tokens[next_id] = token + return self._vocab[token] + + def _convert_id_to_token(self, index): + return self._ids_to_tokens.get(index, self.unk_token) + + @property + def vocab_size(self) -> int: + return len(self._vocab) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + if token_ids_1 is not None: + token_ids_0 = token_ids_0 + token_ids_1 + return [self.bos_token_id] + list(token_ids_0) + [self.eos_token_id] + + def save_vocabulary(self, save_directory, filename_prefix=None): + return () + + +def _make_processor(): + return IsaacProcessor(image_processor=IsaacImageProcessorFast(), tokenizer=SimpleIsaacTokenizer()) + + +@require_torch +def test_post_process_generation_extracts_boxes_and_cleans_text(): + processor = _make_processor() + + generated_text = ( + "No, it is not safe to cross the street. " + '(808, 247), (863, 386)' + ) + + clean_text, annotations = processor.post_process_generation(generated_text) + + assert clean_text == "No, it is not safe to cross the street." + assert len(annotations) == 1 + box = annotations[0] + assert box.mention == "traffic light" + assert box.t == pytest.approx(0.5) + assert box.top_left.x == 808 + assert box.top_left.y == 247 + assert box.bottom_right.x == 863 + assert box.bottom_right.y == 386 diff --git a/tests/models/isaac/test_processing_isaac.py b/tests/models/isaac/test_processing_isaac.py index 944ab62645ad..659ce2bfbb07 100644 --- a/tests/models/isaac/test_processing_isaac.py +++ b/tests/models/isaac/test_processing_isaac.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ """Testing suite for the Isaac processor.""" import os +import re import unittest from pathlib import Path @@ -23,13 +24,12 @@ from huggingface_hub import is_offline_mode from transformers import IsaacConfig, PythonBackend -from transformers.image_processing_utils import ImageProcessingMixin from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessorFast -from transformers.models.isaac.modeling_isaac import ModalityType from transformers.models.isaac.processing_isaac import IsaacProcessor from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_vision_available -from transformers.utils.generic import TensorType + +from ...test_processing_common import ProcessorTesterMixin if is_vision_available(): @@ -38,6 +38,19 @@ Image = None +ISAAC_OUTPUT_KEYS = { + "input_ids", + "attention_mask", + "mm_token_type_ids", + "vision_patches", + "vision_patch_attention_mask", + "vision_token_grids", + "vision_token_offsets", + "vision_token_lengths", + "vision_image_attention_mask", +} + + class SimpleIsaacTokenizer(PythonBackend): vocab_files_names = {} model_input_names = ["input_ids"] @@ -49,6 +62,7 @@ def __init__(self): "": 2, "": 3, "": 4, + "<|image_pad|>": 5, } self._ids_to_tokens = {idx: tok for tok, idx in self._vocab.items()} super().__init__( @@ -56,15 +70,10 @@ def __init__(self): eos_token="", pad_token="", unk_token="", - extra_special_tokens=[""], + additional_special_tokens=[""], + extra_special_tokens={"image_pad_token": "<|image_pad|>"}, model_max_length=512, ) - self.chat_template = ( - "{% for message in messages %}" - "{{ message['role'] }}: {{ message['content'] | trim }}\n" - "{% endfor %}" - "{% if add_generation_prompt %}assistant:{% endif %}" - ) def get_vocab(self): return dict(self._vocab) @@ -73,7 +82,25 @@ def _tokenize(self, text): clean = text.replace("\n", " ").strip() if not clean: return [] - return [token for token in clean.split(" ") if token] + + special_tokens = sorted( + (token for token in self._vocab if token.startswith("<") and token.endswith(">")), + key=len, + reverse=True, + ) + if not special_tokens: + return [token for token in clean.split(" ") if token] + + split_pattern = "(" + "|".join(re.escape(token) for token in special_tokens) + ")" + tokens = [] + for chunk in re.split(split_pattern, clean): + if not chunk or chunk.isspace(): + continue + if chunk in self._vocab: + tokens.append(chunk) + else: + tokens.extend(token for token in chunk.split(" ") if token) + return tokens def _convert_token_to_id(self, token): if token not in self._vocab: @@ -98,6 +125,73 @@ def save_vocabulary(self, save_directory, filename_prefix=None): return () +class SimpleIsaacTokenizerWithNamedImagePad(PythonBackend): + vocab_files_names = {} + model_input_names = ["input_ids"] + + def __init__(self): + self._vocab = { + "": 0, + "": 1, + "": 2, + "": 3, + "": 4, + "": 5, + "<|image_pad|>": 6, + } + self._ids_to_tokens = {idx: tok for tok, idx in self._vocab.items()} + super().__init__( + bos_token="", + eos_token="", + pad_token="", + unk_token="", + extra_special_tokens={"image_pad_token": ""}, + model_max_length=512, + ) + + def get_vocab(self): + return dict(self._vocab) + + def _tokenize(self, text): + clean = text.replace("\n", " ").strip() + if not clean: + return [] + + special_tokens = sorted( + (token for token in self._vocab if token.startswith("<") and token.endswith(">")), + key=len, + reverse=True, + ) + split_pattern = "(" + "|".join(re.escape(token) for token in special_tokens) + ")" + tokens = [] + for chunk in re.split(split_pattern, clean): + if not chunk or chunk.isspace(): + continue + if chunk in self._vocab: + tokens.append(chunk) + else: + tokens.extend(token for token in chunk.split(" ") if token) + return tokens + + def _convert_token_to_id(self, token): + return self._vocab.get(token, self._vocab[""]) + + def _convert_id_to_token(self, index): + return self._ids_to_tokens.get(index, self.unk_token) + + @property + def vocab_size(self) -> int: + return len(self._vocab) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + if token_ids_1 is not None: + token_ids_0 = token_ids_0 + token_ids_1 + return [self.bos_token_id] + list(token_ids_0) + [self.eos_token_id] + + def save_vocabulary(self, save_directory, filename_prefix=None): + return () + + def _make_dummy_image(size=(32, 32), color=(255, 0, 0)): if Image is None: raise RuntimeError("PIL.Image is not available in this environment.") @@ -114,91 +208,94 @@ def _make_processor_with_max_len(tokenizer, base_config, max_len): pixel_shuffle_scale=vision_config.pixel_shuffle_scale_factor, rescale_factor=config.vision_rescale_factor, ) - return IsaacProcessor(image_processor=image_processor, tokenizer=tokenizer, config=config) + return IsaacProcessor( + image_processor=image_processor, + tokenizer=tokenizer, + vision_token=config.vision_token, + max_sequence_length=config.max_sequence_length, + ) def _run_processor(processor, text, images=None): return processor(text=text, images=images, return_tensors="pt") -def _assert_common(outputs): - assert set(outputs.keys()) == {"input_ids", "packed_inputs"} - input_ids = outputs["input_ids"] - packed_inputs = outputs["packed_inputs"] - - expected_packed_keys = { - "vision_patches", - "vision_token_grids", - "vision_token_offsets", - "vision_token_lengths", - "vision_token_batch_indices", - "modality_tensor", - "position_ids", - } - assert set(packed_inputs.keys()) == expected_packed_keys +def _make_post_process_processor(): + return IsaacProcessor(image_processor=IsaacImageProcessorFast(), tokenizer=SimpleIsaacTokenizer()) - assert input_ids.shape[0] == 1 - assert input_ids.dtype == torch.long - modality = packed_inputs["modality_tensor"] - position_ids = packed_inputs["position_ids"] - assert modality.shape == (1, input_ids.shape[1]) - assert position_ids.shape == (1, input_ids.shape[1], 3) - assert modality.dtype == torch.long - assert position_ids.dtype == torch.long - assert modality.device == input_ids.device == position_ids.device +def test_processor_prefers_named_image_pad_token(): + processor = IsaacProcessor( + image_processor=IsaacImageProcessorFast(), tokenizer=SimpleIsaacTokenizerWithNamedImagePad() + ) + + assert processor.image_token == "" + assert processor.image_pad_token_id == processor.tokenizer.image_pad_token_id + assert processor.image_pad_token_id != processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") + + +def _assert_common(outputs, batch_size=1): + assert set(outputs.keys()) == ISAAC_OUTPUT_KEYS + + input_ids = outputs["input_ids"] + attention_mask = outputs["attention_mask"] + mm_token_type_ids = outputs["mm_token_type_ids"] + vision_patches = outputs["vision_patches"] + vision_patch_attention_mask = outputs["vision_patch_attention_mask"] + vision_token_grids = outputs["vision_token_grids"] + vision_token_offsets = outputs["vision_token_offsets"] + vision_token_lengths = outputs["vision_token_lengths"] + vision_image_attention_mask = outputs["vision_image_attention_mask"] + + assert input_ids.shape[0] == batch_size + assert attention_mask.shape == input_ids.shape + assert mm_token_type_ids.shape == input_ids.shape + assert input_ids.dtype == torch.long + assert attention_mask.dtype == torch.long + assert mm_token_type_ids.dtype == torch.long - return input_ids, packed_inputs + assert vision_patches.shape[:2] == vision_patch_attention_mask.shape[:2] + assert vision_patches.shape[0] == batch_size + assert vision_token_grids.shape == (batch_size, vision_patches.shape[1], 2) + assert vision_token_offsets.shape == (batch_size, vision_patches.shape[1]) + assert vision_token_lengths.shape == (batch_size, vision_patches.shape[1]) + assert vision_image_attention_mask.shape == (batch_size, vision_patches.shape[1]) + return outputs -def _assert_no_vision(packed_inputs): - assert packed_inputs["vision_patches"] is None - assert packed_inputs["vision_token_grids"] is None - assert packed_inputs["vision_token_offsets"] is None - assert packed_inputs["vision_token_lengths"] is None - assert packed_inputs["vision_token_batch_indices"] is None +def _assert_no_vision(outputs, batch_index=0): + assert outputs["vision_patch_attention_mask"][batch_index].sum().item() == 0 + assert outputs["vision_token_grids"][batch_index].sum().item() == 0 + assert outputs["vision_token_offsets"][batch_index].sum().item() == 0 + assert outputs["vision_token_lengths"][batch_index].sum().item() == 0 + assert outputs["vision_image_attention_mask"][batch_index].sum().item() == 0 + assert not outputs["mm_token_type_ids"][batch_index].eq(1).any() -def _assert_vision_segments(packed_inputs, expected_segments): - assert packed_inputs["vision_patches"] is not None - assert packed_inputs["vision_token_grids"] is not None - assert packed_inputs["vision_token_offsets"] is not None - assert packed_inputs["vision_token_lengths"] is not None - assert packed_inputs["vision_token_batch_indices"] is not None - assert packed_inputs["vision_token_grids"].shape[0] == expected_segments - assert packed_inputs["vision_token_offsets"].shape == (expected_segments,) - assert packed_inputs["vision_token_lengths"].shape == (expected_segments,) - assert packed_inputs["vision_token_batch_indices"].shape == (expected_segments,) +def _assert_vision_segments(outputs, expected_segments, batch_index=0): + active_segments = int(outputs["vision_image_attention_mask"][batch_index].sum().item()) + assert active_segments == expected_segments + assert torch.all(outputs["vision_token_lengths"][batch_index, :expected_segments] > 0) + assert torch.all(outputs["vision_patch_attention_mask"][batch_index, :expected_segments].sum(dim=-1) > 0) -def _count_modality(packed_inputs, modality_value): - modality = packed_inputs["modality_tensor"] - return int((modality == modality_value).sum().item()) +def _count_modality(outputs, modality_value, batch_index=0): + return int( + (outputs["attention_mask"][batch_index].bool() & outputs["mm_token_type_ids"][batch_index].eq(modality_value)) + .sum() + .item() + ) -def _pad_to_max(tensors: list[torch.Tensor], pad_value: int) -> torch.Tensor: - """Pad a list of (L, ...) tensors to (B, L_max, ...).""" - max_len = max(t.shape[0] for t in tensors) - batch = len(tensors) - if tensors[0].ndim == 1: - out = torch.full((batch, max_len), pad_value, device=tensors[0].device, dtype=tensors[0].dtype) - for i, t in enumerate(tensors): - out[i, : t.shape[0]] = t - return out - # assume (L, K) - k = tensors[0].shape[1] - out = torch.full((batch, max_len, k), pad_value, device=tensors[0].device, dtype=tensors[0].dtype) - for i, t in enumerate(tensors): - out[i, : t.shape[0]] = t - return out +def _get_active_vision_grids(outputs, batch_index=0): + mask = outputs["vision_image_attention_mask"][batch_index].bool() + return outputs["vision_token_grids"][batch_index][mask] -def _get_image_token_length(processor, image, vision_token): - outputs = _run_processor(processor, text=vision_token, images=[image]) - _, packed = _assert_common(outputs) - _assert_vision_segments(packed, expected_segments=1) - return packed["vision_token_lengths"][0].item() +def _get_active_vision_lengths(outputs, batch_index=0): + mask = outputs["vision_image_attention_mask"][batch_index].bool() + return outputs["vision_token_lengths"][batch_index][mask] @pytest.fixture @@ -258,137 +355,232 @@ def isaac_processor(isaac_tokenizer, isaac_tiny_config): return IsaacProcessor( image_processor=image_processor, tokenizer=isaac_tokenizer, - config=isaac_tiny_config, + vision_token=isaac_tiny_config.vision_token, + max_sequence_length=isaac_tiny_config.max_sequence_length, ) +BASE_MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1-Base") +BASE_MODEL_REVISION = os.environ.get("ISAAC_TEST_MODEL_REVISION", "refs/pr/3") or None +LOCAL_CHECKPOINT = os.environ.get("ISAAC_TEST_MODEL_PATH") + + +def _checkpoint_or_skip(model_id=BASE_MODEL_ID): + if LOCAL_CHECKPOINT: + resolved = Path(LOCAL_CHECKPOINT).expanduser() + if not resolved.exists(): + pytest.skip(f"Local checkpoint path {resolved} does not exist.") + return str(resolved) + if is_offline_mode(): + pytest.skip("Offline mode: set ISAAC_TEST_MODEL_PATH to a local checkpoint to run these tests.") + return model_id + + @require_torch @require_vision -def test_isaac_processor_matches_config_defaults(isaac_processor, isaac_tiny_config): - assert isaac_processor.vision_token == isaac_tiny_config.vision_token - assert isaac_processor.max_sequence_length == isaac_tiny_config.max_sequence_length - assert isaac_processor.config is isaac_tiny_config - assert isinstance(isaac_processor.image_processor, IsaacImageProcessorFast) - assert isaac_processor.image_processor.rescale_factor == pytest.approx(isaac_tiny_config.vision_rescale_factor) +class IsaacProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = IsaacProcessor + model_id = BASE_MODEL_ID + images_input_name = "vision_patches" + + @classmethod + def _setup_from_pretrained(cls, model_id, **kwargs): + checkpoint = _checkpoint_or_skip(model_id) + return super()._setup_from_pretrained( + checkpoint, + revision=BASE_MODEL_REVISION, + patch_size=4, + max_num_patches=4, + **kwargs, + ) + + @classmethod + def _setup_test_attributes(cls, processor): + cls.image_token = processor.vision_token + cls.pad_token_id = processor.tokenizer.pad_token_id + cls.image_pad_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") + + def prepare_image_inputs(self, batch_size: int | None = None, nested: bool = False): + if batch_size is None: + return _make_dummy_image(size=(16, 16)) + images = [_make_dummy_image(size=(16, 16), color=(50 * (i + 1), 0, 0)) for i in range(batch_size)] + if nested: + return [[image] for image in images] + return images + + def test_model_input_names(self): + processor = self.get_processor() + inputs = processor( + text=self.prepare_text_inputs(modalities="image"), + images=self.prepare_image_inputs(), + return_tensors="pt", + ) + + expected_input_names = set(processor.model_input_names) | { + "mm_token_type_ids", + "vision_token_offsets", + "vision_token_lengths", + "vision_image_attention_mask", + } + self.assertSetEqual(set(inputs.keys()), expected_input_names) + + @unittest.skip("IsaacProcessor expands image placeholders into image pad tokens before tokenization") + def test_tokenizer_defaults(self): + pass + + @unittest.skip("IsaacProcessor does not return offset mappings needed for assistant masks") + def test_apply_chat_template_assistant_mask(self): + pass + + def test_single_vs_batched_consistency(self): + processor = self.get_processor() + prompt = f"hello {processor.vision_token} world" + image = self.prepare_image_inputs() + + single = _assert_common(processor(text=prompt, images=[image], return_tensors="pt")) + batch = _assert_common( + processor(text=[prompt, "short"], images=[[image], []], return_tensors="pt"), batch_size=2 + ) + + single_ids = single["input_ids"].squeeze(0) + batch_ids = batch["input_ids"][0] + self.assertTrue(torch.equal(batch_ids[-single_ids.size(0) :], single_ids)) + + image_positions = batch["mm_token_type_ids"][0].eq(1) + if image_positions.any(): + self.assertTrue(torch.all(batch_ids[image_positions] == self.image_pad_token_id)) + self.assertTrue(torch.all(batch["attention_mask"][0][image_positions] == 1)) + + _assert_vision_segments(batch, expected_segments=1, batch_index=0) + _assert_no_vision(batch, batch_index=1) @require_torch @require_vision def test_text_only_has_no_vision_fields(isaac_processor): - outputs = _run_processor(isaac_processor, text="Hello, how are you?", images=None) - _, packed = _assert_common(outputs) - _assert_no_vision(packed) + outputs = _assert_common(_run_processor(isaac_processor, text="Hello, how are you?", images=None)) + _assert_no_vision(outputs) @require_torch -@require_vision -def test_accepts_batchencoding_chat_template(isaac_processor): - messages = [{"role": "user", "content": "Hello, how are you?"}] - batch_encoding = isaac_processor.apply_chat_template(messages, add_generation_prompt=True) +def test_post_process_generation_extracts_boxes_and_cleans_text(): + processor = _make_post_process_processor() + + generated_text = ( + "No, it is not safe to cross the street. " + '(808, 247), (863, 386)' + ) - outputs = _run_processor(isaac_processor, text=batch_encoding, images=None) - _, packed = _assert_common(outputs) - _assert_no_vision(packed) + clean_text, annotations = processor.post_process_generation(generated_text) + + assert clean_text == "No, it is not safe to cross the street." + assert len(annotations) == 1 + box = annotations[0] + assert box.mention == "traffic light" + assert box.t == pytest.approx(0.5) + assert box.top_left.x == 808 + assert box.top_left.y == 247 + assert box.bottom_right.x == 863 + assert box.bottom_right.y == 386 @require_torch @require_vision def test_single_image_returns_offsets_and_lengths(isaac_processor): vision_token = isaac_processor.vision_token - text = f"Look at this {vision_token} and describe it." - image = _make_dummy_image() - - outputs = _run_processor(isaac_processor, text=text, images=[image]) - _, packed = _assert_common(outputs) - _assert_vision_segments(packed, expected_segments=1) + outputs = _assert_common( + _run_processor( + isaac_processor, text=f"Look at this {vision_token} and describe it.", images=[_make_dummy_image()] + ) + ) + _assert_vision_segments(outputs, expected_segments=1) - grid_tokens = torch.prod(packed["vision_token_grids"], dim=-1) - torch.testing.assert_close(packed["vision_token_lengths"], grid_tokens) - torch.testing.assert_close(packed["vision_token_offsets"], torch.zeros_like(packed["vision_token_offsets"])) + grid_tokens = torch.prod(_get_active_vision_grids(outputs), dim=-1) + torch.testing.assert_close(_get_active_vision_lengths(outputs), grid_tokens) + torch.testing.assert_close( + outputs["vision_token_offsets"][0, :1], torch.zeros_like(outputs["vision_token_offsets"][0, :1]) + ) @require_torch @require_vision def test_multiple_images_have_matching_offsets_lengths_and_grids(isaac_processor): vision_token = isaac_processor.vision_token - text = f"First {vision_token} then {vision_token}" images = [_make_dummy_image(color=(255, 0, 0)), _make_dummy_image(color=(0, 255, 0))] - outputs = _run_processor(isaac_processor, text=text, images=images) - _, packed = _assert_common(outputs) - _assert_vision_segments(packed, expected_segments=2) + outputs = _assert_common( + _run_processor(isaac_processor, text=f"First {vision_token} then {vision_token}", images=images) + ) + _assert_vision_segments(outputs, expected_segments=2) - grid_tokens = torch.prod(packed["vision_token_grids"], dim=-1) - torch.testing.assert_close(packed["vision_token_lengths"], grid_tokens) - torch.testing.assert_close(packed["vision_token_offsets"], torch.zeros_like(packed["vision_token_offsets"])) + grid_tokens = torch.prod(_get_active_vision_grids(outputs), dim=-1) + torch.testing.assert_close(_get_active_vision_lengths(outputs), grid_tokens) + torch.testing.assert_close( + outputs["vision_token_offsets"][0, :2], torch.zeros_like(outputs["vision_token_offsets"][0, :2]) + ) @require_torch @require_vision def test_error_on_image_mismatch(isaac_processor): vision_token = isaac_processor.vision_token - text = f"{vision_token} {vision_token}" - image = _make_dummy_image() - with pytest.raises(ValueError, match="one image per"): - _run_processor(isaac_processor, text=text, images=[image]) + _run_processor(isaac_processor, text=f"{vision_token} {vision_token}", images=[_make_dummy_image()]) @require_torch @require_vision def test_consecutive_vision_tokens_allow_empty_text_segments(isaac_processor): vision_token = isaac_processor.vision_token - text = f"prefix {vision_token}{vision_token} suffix" images = [_make_dummy_image(), _make_dummy_image(color=(0, 0, 255))] - outputs = _run_processor(isaac_processor, text=text, images=images) - _, packed = _assert_common(outputs) - _assert_vision_segments(packed, expected_segments=2) + outputs = _assert_common( + _run_processor(isaac_processor, text=f"prefix {vision_token}{vision_token} suffix", images=images) + ) + _assert_vision_segments(outputs, expected_segments=2) - torch.testing.assert_close(packed["vision_token_offsets"], torch.zeros_like(packed["vision_token_offsets"])) - grid_tokens = torch.prod(packed["vision_token_grids"], dim=-1) - torch.testing.assert_close(packed["vision_token_lengths"], grid_tokens) + torch.testing.assert_close( + outputs["vision_token_offsets"][0, :2], torch.zeros_like(outputs["vision_token_offsets"][0, :2]) + ) + grid_tokens = torch.prod(_get_active_vision_grids(outputs), dim=-1) + torch.testing.assert_close(_get_active_vision_lengths(outputs), grid_tokens) @require_torch @require_vision def test_device_and_dtype_consistency(isaac_processor): vision_token = isaac_processor.vision_token - text = f"Describe this {vision_token}" - image = _make_dummy_image() - - outputs = _run_processor(isaac_processor, text=text, images=[image]) - input_ids, packed = _assert_common(outputs) - _assert_vision_segments(packed, expected_segments=1) + outputs = _assert_common( + _run_processor(isaac_processor, text=f"Describe this {vision_token}", images=[_make_dummy_image()]) + ) + _assert_vision_segments(outputs, expected_segments=1) tensors = [ - input_ids, - packed["position_ids"], - packed["modality_tensor"], - packed["vision_token_offsets"], - packed["vision_token_lengths"], - packed["vision_token_grids"], + outputs["input_ids"], + outputs["attention_mask"], + outputs["mm_token_type_ids"], + outputs["vision_token_offsets"], + outputs["vision_token_lengths"], + outputs["vision_token_grids"], ] - devices = {t.device for t in tensors} + devices = {tensor.device for tensor in tensors} assert len(devices) == 1 - for t in tensors: - assert t.dtype == torch.long + for tensor in tensors: + assert tensor.dtype == torch.long @require_torch @require_vision def test_no_crop_when_total_below_max(isaac_processor): vision_token = isaac_processor.vision_token - text = f"hello {vision_token} world" - image = _make_dummy_image() - - outputs = _run_processor(isaac_processor, text=text, images=[image]) - input_ids, packed = _assert_common(outputs) - _assert_vision_segments(packed, expected_segments=1) + outputs = _assert_common( + _run_processor(isaac_processor, text=f"hello {vision_token} world", images=[_make_dummy_image()]) + ) + _assert_vision_segments(outputs, expected_segments=1) - grid_tokens = torch.prod(packed["vision_token_grids"], dim=-1) - text_tokens = _count_modality(packed, ModalityType.text.value) - assert input_ids.shape[1] == grid_tokens.item() + text_tokens + grid_tokens = torch.prod(_get_active_vision_grids(outputs), dim=-1) + text_tokens = _count_modality(outputs, 0) + assert outputs["input_ids"].shape[1] == grid_tokens.item() + text_tokens @require_torch @@ -398,18 +590,16 @@ def test_exact_fit_keeps_all_tokens(isaac_processor, isaac_tokenizer, isaac_tiny text = f"hey {vision_token} there" image = _make_dummy_image() - base_outputs = _run_processor(isaac_processor, text=text, images=[image]) + base_outputs = _assert_common(_run_processor(isaac_processor, text=text, images=[image])) base_length = base_outputs["input_ids"].shape[1] - base_packed = base_outputs["packed_inputs"] - base_vision_length = base_packed["vision_token_lengths"][0].item() + base_vision_length = _get_active_vision_lengths(base_outputs).item() processor = _make_processor_with_max_len(isaac_tokenizer, isaac_tiny_config, base_length) - outputs = _run_processor(processor, text=text, images=[image]) + outputs = _assert_common(_run_processor(processor, text=text, images=[image])) - input_ids, packed = _assert_common(outputs) - _assert_vision_segments(packed, expected_segments=1) - assert input_ids.shape[1] == base_length - assert packed["vision_token_lengths"].item() == base_vision_length + _assert_vision_segments(outputs, expected_segments=1) + assert outputs["input_ids"].shape[1] == base_length + assert _get_active_vision_lengths(outputs).item() == base_vision_length @require_torch @@ -417,27 +607,24 @@ def test_exact_fit_keeps_all_tokens(isaac_processor, isaac_tokenizer, isaac_tiny def test_crop_truncates_text_segment_only(isaac_processor, isaac_tokenizer, isaac_tiny_config): vision_token = isaac_processor.vision_token text_prefix_tokens = " ".join([f"t{i}" for i in range(8)]) - text_suffix = "tail end" - text = f"{text_prefix_tokens} {vision_token} {text_suffix}" + text = f"{text_prefix_tokens} {vision_token} tail end" image = _make_dummy_image() - base_outputs = _run_processor(isaac_processor, text=text, images=[image]) - base_packed = base_outputs["packed_inputs"] - full_text_tokens = _count_modality(base_packed, ModalityType.text.value) - vision_length = base_packed["vision_token_lengths"][0].item() + base_outputs = _assert_common(_run_processor(isaac_processor, text=text, images=[image])) + full_text_tokens = _count_modality(base_outputs, 0) + vision_length = _get_active_vision_lengths(base_outputs).item() max_len = base_outputs["input_ids"].shape[1] - 4 processor = _make_processor_with_max_len(isaac_tokenizer, isaac_tiny_config, max_len) - outputs = _run_processor(processor, text=text, images=[image]) - - input_ids, packed = _assert_common(outputs) - _assert_vision_segments(packed, expected_segments=1) - assert input_ids.shape[1] == max_len + outputs = _assert_common(_run_processor(processor, text=text, images=[image])) - kept_text_tokens = _count_modality(packed, ModalityType.text.value) - assert kept_text_tokens == full_text_tokens - 4 - torch.testing.assert_close(packed["vision_token_offsets"], torch.zeros_like(packed["vision_token_offsets"])) - assert packed["vision_token_lengths"].item() == vision_length + _assert_vision_segments(outputs, expected_segments=1) + assert outputs["input_ids"].shape[1] == max_len + assert _count_modality(outputs, 0) == full_text_tokens - 4 + torch.testing.assert_close( + outputs["vision_token_offsets"][0, :1], torch.zeros_like(outputs["vision_token_offsets"][0, :1]) + ) + assert _get_active_vision_lengths(outputs).item() == vision_length @require_torch @@ -449,9 +636,8 @@ def test_crop_cuts_through_image_segment(isaac_processor, isaac_tokenizer, isaac text = f"{text_before} {vision_token} {text_after}" image = _make_dummy_image() - base_outputs = _run_processor(isaac_processor, text=text, images=[image]) - base_packed = base_outputs["packed_inputs"] - vision_full = base_packed["vision_token_lengths"][0].item() + base_outputs = _assert_common(_run_processor(isaac_processor, text=text, images=[image])) + vision_full = _get_active_vision_lengths(base_outputs).item() text_before_len = len(isaac_tokenizer.encode(text_before, add_special_tokens=False)) text_after_len = len(isaac_tokenizer.encode(text_after, add_special_tokens=False)) total_length = vision_full + text_before_len + text_after_len @@ -462,33 +648,13 @@ def test_crop_cuts_through_image_segment(isaac_processor, isaac_tokenizer, isaac expected_length = vision_full - expected_offset processor = _make_processor_with_max_len(isaac_tokenizer, isaac_tiny_config, max_len) - outputs = _run_processor(processor, text=text, images=[image]) - - input_ids, packed = _assert_common(outputs) - _assert_vision_segments(packed, expected_segments=1) - - assert input_ids.shape[1] == max_len - assert packed["vision_token_offsets"].item() == expected_offset - assert packed["vision_token_lengths"].item() == expected_length - assert _count_modality(packed, ModalityType.text.value) == text_after_len - - -@require_torch -@require_vision -def test_crop_removes_all_vision_when_window_excludes_images(isaac_processor, isaac_tokenizer, isaac_tiny_config): - vision_token = isaac_processor.vision_token - text_tail = "closing" - text = f"{vision_token} {text_tail}" - image = _make_dummy_image() - - tail_tokens = len(isaac_processor.tokenizer.encode(text_tail, add_special_tokens=False)) - processor = _make_processor_with_max_len(isaac_tokenizer, isaac_tiny_config, tail_tokens) - outputs = _run_processor(processor, text=text, images=[image]) + outputs = _assert_common(_run_processor(processor, text=text, images=[image])) - input_ids, packed = _assert_common(outputs) - _assert_no_vision(packed) - assert input_ids.shape[1] == tail_tokens - assert _count_modality(packed, ModalityType.text.value) == tail_tokens + _assert_vision_segments(outputs, expected_segments=1) + assert outputs["input_ids"].shape[1] == max_len + assert outputs["vision_token_offsets"][0, 0].item() == expected_offset + assert _get_active_vision_lengths(outputs).item() == expected_length + assert _count_modality(outputs, 0) == text_after_len @require_torch @@ -496,174 +662,27 @@ def test_crop_removes_all_vision_when_window_excludes_images(isaac_processor, is def test_batch_outputs_match_individual_calls(isaac_processor): texts = ["hi", "this one is longer"] - per_sample = [_run_processor(isaac_processor, text=t, images=None) for t in texts] - batch_outputs = _run_processor(isaac_processor, text=texts, images=None) - - assert set(batch_outputs.keys()) == {"input_ids", "packed_inputs"} - batch_input_ids = batch_outputs["input_ids"] - batch_packed = batch_outputs["packed_inputs"] - - assert set(batch_packed.keys()) == { - "vision_patches", - "vision_token_grids", - "vision_token_offsets", - "vision_token_lengths", - "vision_token_batch_indices", - "modality_tensor", - "position_ids", - } - - assert batch_input_ids.shape[0] == len(texts) - assert batch_packed["modality_tensor"].shape[0] == len(texts) - assert batch_packed["position_ids"].shape[0] == len(texts) + per_sample = [_assert_common(_run_processor(isaac_processor, text=text, images=None)) for text in texts] + batch_outputs = _assert_common(_run_processor(isaac_processor, text=texts, images=None), batch_size=len(texts)) - sample_lengths = [output["input_ids"].squeeze(0).shape[0] for output in per_sample] - max_length = max(sample_lengths) pad_id = isaac_processor.pad_token_id - - for i, (single_output, batch_ids, single_len) in enumerate(zip(per_sample, batch_input_ids, sample_lengths)): + for index, single_output in enumerate(per_sample): single_ids = single_output["input_ids"].squeeze(0) - single_packed = single_output["packed_inputs"] - - torch.testing.assert_close(batch_ids[-single_len:], single_ids) - - batch_modality_row = batch_packed["modality_tensor"][i] - expected_modality = torch.full( - (max_length,), - batch_modality_row[-1].item(), - dtype=batch_modality_row.dtype, - device=batch_modality_row.device, - ) - expected_modality[-single_len:] = single_packed["modality_tensor"].squeeze(0) - torch.testing.assert_close(batch_modality_row, expected_modality) - - batch_positions_row = batch_packed["position_ids"][i] - expected_positions = torch.zeros( - (max_length, 3), dtype=batch_positions_row.dtype, device=batch_positions_row.device - ) - expected_positions[-single_len:] = single_packed["position_ids"].squeeze(0) - torch.testing.assert_close(batch_positions_row, expected_positions) - - if single_len == max_length: - continue - - pad_span = batch_ids[: max_length - single_len] - assert torch.all(pad_span == pad_id) - - attention_mask = batch_ids.ne(pad_id).long() - assert not torch.any(attention_mask[: max_length - single_len]) - assert torch.all(attention_mask[-single_len:]) - - _assert_no_vision(batch_packed) - - -class StubTokenizer(SimpleIsaacTokenizer): - def __init__(self): - super().__init__() - self._base = 2000 - - def encode(self, text, add_special_tokens=False, return_tensors=None): - token_ids = torch.tensor([self._base + b for b in text.encode("utf-8")], dtype=torch.long) - if return_tensors in {"pt", TensorType.PYTORCH}: - return token_ids.unsqueeze(0) - return token_ids - - def convert_tokens_to_ids(self, token): - if token == "<|image_pad|>": - return 151655 - if token == self.pad_token: - return super().convert_tokens_to_ids(token) - return None - - -class StubImageProcessor(ImageProcessingMixin): - def __call__(self, images=None, return_tensors=None): - patches = torch.ones((1, 2, 2, 3), dtype=torch.float32) - sizes = torch.tensor([[1, 2, 2]], dtype=torch.long) - return { - "patches": patches, - "virtual_pixel_size": sizes, - "real_pixel_size": sizes, - } - - -BASE_MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1-Base") -BASE_MODEL_REVISION = os.environ.get("ISAAC_TEST_MODEL_REVISION", "refs/pr/3") or None -LOCAL_CHECKPOINT = os.environ.get("ISAAC_TEST_MODEL_PATH") - - -def _checkpoint_or_skip(): - if LOCAL_CHECKPOINT: - resolved = Path(LOCAL_CHECKPOINT).expanduser() - if not resolved.exists(): - pytest.skip(f"Local checkpoint path {resolved} does not exist.") - return str(resolved) - if is_offline_mode(): - pytest.skip("Offline mode: set ISAAC_TEST_MODEL_PATH to a local checkpoint to run these tests.") - return BASE_MODEL_ID - + single_mask = single_output["attention_mask"].squeeze(0) + single_mm = single_output["mm_token_type_ids"].squeeze(0) -def _create_real_processor(): - checkpoint = _checkpoint_or_skip() - config = IsaacConfig.from_pretrained(checkpoint, revision=BASE_MODEL_REVISION) - processor = IsaacProcessor.from_pretrained(checkpoint, revision=BASE_MODEL_REVISION) - tokenizer = processor.tokenizer - return processor, tokenizer, config + batch_ids = batch_outputs["input_ids"][index] + batch_mask = batch_outputs["attention_mask"][index] + batch_mm = batch_outputs["mm_token_type_ids"][index] + single_len = single_ids.shape[0] + assert torch.equal(batch_ids[-single_len:], single_ids) + assert torch.equal(batch_mask[-single_len:], single_mask) + assert torch.equal(batch_mm[-single_len:], single_mm) -@require_torch -@require_vision -class TestIsaacProcessorRealPadding(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.processor, cls.tokenizer, cls.config = _create_real_processor() - cls.dummy_image = _make_dummy_image() - cls.vision_token = cls.config.vision_token - cls.pad_id = cls.tokenizer.pad_token_id - cls.image_pad_id = cls.tokenizer.convert_tokens_to_ids("<|image_pad|>") - if cls.pad_id is None or cls.image_pad_id is None: - pytest.skip("pad/image pad ids unavailable for processor") - - def _check_padding_and_masks(self, input_ids: torch.Tensor, pad_id: int): - for row in range(input_ids.size(0)): - row_ids = input_ids[row] - nonpad_positions = (row_ids != pad_id).nonzero(as_tuple=False) - last_nonpad = int(nonpad_positions.max()) if nonpad_positions.numel() else -1 - if last_nonpad + 1 < row_ids.numel(): - tail = row_ids[last_nonpad + 1 :] - assert torch.all(tail == pad_id) - attn = (row_ids != pad_id).long() - if last_nonpad >= 0: - assert torch.all(attn[: last_nonpad + 1] == 1) - assert int(attn[last_nonpad + 1 :].sum()) == 0 + if single_len < batch_ids.shape[0]: + pad_span = batch_ids[: batch_ids.shape[0] - single_len] + assert torch.all(pad_span == pad_id) + assert not torch.any(batch_mask[: batch_ids.shape[0] - single_len]) - def test_single_vs_batched_consistency(self): - prompt = f"hello {self.vision_token} world" - images_single = [self.dummy_image] - - single = self.processor(text=prompt, images=images_single, return_tensors="pt") - single_ids = single["input_ids"].squeeze(0) - - batch_prompts = [prompt, "short"] - batch_images = [images_single, None] - batch = self.processor(text=batch_prompts, images=batch_images, return_tensors="pt") - batch_ids = batch["input_ids"][0] - modality = batch["packed_inputs"]["modality_tensor"][0] - - assert torch.equal(batch_ids[: single_ids.size(0)], single_ids) - - image_positions = modality == ModalityType.image.value - if image_positions.any(): - assert torch.all(batch_ids[image_positions] == self.image_pad_id) - assert torch.all(batch_ids[image_positions] != self.pad_id) - - nonpad = (batch_ids != self.pad_id).nonzero(as_tuple=False) - last_nonpad = int(nonpad.max()) if nonpad.numel() else -1 - if last_nonpad + 1 < batch_ids.numel(): - tail = batch_ids[last_nonpad + 1 :] - assert torch.all(tail == self.pad_id) - - attn = (batch_ids != self.pad_id).long() - if last_nonpad >= 0: - assert torch.all(attn[: last_nonpad + 1] == 1) - assert int(attn[last_nonpad + 1 :].sum()) == 0 + _assert_no_vision(batch_outputs, batch_index=index) From 778d8c5a13980067d4288c2368639d0947e59d49 Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Tue, 24 Mar 2026 21:14:47 +0400 Subject: [PATCH 65/77] feat: config updates, image processor backend, assorted changes/tests (#16) * fix: use torchvisionbackend * fix: import IsaacImageProcessor * fix: resample not interpolation * style: orgranize import * chore: auto processing auto from main * feat: register isaac image processor according to new convention * fix: update to new config style * fix: correct pix2struct import * docs: initial doc update * feat: re-register isaac processor to auto * refactor: move max_posiiton_embeddings to isaac config refactor: move max_position_embeddings to isaac config * TEMP pop! * docs: update date * style: remove removed attr * style: add config attr for completeness * style: drop redundant merge_with_config_defaults * style: remove redundant positions ids handling * refactor: rely on base class for setting embeddings * fix: always use full attention * style: clarify padding logic * chore: remove stale artifcat * fix: kwargs name! * refactor: isolate custom padding to image processor pad method * feat: no device movement * style: align with transformers standard for loading rope params * refactor: drop unneeded arg filter * feat: compile check image presence instead * docs: add clarifying comment for keeping empty tensors * refactor: move broadcasting to forward WIP * style: use new layer validation functionality * feat: update embedding access mixin to support nested paths! * chore: convert artifacts * refactor: inline build batch * style: drop duplicate test * test: polygons * feat: polygon extraction * test: polygon generation test --- docs/source/en/model_doc/isaac.md | 113 +++-- src/transformers/modeling_utils.py | 37 +- .../models/auto/image_processing_auto.py | 376 ++++++++------- .../models/auto/processing_auto.py | 1 + .../models/isaac/configuration_isaac.py | 155 +++--- .../models/isaac/image_processing_isaac.py | 343 +++++++++++++- .../isaac/image_processing_isaac_fast.py | 330 ------------- .../models/isaac/modeling_isaac.py | 189 +++----- .../models/isaac/modular_isaac.py | 440 +++++++++--------- .../models/isaac/processing_isaac.py | 144 +++--- tests/models/isaac/test_modeling_isaac.py | 98 ++-- tests/models/isaac/test_processing_isaac.py | 38 ++ tests/utils/test_modeling_utils.py | 61 +++ 13 files changed, 1232 insertions(+), 1093 deletions(-) delete mode 100644 src/transformers/models/isaac/image_processing_isaac_fast.py diff --git a/docs/source/en/model_doc/isaac.md b/docs/source/en/model_doc/isaac.md index 0d4c74b3a06c..eff6634c4b29 100644 --- a/docs/source/en/model_doc/isaac.md +++ b/docs/source/en/model_doc/isaac.md @@ -1,4 +1,4 @@ - *This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-30.* -*This model was added to Hugging Face Transformers in 2025.*
@@ -26,31 +25,20 @@ rendered properly in your Markdown viewer. # Isaac +## Overview + Isaac is Perceptron's vision-language model (VLM) that pairs a SigLIP2 vision encoder with a Qwen3 decoder-only stack. The -architecture is designed for efficient long-context multimodal interactions, and supports interleaving images with -text. The vision encoder has variable-resolution capability and with optional pixel shuffle to merge -neighboring patches before they reach the decoder, which keeps the KV-cache and compute requirements manageable on long -conversations. Text and vision tokens are unified via the [`TensorStream`](https://github.com/perceptron-ai-inc/perceptron/tree/main/src/perceptron/tensorstream) abstraction so -that modal boundaries, spatial coordinates, and rescaling parameters are preserved throughout the model stack. For more information, refer to the [technical report](https://github.com/perceptron-ai-inc/perceptron/blob/main/papers/isaac_01.pdf). - -Key implementation notes: - -- **Packed vision attention** โ€“ `IsaacVisionEncoder` keeps track of per-image patch lengths and uses specialized attention - kernels with custom `AttentionMaskConverter` utilities so the decoder only applies attention to real patches while supporting - both FlashAttention and SDPA. -- **TensorStream-first pipeline** โ€“ `IsaacProcessor` converts chat templates into multimodal streams where every image gets a - dedicated event with spatial metadata. `IsaacModel` can embed that stream directly (using `embed_stream`) and automatically - derive multi-dimensional RoPE coordinates, so you only need to provide the `tensor_stream` during the first decoding step. -- **Fast image pre-processing** โ€“ `IsaacImageProcessorFast` solves for the closest resolution that fits within the requested context. +Transformers implementation supports text-only and image-conditioned generation, including prompts with multiple interleaved +images. Isaac uses variable-resolution image preprocessing and can optionally reduce spatial tokens with pixel shuffle to keep +long multimodal prompts manageable. For more information, refer to the [technical report](https://github.com/perceptron-ai-inc/perceptron/blob/main/papers/isaac_01.pdf). Isaac checkpoints are distributed under Perceptron's Non-Production license; please review the license that ships with the weights before using them in commercial settings. -## Usage example +## Usage -`IsaacProcessor` expects that every `` token in the rendered prompt has a -matching image. The processor returns both standard tokenized inputs and a `TensorStream`. You should pass the stream to the -model (only the first generation step requires it) alongside the regular tensors. +Isaac uses explicit image placeholders in the rendered prompt. Every occurrence of `processor.vision_token` (usually +``) must have a matching image in the `images` argument. ```py import torch @@ -68,47 +56,86 @@ model = IsaacForConditionalGeneration.from_pretrained( images = [Image.open("chart.png"), Image.open("panel.jpg")] messages = [ - { - "role": "user", - "content": [ - {"type": "image", "image": images[0]}, - {"type": "image", "image": images[1]}, - {"type": "text", "text": "Compare the two figures and explain what changed."}, - ], - } + {"role": "user", "content": "Compare the two figures and explain what changed."}, + {"role": "user", "content": f"{processor.vision_token}{processor.vision_token}"}, ] -# Render the chat template to text so we can pass text+images together. prompt = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, +).strip() + +inputs = processor(text=prompt, images=images, return_tensors="pt") +multimodal_keys = ( + "input_ids", + "attention_mask", + "mm_token_type_ids", + "vision_patches", + "vision_patch_attention_mask", + "vision_token_grids", + "vision_token_offsets", + "vision_token_lengths", ) - -# IsaacProcessor builds TensorStream events internally when both text and images are provided. -batch = processor(text=prompt, images=images, return_tensors="pt") +model_inputs = {key: inputs[key].to(model.device) for key in multimodal_keys} with torch.inference_mode(): - generated = model.generate( - **inputs, - tensor_stream=tensor_stream, + generated_ids = model.generate( + **model_inputs, max_new_tokens=256, - temperature=0.2, + do_sample=False, eos_token_id=processor.tokenizer.eos_token_id, pad_token_id=processor.tokenizer.eos_token_id, ) -response = processor.post_process_image_text_to_text( - generated, - skip_special_tokens=True, -)[0] +generated_ids = generated_ids[:, model_inputs["input_ids"].shape[1] :] +response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] print(response) ``` +`IsaacProcessor` returns standard multimodal tensors that can be passed directly to the model, including `input_ids`, +`attention_mask`, `mm_token_type_ids`, `vision_patches`, `vision_patch_attention_mask`, `vision_token_grids`, +`vision_token_offsets`, `vision_token_lengths`, and `vision_image_attention_mask`. + +Important notes: + +- Pass the full processor output to `generate()`. Isaac uses the multimodal tensors during prefill and handles cached + decoding internally. +- Batched inputs can mix text-only and multimodal samples. For batched multimodal inputs, pass images as a nested list such + as `[[image_a], [image_b, image_c], []]`. +- If truncation is enabled, the processor keeps the rightmost part of the packed multimodal sequence and updates + `vision_token_offsets` and `vision_token_lengths` automatically. + +### Post-processing grounded outputs + +Isaac can generate grounded points and boxes in tagged text spans. Use `post_process_generation()` to strip the tags and +recover structured annotations. + +```py +clean_text, annotations = processor.post_process_generation(response, expected="box") +print(clean_text) +print(annotations) +``` + +Set `expected="point"` to extract point annotations, or leave `expected=None` to collect both points and boxes. + +## IsaacVisionConfig + +[[autodoc]] IsaacVisionConfig + +## IsaacTextConfig + +[[autodoc]] IsaacTextConfig + ## IsaacConfig [[autodoc]] IsaacConfig +## IsaacTextModel + +[[autodoc]] IsaacTextModel + - forward + ## IsaacModel [[autodoc]] IsaacModel @@ -123,6 +150,10 @@ print(response) [[autodoc]] IsaacProcessor +## IsaacImageProcessor + +[[autodoc]] IsaacImageProcessor + ## IsaacImageProcessorFast [[autodoc]] IsaacImageProcessorFast diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1cdb033cb709..2866a7908704 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1003,6 +1003,33 @@ class EmbeddingAccessMixin: _input_embed_layer = "embed_tokens" # default layer that holds input embeddings. + def _resolve_input_embed_layer(self) -> tuple[nn.Module | None, str]: + """ + Returns the parent module and leaf attribute for `_input_embed_layer`. + + Supports both a simple attribute name such as `embed_tokens` and a dotted path such as + `text_model.embed_tokens`. + """ + + name = getattr(self, "_input_embed_layer", "embed_tokens") + if "." not in name: + return None, name + + module_path, _, attribute_name = name.rpartition(".") + try: + module = self.get_submodule(module_path) + except AttributeError as error: + raise NotImplementedError( + f"`_input_embed_layer={name}` could not be resolved for {self.__class__.__name__}." + ) from error + + if not hasattr(module, attribute_name): + raise NotImplementedError( + f"`_input_embed_layer={name}` could not be resolved for {self.__class__.__name__}." + ) + + return module, attribute_name + def get_input_embeddings(self) -> nn.Module: """ Returns the model's input embeddings. @@ -1011,7 +1038,9 @@ def get_input_embeddings(self) -> nn.Module: `nn.Module`: A torch module mapping vocabulary to hidden states. """ - name = getattr(self, "_input_embed_layer", "embed_tokens") + module, name = self._resolve_input_embed_layer() + if module is not None: + return getattr(module, name) # 1) Direct attribute (most NLP models). if (default_embedding := getattr(self, name, None)) is not None: @@ -1044,7 +1073,11 @@ def set_input_embeddings(self, value: nn.Module): should) override for exotic layouts. """ - name = getattr(self, "_input_embed_layer", "embed_tokens") + module, name = self._resolve_input_embed_layer() + if module is not None: + setattr(module, name, value) + return + # 1) Direct attribute (most NLP models) if hasattr(self, name): setattr(self, name, value) diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index e5c127c55e61..16ab8cd2d3b5 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -61,176 +61,212 @@ else: IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict( [ - ("aimv2", ("CLIPImageProcessor", "CLIPImageProcessorFast")), - ("aimv2_vision_model", ("CLIPImageProcessor", "CLIPImageProcessorFast")), - ("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")), - ("altclip", ("CLIPImageProcessor", "CLIPImageProcessorFast")), - ("aria", ("AriaImageProcessor", None)), - ("aya_vision", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")), - ("beit", ("BeitImageProcessor", "BeitImageProcessorFast")), - ("bit", ("BitImageProcessor", "BitImageProcessorFast")), - ("blip", ("BlipImageProcessor", "BlipImageProcessorFast")), - ("blip-2", ("BlipImageProcessor", "BlipImageProcessorFast")), - ("bridgetower", ("BridgeTowerImageProcessor", "BridgeTowerImageProcessorFast")), - ("chameleon", ("ChameleonImageProcessor", "ChameleonImageProcessorFast")), - ("chinese_clip", ("ChineseCLIPImageProcessor", "ChineseCLIPImageProcessorFast")), - ("chmv2", (None, "CHMv2ImageProcessorFast")), - ("clip", ("CLIPImageProcessor", "CLIPImageProcessorFast")), - ("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")), - ("cohere2_vision", (None, "Cohere2VisionImageProcessorFast")), - ("colpali", ("SiglipImageProcessor", "SiglipImageProcessorFast")), - ("colqwen2", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")), - ("conditional_detr", ("ConditionalDetrImageProcessor", "ConditionalDetrImageProcessorFast")), - ("convnext", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), - ("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), - ("cvt", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), - ("data2vec-vision", ("BeitImageProcessor", "BeitImageProcessorFast")), - ("deepseek_vl", ("DeepseekVLImageProcessor", "DeepseekVLImageProcessorFast")), - ("deepseek_vl_hybrid", ("DeepseekVLHybridImageProcessor", "DeepseekVLHybridImageProcessorFast")), - ("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")), - ("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")), - ("depth_anything", ("DPTImageProcessor", "DPTImageProcessorFast")), - ("depth_pro", ("DepthProImageProcessor", "DepthProImageProcessorFast")), - ("detr", ("DetrImageProcessor", "DetrImageProcessorFast")), - ("dinat", ("ViTImageProcessor", "ViTImageProcessorFast")), - ("dinov2", ("BitImageProcessor", "BitImageProcessorFast")), - ("dinov3_vit", (None, "DINOv3ViTImageProcessorFast")), - ("donut-swin", ("DonutImageProcessor", "DonutImageProcessorFast")), - ("dpt", ("DPTImageProcessor", "DPTImageProcessorFast")), - ("edgetam", (None, "Sam2ImageProcessorFast")), - ("efficientloftr", ("EfficientLoFTRImageProcessor", "EfficientLoFTRImageProcessorFast")), - ("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")), - ("emu3", ("Emu3ImageProcessor", None)), - ("eomt", ("EomtImageProcessor", "EomtImageProcessorFast")), - ("eomt_dinov3", ("EomtImageProcessor", "EomtImageProcessorFast")), - ("ernie4_5_vl_moe", ("Ernie4_5_VLMoeImageProcessor", "Ernie4_5_VLMoeImageProcessorFast")), - ("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")), - ("florence2", ("CLIPImageProcessor", "CLIPImageProcessorFast")), - ("focalnet", ("BitImageProcessor", "BitImageProcessorFast")), - ("fuyu", ("FuyuImageProcessor", "FuyuImageProcessorFast")), - ("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")), - ("gemma3n", ("SiglipImageProcessor", "SiglipImageProcessorFast")), - ("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")), - ("glm46v", ("Glm46VImageProcessor", "Glm46VImageProcessorFast")), - ("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")), - ("glm_image", ("GlmImageImageProcessor", "GlmImageImageProcessorFast")), - ("glpn", ("GLPNImageProcessor", "GLPNImageProcessorFast")), - ("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")), - ("grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")), - ("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")), - ("hiera", ("BitImageProcessor", "BitImageProcessorFast")), - ("idefics", ("IdeficsImageProcessor", None)), - ("idefics2", ("Idefics2ImageProcessor", "Idefics2ImageProcessorFast")), - ("idefics3", ("Idefics3ImageProcessor", "Idefics3ImageProcessorFast")), - ("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")), - ("imagegpt", ("ImageGPTImageProcessor", "ImageGPTImageProcessorFast")), - ("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")), - ("internvl", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")), - ("isaac", (None, "IsaacImageProcessorFast")), - ("janus", ("JanusImageProcessor", "JanusImageProcessorFast")), - ("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")), - ("kosmos-2.5", ("Kosmos2_5ImageProcessor", "Kosmos2_5ImageProcessorFast")), - ("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")), - ("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")), - ("layoutxlm", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessor")), - ("levit", ("LevitImageProcessor", "LevitImageProcessorFast")), - ("lfm2_vl", (None, "Lfm2VlImageProcessorFast")), - ("lightglue", ("LightGlueImageProcessor", "LightGlueImageProcessorFast")), - ("lighton_ocr", ("PixtralImageProcessor", "PixtralImageProcessorFast")), - ("llama4", (None, "Llama4ImageProcessorFast")), - ("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")), - ("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")), - ("llava_next_video", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")), - ("llava_onevision", ("LlavaOnevisionImageProcessor", "LlavaOnevisionImageProcessorFast")), - ("lw_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")), - ("mask2former", ("Mask2FormerImageProcessor", "Mask2FormerImageProcessorFast")), - ("maskformer", ("MaskFormerImageProcessor", "MaskFormerImageProcessorFast")), - ("metaclip_2", ("CLIPImageProcessor", "CLIPImageProcessorFast")), - ("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")), - ("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")), - ("mlcd", ("CLIPImageProcessor", "CLIPImageProcessorFast")), - ("mllama", ("MllamaImageProcessor", "MllamaImageProcessorFast")), - ("mm-grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")), - ("mobilenet_v1", ("MobileNetV1ImageProcessor", "MobileNetV1ImageProcessorFast")), - ("mobilenet_v2", ("MobileNetV2ImageProcessor", "MobileNetV2ImageProcessorFast")), - ("mobilevit", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")), - ("mobilevitv2", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")), - ("nougat", ("NougatImageProcessor", "NougatImageProcessorFast")), - ("omdet-turbo", ("DetrImageProcessor", "DetrImageProcessorFast")), - ("oneformer", ("OneFormerImageProcessor", "OneFormerImageProcessorFast")), - ("ovis2", ("Ovis2ImageProcessor", "Ovis2ImageProcessorFast")), - ("owlv2", ("Owlv2ImageProcessor", "Owlv2ImageProcessorFast")), - ("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")), - ("paddleocr_vl", ("PaddleOCRVLImageProcessor", "PaddleOCRVLImageProcessorFast")), - ("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")), - ("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")), - ("perception_lm", (None, "PerceptionLMImageProcessorFast")), - ("phi4_multimodal", (None, "Phi4MultimodalImageProcessorFast")), - ("pix2struct", ("Pix2StructImageProcessor", "Pix2StructImageProcessorFast")), - ("pixio", ("BitImageProcessor", "BitImageProcessorFast")), - ("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")), - ("poolformer", ("PoolFormerImageProcessor", "PoolFormerImageProcessorFast")), - ("pp_doclayout_v2", (None, "PPDocLayoutV2ImageProcessorFast")), - ("pp_doclayout_v3", (None, "PPDocLayoutV3ImageProcessorFast")), - ("pp_lcnet", (None, "PPLCNetImageProcessorFast")), - ("pp_ocrv5_mobile_det", (None, "PPOCRV5ServerDetImageProcessorFast")), - ("pp_ocrv5_server_det", (None, "PPOCRV5ServerDetImageProcessorFast")), - ("prompt_depth_anything", ("PromptDepthAnythingImageProcessor", "PromptDepthAnythingImageProcessorFast")), - ("pvt", ("PvtImageProcessor", "PvtImageProcessorFast")), - ("pvt_v2", ("PvtImageProcessor", "PvtImageProcessorFast")), - ("qwen2_5_omni", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")), - ("qwen2_5_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")), - ("qwen2_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")), - ("qwen3_5", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")), - ("qwen3_5_moe", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")), - ("qwen3_omni_moe", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")), - ("qwen3_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")), - ("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), - ("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), - ("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")), - ("sam", ("SamImageProcessor", "SamImageProcessorFast")), - ("sam2", (None, "Sam2ImageProcessorFast")), - ("sam2_video", (None, "Sam2ImageProcessorFast")), - ("sam3", (None, "Sam3ImageProcessorFast")), - ("sam3_tracker", (None, "Sam3ImageProcessorFast")), - ("sam3_tracker_video", (None, "Sam3ImageProcessorFast")), - ("sam3_video", (None, "Sam3ImageProcessorFast")), - ("sam_hq", ("SamImageProcessor", "SamImageProcessorFast")), - ("segformer", ("SegformerImageProcessor", "SegformerImageProcessorFast")), - ("seggpt", ("SegGptImageProcessor", None)), - ("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")), - ("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")), - ("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")), - ("smolvlm", ("SmolVLMImageProcessor", "SmolVLMImageProcessorFast")), - ("superglue", ("SuperGlueImageProcessor", "SuperGlueImageProcessorFast")), - ("superpoint", ("SuperPointImageProcessor", "SuperPointImageProcessorFast")), - ("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")), - ("swin", ("ViTImageProcessor", "ViTImageProcessorFast")), - ("swin2sr", ("Swin2SRImageProcessor", "Swin2SRImageProcessorFast")), - ("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")), - ("t5gemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")), - ("t5gemma2_encoder", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")), - ("table-transformer", ("DetrImageProcessor", "DetrImageProcessorFast")), - ("textnet", ("TextNetImageProcessor", "TextNetImageProcessorFast")), - ("timesformer", ("VideoMAEImageProcessor", None)), - ("timm_wrapper", ("TimmWrapperImageProcessor", None)), - ("trocr", ("ViTImageProcessor", "ViTImageProcessorFast")), - ("tvp", ("TvpImageProcessor", "TvpImageProcessorFast")), - ("udop", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")), - ("upernet", ("SegformerImageProcessor", "SegformerImageProcessorFast")), - ("video_llama_3", ("VideoLlama3ImageProcessor", "VideoLlama3ImageProcessorFast")), - ("video_llava", ("VideoLlavaImageProcessor", None)), - ("videomae", ("VideoMAEImageProcessor", None)), - ("vilt", ("ViltImageProcessor", "ViltImageProcessorFast")), - ("vipllava", ("CLIPImageProcessor", "CLIPImageProcessorFast")), - ("vit", ("ViTImageProcessor", "ViTImageProcessorFast")), - ("vit_mae", ("ViTImageProcessor", "ViTImageProcessorFast")), - ("vit_msn", ("ViTImageProcessor", "ViTImageProcessorFast")), - ("vitmatte", ("VitMatteImageProcessor", "VitMatteImageProcessorFast")), - ("vitpose", ("VitPoseImageProcessor", "VitPoseImageProcessorFast")), - ("xclip", ("CLIPImageProcessor", "CLIPImageProcessorFast")), - ("yolos", ("YolosImageProcessor", "YolosImageProcessorFast")), - ("zoedepth", ("ZoeDepthImageProcessor", "ZoeDepthImageProcessorFast")), + ("aimv2", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}), + ("aimv2_vision_model", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}), + ("align", {"torchvision": "EfficientNetImageProcessor", "pil": "EfficientNetImageProcessorPil"}), + ("altclip", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}), + ("aria", {"torchvision": "AriaImageProcessor", "pil": "AriaImageProcessorPil"}), + ("aya_vision", {"torchvision": "GotOcr2ImageProcessor", "pil": "GotOcr2ImageProcessorPil"}), + ("beit", {"torchvision": "BeitImageProcessor", "pil": "BeitImageProcessorPil"}), + ("bit", {"torchvision": "BitImageProcessor", "pil": "BitImageProcessorPil"}), + ("blip", {"torchvision": "BlipImageProcessor", "pil": "BlipImageProcessorPil"}), + ("blip-2", {"torchvision": "BlipImageProcessor", "pil": "BlipImageProcessorPil"}), + ("bridgetower", {"torchvision": "BridgeTowerImageProcessor", "pil": "BridgeTowerImageProcessorPil"}), + ("chameleon", {"torchvision": "ChameleonImageProcessor", "pil": "ChameleonImageProcessorPil"}), + ("chinese_clip", {"torchvision": "ChineseCLIPImageProcessor", "pil": "ChineseCLIPImageProcessorPil"}), + ("chmv2", {"torchvision": "CHMv2ImageProcessor"}), + ("clip", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}), + ("clipseg", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}), + ("cohere2_vision", {"torchvision": "Cohere2VisionImageProcessor"}), + ("colpali", {"torchvision": "SiglipImageProcessor", "pil": "SiglipImageProcessorPil"}), + ("colqwen2", {"torchvision": "Qwen2VLImageProcessor", "pil": "Qwen2VLImageProcessorPil"}), + ( + "conditional_detr", + {"torchvision": "ConditionalDetrImageProcessor", "pil": "ConditionalDetrImageProcessorPil"}, + ), + ("convnext", {"torchvision": "ConvNextImageProcessor", "pil": "ConvNextImageProcessorPil"}), + ("convnextv2", {"torchvision": "ConvNextImageProcessor", "pil": "ConvNextImageProcessorPil"}), + ("cvt", {"torchvision": "ConvNextImageProcessor", "pil": "ConvNextImageProcessorPil"}), + ("data2vec-vision", {"torchvision": "BeitImageProcessor", "pil": "BeitImageProcessorPil"}), + ("deepseek_vl", {"torchvision": "DeepseekVLImageProcessor", "pil": "DeepseekVLImageProcessorPil"}), + ( + "deepseek_vl_hybrid", + {"torchvision": "DeepseekVLHybridImageProcessor", "pil": "DeepseekVLHybridImageProcessorPil"}, + ), + ( + "deformable_detr", + {"torchvision": "DeformableDetrImageProcessor", "pil": "DeformableDetrImageProcessorPil"}, + ), + ("deit", {"torchvision": "DeiTImageProcessor", "pil": "DeiTImageProcessorPil"}), + ("depth_anything", {"torchvision": "DPTImageProcessor", "pil": "DPTImageProcessorPil"}), + ("depth_pro", {"torchvision": "DepthProImageProcessor"}), + ("detr", {"torchvision": "DetrImageProcessor", "pil": "DetrImageProcessorPil"}), + ("dinat", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}), + ("dinov2", {"torchvision": "BitImageProcessor", "pil": "BitImageProcessorPil"}), + ("dinov3_vit", {"torchvision": "DINOv3ViTImageProcessor"}), + ("donut-swin", {"torchvision": "DonutImageProcessor", "pil": "DonutImageProcessorPil"}), + ("dpt", {"torchvision": "DPTImageProcessor", "pil": "DPTImageProcessorPil"}), + ("edgetam", {"torchvision": "Sam2ImageProcessor"}), + ( + "efficientloftr", + {"torchvision": "EfficientLoFTRImageProcessor", "pil": "EfficientLoFTRImageProcessorPil"}, + ), + ("efficientnet", {"torchvision": "EfficientNetImageProcessor", "pil": "EfficientNetImageProcessorPil"}), + ("emu3", {"pil": "Emu3ImageProcessor"}), + ("eomt", {"torchvision": "EomtImageProcessor", "pil": "EomtImageProcessorPil"}), + ("eomt_dinov3", {"torchvision": "EomtImageProcessor", "pil": "EomtImageProcessorPil"}), + ( + "ernie4_5_vl_moe", + {"torchvision": "Ernie4_5_VLMoeImageProcessor", "pil": "Ernie4_5_VLMoeImageProcessorPil"}, + ), + ("flava", {"torchvision": "FlavaImageProcessor", "pil": "FlavaImageProcessorPil"}), + ("florence2", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}), + ("focalnet", {"torchvision": "BitImageProcessor", "pil": "BitImageProcessorPil"}), + ("fuyu", {"torchvision": "FuyuImageProcessor", "pil": "FuyuImageProcessorPil"}), + ("gemma3", {"torchvision": "Gemma3ImageProcessor", "pil": "Gemma3ImageProcessorPil"}), + ("gemma3n", {"torchvision": "SiglipImageProcessor", "pil": "SiglipImageProcessorPil"}), + ("git", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}), + ("glm46v", {"torchvision": "Glm46VImageProcessor", "pil": "Glm46VImageProcessorPil"}), + ("glm4v", {"torchvision": "Glm4vImageProcessor", "pil": "Glm4vImageProcessorPil"}), + ("glm_image", {"torchvision": "GlmImageImageProcessor", "pil": "GlmImageImageProcessorPil"}), + ("glpn", {"torchvision": "GLPNImageProcessor", "pil": "GLPNImageProcessorPil"}), + ("got_ocr2", {"torchvision": "GotOcr2ImageProcessor", "pil": "GotOcr2ImageProcessorPil"}), + ( + "grounding-dino", + {"torchvision": "GroundingDinoImageProcessor", "pil": "GroundingDinoImageProcessorPil"}, + ), + ("groupvit", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}), + ("hiera", {"torchvision": "BitImageProcessor", "pil": "BitImageProcessorPil"}), + ("idefics", {"torchvision": "IdeficsImageProcessor", "pil": "IdeficsImageProcessorPil"}), + ("idefics2", {"torchvision": "Idefics2ImageProcessor", "pil": "Idefics2ImageProcessorPil"}), + ("idefics3", {"torchvision": "Idefics3ImageProcessor", "pil": "Idefics3ImageProcessorPil"}), + ("ijepa", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}), + ("imagegpt", {"torchvision": "ImageGPTImageProcessor", "pil": "ImageGPTImageProcessorPil"}), + ("instructblip", {"torchvision": "BlipImageProcessor", "pil": "BlipImageProcessorPil"}), + ("internvl", {"torchvision": "GotOcr2ImageProcessor", "pil": "GotOcr2ImageProcessorPil"}), + ("isaac", {"torchvision": "IsaacImageProcessor"}), + ("janus", {"torchvision": "JanusImageProcessor", "pil": "JanusImageProcessorPil"}), + ("kosmos-2", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}), + ("kosmos-2.5", {"torchvision": "Kosmos2_5ImageProcessor", "pil": "Kosmos2_5ImageProcessorPil"}), + ("layoutlmv2", {"torchvision": "LayoutLMv2ImageProcessor", "pil": "LayoutLMv2ImageProcessorPil"}), + ("layoutlmv3", {"torchvision": "LayoutLMv3ImageProcessor", "pil": "LayoutLMv3ImageProcessorPil"}), + ("layoutxlm", {"torchvision": "LayoutLMv2ImageProcessor", "pil": "LayoutLMv2ImageProcessorPil"}), + ("levit", {"torchvision": "LevitImageProcessor", "pil": "LevitImageProcessorPil"}), + ("lfm2_vl", {"torchvision": "Lfm2VlImageProcessor"}), + ("lightglue", {"torchvision": "LightGlueImageProcessor", "pil": "LightGlueImageProcessorPil"}), + ("lighton_ocr", {"torchvision": "PixtralImageProcessor", "pil": "PixtralImageProcessorPil"}), + ("llama4", {"torchvision": "Llama4ImageProcessor"}), + ("llava", {"torchvision": "LlavaImageProcessor", "pil": "LlavaImageProcessorPil"}), + ("llava_next", {"torchvision": "LlavaNextImageProcessor", "pil": "LlavaNextImageProcessorPil"}), + ("llava_next_video", {"torchvision": "LlavaNextImageProcessor", "pil": "LlavaNextImageProcessorPil"}), + ( + "llava_onevision", + {"torchvision": "LlavaOnevisionImageProcessor", "pil": "LlavaOnevisionImageProcessorPil"}, + ), + ("lw_detr", {"torchvision": "DeformableDetrImageProcessor", "pil": "DeformableDetrImageProcessorPil"}), + ("mask2former", {"torchvision": "Mask2FormerImageProcessor", "pil": "Mask2FormerImageProcessorPil"}), + ("maskformer", {"torchvision": "MaskFormerImageProcessor", "pil": "MaskFormerImageProcessorPil"}), + ("metaclip_2", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}), + ("mgp-str", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}), + ("mistral3", {"torchvision": "PixtralImageProcessor", "pil": "PixtralImageProcessorPil"}), + ("mlcd", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}), + ("mllama", {"torchvision": "MllamaImageProcessor", "pil": "MllamaImageProcessorPil"}), + ( + "mm-grounding-dino", + {"torchvision": "GroundingDinoImageProcessor", "pil": "GroundingDinoImageProcessorPil"}, + ), + ("mobilenet_v1", {"torchvision": "MobileNetV1ImageProcessor", "pil": "MobileNetV1ImageProcessorPil"}), + ("mobilenet_v2", {"torchvision": "MobileNetV2ImageProcessor", "pil": "MobileNetV2ImageProcessorPil"}), + ("mobilevit", {"torchvision": "MobileViTImageProcessor", "pil": "MobileViTImageProcessorPil"}), + ("mobilevitv2", {"torchvision": "MobileViTImageProcessor", "pil": "MobileViTImageProcessorPil"}), + ("nougat", {"torchvision": "NougatImageProcessor", "pil": "NougatImageProcessorPil"}), + ("omdet-turbo", {"torchvision": "DetrImageProcessor", "pil": "DetrImageProcessorPil"}), + ("oneformer", {"torchvision": "OneFormerImageProcessor", "pil": "OneFormerImageProcessorPil"}), + ("ovis2", {"torchvision": "Ovis2ImageProcessor", "pil": "Ovis2ImageProcessorPil"}), + ("owlv2", {"torchvision": "Owlv2ImageProcessor", "pil": "Owlv2ImageProcessorPil"}), + ("owlvit", {"torchvision": "OwlViTImageProcessor", "pil": "OwlViTImageProcessorPil"}), + ("paddleocr_vl", {"torchvision": "PaddleOCRVLImageProcessor", "pil": "PaddleOCRVLImageProcessorPil"}), + ("paligemma", {"torchvision": "SiglipImageProcessor", "pil": "SiglipImageProcessorPil"}), + ("perceiver", {"torchvision": "PerceiverImageProcessor", "pil": "PerceiverImageProcessorPil"}), + ("perception_lm", {"torchvision": "PerceptionLMImageProcessor"}), + ("phi4_multimodal", {"torchvision": "Phi4MultimodalImageProcessor"}), + ("pi0", {"torchvision": "PI0ImageProcessor"}), + ("pix2struct", {"torchvision": "Pix2StructImageProcessor", "pil": "Pix2StructImageProcessorPil"}), + ("pixio", {"torchvision": "BitImageProcessor", "pil": "BitImageProcessorPil"}), + ("pixtral", {"torchvision": "PixtralImageProcessor", "pil": "PixtralImageProcessorPil"}), + ("poolformer", {"torchvision": "PoolFormerImageProcessor", "pil": "PoolFormerImageProcessorPil"}), + ( + "pp_chart2table", + {"torchvision": "PPChart2TableImageProcessor", "pil": "PPChart2TableImageProcessorPil"}, + ), + ("pp_doclayout_v2", {"torchvision": "PPDocLayoutV2ImageProcessor"}), + ("pp_doclayout_v3", {"torchvision": "PPDocLayoutV3ImageProcessor"}), + ("pp_lcnet", {"torchvision": "PPLCNetImageProcessor"}), + ("pp_ocrv5_mobile_det", {"torchvision": "PPOCRV5ServerDetImageProcessor"}), + ("pp_ocrv5_mobile_rec", {"torchvision": "PPOCRV5ServerRecImageProcessor"}), + ("pp_ocrv5_server_det", {"torchvision": "PPOCRV5ServerDetImageProcessor"}), + ("pp_ocrv5_server_rec", {"torchvision": "PPOCRV5ServerRecImageProcessor"}), + ( + "prompt_depth_anything", + {"torchvision": "PromptDepthAnythingImageProcessor", "pil": "PromptDepthAnythingImageProcessorPil"}, + ), + ("pvt", {"torchvision": "PvtImageProcessor", "pil": "PvtImageProcessorPil"}), + ("pvt_v2", {"torchvision": "PvtImageProcessor", "pil": "PvtImageProcessorPil"}), + ("qwen2_5_omni", {"torchvision": "Qwen2VLImageProcessor", "pil": "Qwen2VLImageProcessorPil"}), + ("qwen2_5_vl", {"torchvision": "Qwen2VLImageProcessor", "pil": "Qwen2VLImageProcessorPil"}), + ("qwen2_vl", {"torchvision": "Qwen2VLImageProcessor", "pil": "Qwen2VLImageProcessorPil"}), + ("qwen3_5", {"torchvision": "Qwen2VLImageProcessor", "pil": "Qwen2VLImageProcessorPil"}), + ("qwen3_5_moe", {"torchvision": "Qwen2VLImageProcessor", "pil": "Qwen2VLImageProcessorPil"}), + ("qwen3_omni_moe", {"torchvision": "Qwen2VLImageProcessor", "pil": "Qwen2VLImageProcessorPil"}), + ("qwen3_vl", {"torchvision": "Qwen2VLImageProcessor", "pil": "Qwen2VLImageProcessorPil"}), + ("regnet", {"torchvision": "ConvNextImageProcessor", "pil": "ConvNextImageProcessorPil"}), + ("resnet", {"torchvision": "ConvNextImageProcessor", "pil": "ConvNextImageProcessorPil"}), + ("rt_detr", {"torchvision": "RTDetrImageProcessor", "pil": "RTDetrImageProcessorPil"}), + ("sam", {"torchvision": "SamImageProcessor", "pil": "SamImageProcessorPil"}), + ("sam2", {"torchvision": "Sam2ImageProcessor"}), + ("sam2_video", {"torchvision": "Sam2ImageProcessor"}), + ("sam3", {"torchvision": "Sam3ImageProcessor"}), + ("sam3_tracker", {"torchvision": "Sam3ImageProcessor"}), + ("sam3_tracker_video", {"torchvision": "Sam3ImageProcessor"}), + ("sam3_video", {"torchvision": "Sam3ImageProcessor"}), + ("sam_hq", {"torchvision": "SamImageProcessor", "pil": "SamImageProcessorPil"}), + ("segformer", {"torchvision": "SegformerImageProcessor", "pil": "SegformerImageProcessorPil"}), + ("seggpt", {"torchvision": "SegGptImageProcessor", "pil": "SegGptImageProcessorPil"}), + ("shieldgemma2", {"torchvision": "Gemma3ImageProcessor", "pil": "Gemma3ImageProcessorPil"}), + ("siglip", {"torchvision": "SiglipImageProcessor", "pil": "SiglipImageProcessorPil"}), + ("siglip2", {"torchvision": "Siglip2ImageProcessor", "pil": "Siglip2ImageProcessorPil"}), + ("slanext", {"torchvision": "SLANeXtImageProcessor"}), + ("smolvlm", {"torchvision": "SmolVLMImageProcessor", "pil": "SmolVLMImageProcessorPil"}), + ("superglue", {"torchvision": "SuperGlueImageProcessor", "pil": "SuperGlueImageProcessorPil"}), + ("superpoint", {"torchvision": "SuperPointImageProcessor", "pil": "SuperPointImageProcessorPil"}), + ("swiftformer", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}), + ("swin", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}), + ("swin2sr", {"torchvision": "Swin2SRImageProcessor", "pil": "Swin2SRImageProcessorPil"}), + ("swinv2", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}), + ("t5gemma2", {"torchvision": "Gemma3ImageProcessor", "pil": "Gemma3ImageProcessorPil"}), + ("t5gemma2_encoder", {"torchvision": "Gemma3ImageProcessor", "pil": "Gemma3ImageProcessorPil"}), + ("table-transformer", {"torchvision": "DetrImageProcessor", "pil": "DetrImageProcessorPil"}), + ("textnet", {"torchvision": "TextNetImageProcessor", "pil": "TextNetImageProcessorPil"}), + ("timesformer", {"pil": "VideoMAEImageProcessorPil", "torchvision": "VideoMAEImageProcessor"}), + ("timm_wrapper", {"pil": "TimmWrapperImageProcessor"}), + ("trocr", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}), + ("tvp", {"torchvision": "TvpImageProcessor", "pil": "TvpImageProcessorPil"}), + ("udop", {"torchvision": "LayoutLMv3ImageProcessor", "pil": "LayoutLMv3ImageProcessorPil"}), + ("upernet", {"torchvision": "SegformerImageProcessor", "pil": "SegformerImageProcessorPil"}), + ("uvdoc", {"torchvision": "UVDocImageProcessor"}), + ("video_llama_3", {"torchvision": "VideoLlama3ImageProcessor", "pil": "VideoLlama3ImageProcessorPil"}), + ("video_llava", {"pil": "VideoLlavaImageProcessor"}), + ("videomae", {"torchvision": "VideoMAEImageProcessor", "pil": "VideoMAEImageProcessorPil"}), + ("vilt", {"torchvision": "ViltImageProcessor", "pil": "ViltImageProcessorPil"}), + ("vipllava", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}), + ("vit", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}), + ("vit_mae", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}), + ("vit_msn", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}), + ("vitmatte", {"torchvision": "VitMatteImageProcessor", "pil": "VitMatteImageProcessorPil"}), + ("vitpose", {"torchvision": "VitPoseImageProcessor", "pil": "VitPoseImageProcessorPil"}), + ("xclip", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}), + ("yolos", {"torchvision": "YolosImageProcessor", "pil": "YolosImageProcessorPil"}), + ("zoedepth", {"torchvision": "ZoeDepthImageProcessor", "pil": "ZoeDepthImageProcessorPil"}), ] ) diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index a38c8f1a571f..814b84df09be 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -99,6 +99,7 @@ ("instructblip", "InstructBlipProcessor"), ("instructblipvideo", "InstructBlipVideoProcessor"), ("internvl", "InternVLProcessor"), + ("isaac", "IsaacProcessor"), ("janus", "JanusProcessor"), ("kosmos-2", "Kosmos2Processor"), ("kosmos-2.5", "Kosmos2_5Processor"), diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index 28fc96790ca3..2f0bdbeccdc2 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -18,12 +18,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ...configuration_utils import PreTrainedConfig, PretrainedConfig, layer_type_validation + +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig, PretrainedConfig from ...modeling_rope_utils import RopeParameters -from ...utils import auto_docstring -@auto_docstring(checkpoint="google/isaac-base-patch16-naflex") +@strict(accept_kwargs=True) class IsaacVisionConfig(PreTrainedConfig): """Vision configuration for Isaac with Pixel Shuffle support. @@ -37,38 +39,22 @@ class IsaacVisionConfig(PreTrainedConfig): model_type = "isaac_vision" base_config_key = "vision_config" - def __init__( - self, - hidden_size=768, - intermediate_size=3072, - num_hidden_layers=12, - num_attention_heads=12, - num_channels=3, - num_patches=256, - patch_size=16, - hidden_act="gelu_pytorch_tanh", - layer_norm_eps=1e-6, - attention_dropout=0.0, - pixel_shuffle_scale_factor=1, - **kwargs, - ): - super().__init__(**kwargs) + hidden_size: int = 768 + intermediate_size: int = 3072 + num_hidden_layers: int = 12 + num_attention_heads: int = 12 + num_channels: int = 3 + patch_size: int | list[int] | tuple[int, int] = 16 + hidden_act: str = "gelu_pytorch_tanh" + layer_norm_eps: float = 1e-6 + attention_dropout: float | int = 0.0 + + num_patches: int = 256 + + pixel_shuffle_scale_factor: int = 1 + - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_channels = num_channels - self.patch_size = patch_size - self.attention_dropout = attention_dropout - self.layer_norm_eps = layer_norm_eps - self.hidden_act = hidden_act - self.num_patches = num_patches - # Add our custom fields - self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor - - -@auto_docstring(checkpoint="Qwen/IsaacText-8B") +@strict(accept_kwargs=True) class IsaacTextConfig(PreTrainedConfig): r""" max_window_layers (`int`, *optional*, defaults to 28): @@ -109,57 +95,36 @@ class IsaacTextConfig(PreTrainedConfig): "norm": (["hidden_states"], ["hidden_states"]), } - def __init__( - self, - vocab_size: int | None = 151936, - hidden_size: int | None = 4096, - intermediate_size: int | None = 22016, - num_hidden_layers: int | None = 32, - num_attention_heads: int | None = 32, - num_key_value_heads: int | None = 32, - head_dim: int | None = 128, - hidden_act: str | None = "silu", - max_position_embeddings: int | None = 32768, - initializer_range: float | None = 0.02, - rms_norm_eps: float | None = 1e-6, - use_cache: bool | None = True, - tie_word_embeddings: bool | None = False, - rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None, - attention_bias: bool | None = False, - use_sliding_window: bool | None = False, - sliding_window: int | None = 4096, - max_window_layers: int | None = 28, - layer_types: list[str] | None = None, - attention_dropout: float | None = 0.0, - pad_token_id: int | None = None, - bos_token_id: int | None = None, - eos_token_id: int | None = None, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window if self.use_sliding_window else None - self.max_window_layers = max_window_layers - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.head_dim = head_dim - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - - self.layer_types = layer_types + vocab_size: int = 151936 + hidden_size: int = 4096 + intermediate_size: int = 22016 + num_hidden_layers: int = 32 + num_attention_heads: int = 32 + num_key_value_heads: int | None = 32 + head_dim: int = 128 + hidden_act: str = "silu" + max_position_embeddings: int = 32768 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + use_cache: bool = True + tie_word_embeddings: bool = False + rope_parameters: RopeParameters | dict | None = None + attention_bias: bool = False + use_sliding_window: bool = False + sliding_window: int | None = 4096 + max_window_layers: int = 28 + layer_types: list[str] | None = None + attention_dropout: float | int = 0.0 + pad_token_id: int | None = None + bos_token_id: int | None = None + eos_token_id: int | list[int] | None = None + ignore_keys_at_rope_validation = {"mrope_section", "mrope_interleaved"} + + def __post_init__(self, **kwargs): + self.sliding_window = self.sliding_window if self.use_sliding_window else None + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + if self.layer_types is None: self.layer_types = [ "sliding_attention" @@ -167,15 +132,8 @@ def __init__( else "full_attention" for i in range(self.num_hidden_layers) ] - layer_type_validation(self.layer_types, self.num_hidden_layers) - - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.tie_word_embeddings = tie_word_embeddings - self.rope_parameters = rope_parameters - - super().__init__(**kwargs) + super().__post_init__(**kwargs) + self.validate_layer_type() class IsaacConfig(PretrainedConfig): @@ -209,6 +167,9 @@ def __init__( vision_token: str = "", **kwargs, ): + for key in ("use_cache", "rope_theta", "max_position_embeddings"): + kwargs.pop(key, None) + if isinstance(text_config, dict): self.text_config = self.sub_configs["text_config"](**text_config) elif isinstance(text_config, IsaacTextConfig): @@ -225,12 +186,6 @@ def __init__( super().__init__(**kwargs) - # Mirror frequently accessed composite-level attributes. - self.use_cache = self.text_config.use_cache - self.rope_theta = self.text_config.rope_parameters["rope_theta"] - self.max_position_embeddings = getattr(self.text_config, "max_position_embeddings", max_sequence_length) - self.text_config.max_position_embeddings = self.max_position_embeddings - # Vision normalization parameters self.vision_rescale_factor = float(vision_rescale_factor) diff --git a/src/transformers/models/isaac/image_processing_isaac.py b/src/transformers/models/isaac/image_processing_isaac.py index 755690da955a..a4c3275f98cb 100644 --- a/src/transformers/models/isaac/image_processing_isaac.py +++ b/src/transformers/models/isaac/image_processing_isaac.py @@ -4,8 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_isaac.py file directly. One of our CI enforces this. # ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# coding=utf-8 -# Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved. +# Copyright 2026 Perceptron, Inc and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,12 +18,342 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math +from collections.abc import Sequence +from typing import Any -from ...image_processing_utils_fast import ImagesKwargs +from ... import TorchvisionBackend +from ...feature_extraction_utils import BatchFeature +from ...image_processing_utils_fast import ImagesKwargs, SizeDict, group_images_by_shape, reorder_images +from ...image_utils import ImageInput, PILImageResampling, make_nested_list_of_images +from ...utils import TensorType, auto_docstring +from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN +from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD +from ...utils.import_utils import is_torch_available + + +if is_torch_available(): + import torch + import torch.nn.functional as F class IsaacImageProcessorKwargs(ImagesKwargs, total=False): - patch_size: int | None - max_num_patches: int | None - min_num_patches: int | None - pixel_shuffle_scale: int | None + """ + patch_size (`int`, *optional*): + Side length (in pixels) for square patches extracted from resized images. + max_num_patches (`int`, *optional*): + Upper bound on extracted patches per image after resizing. + min_num_patches (`int`, *optional*): + Lower bound on extracted patches per image after resizing. + pixel_shuffle_scale (`int`, *optional*): + Pixel-shuffle reduction factor applied in the vision tower. + """ + + patch_size: int + max_num_patches: int + min_num_patches: int + pixel_shuffle_scale: int + + +# Disable as it causes issues with torch.compile +@torch.compiler.disable +def torch_extract_patches(image_tensor, patch_height, patch_width): + """ + Extract patches from image tensor. Returns tensor of shape (batch, rows, columns, patch_height*patch_width*channels). + + Args: + image_tensor (`torch.Tensor`): + Image tensor of shape (batch, channels, height, width). + patch_height (`int`): + Height of patches to extract. + patch_width (`int`): + Width of patches to extract. + """ + batch_size, channels, height, width = image_tensor.shape + patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width)) + patches = patches.reshape(batch_size, channels, patch_height, patch_width, -1) + patches = patches.permute(0, 4, 2, 3, 1).reshape( + batch_size, height // patch_height, width // patch_width, channels * patch_height * patch_width + ) + return patches + + +def get_scaled_image_size( + scale: float, + original_size: int, + patch_size: int, + pixel_shuffle_scale: int, +) -> int: + scaled_size = scale * original_size + divisor = patch_size * pixel_shuffle_scale + scaled_size = math.ceil(scaled_size / divisor) * divisor + scaled_size = max(divisor, scaled_size) + return int(scaled_size) + + +def get_image_size_for_max_num_patches( + image_height: int, + image_width: int, + patch_size: int, + max_num_patches: int, + min_num_patches: int | None = None, + eps: float = 1e-5, + pixel_shuffle_scale: int = 1, +) -> tuple[int, int]: + r"""Compute a target resolution whose patch grid satisfies patching parametrization. + + Args: + image_height (`int`): + Height in pixels of the source image prior to any resizing. + image_width (`int`): + Width in pixels of the source image prior to any resizing. + patch_size (`int`): + Size of the square patch used by the vision encoder. + max_num_patches (`int`): + Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. + min_num_patches (`int`, *optional*): + Lower bound on the number of patches. When provided the image will be scaled up if necessary. + eps (`float`, *optional*, defaults to 1e-5): + Convergence tolerance for the internal binary search to determing the target dimensions. + pixel_shuffle_scale (`int`, *optional*, defaults to 1): + Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. + + Returns: + `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` + and respect both the maximum and optional minimum patch-count constraints. + """ + + # Ensure divisibility + divisor = patch_size * pixel_shuffle_scale + adjusted_height = math.ceil(image_height / divisor) * divisor + adjusted_height = max(divisor, adjusted_height) + adjusted_width = math.ceil(image_width / divisor) * divisor + adjusted_width = max(divisor, adjusted_width) + + num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) + + if min_num_patches is not None and num_patches < min_num_patches: + # Scale up via binary search to satisfy the minimum patch budget while + # preserving divisibility by patch_size * pixel_shuffle_scale. + scale_min, scale_max = 1.0, 100.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + num_patches = (target_height / patch_size) * (target_width / patch_size) + if num_patches >= min_num_patches: + scale_max = scale + else: + scale_min = scale + scale = scale_max + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + return target_height, target_width + elif num_patches <= max_num_patches: + return adjusted_height, adjusted_width + else: + # Scale down + scale_min, scale_max = eps / 10, 1.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + num_patches = (target_height / patch_size) * (target_width / patch_size) + if num_patches <= max_num_patches: + scale_min = scale + else: + scale_max = scale + scale = scale_min + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + return target_height, target_width + + +@auto_docstring +class IsaacImageProcessor(TorchvisionBackend): + MAX_PIXELS = 60_000_000 # 60โ€‘megapixel ceiling โ‰ˆ 8200 ร— 7300 px + + resample = PILImageResampling.BILINEAR + model_input_names = [ + "vision_patches", + "vision_patch_attention_mask", + "vision_token_grids", + ] + valid_kwargs = IsaacImageProcessorKwargs + + do_resize = True + do_center_crop = False + patch_size: int | None = 16 + max_num_patches: int | None = 256 + min_num_patches: int | None = None + pixel_shuffle_scale: int | None = 1 + do_pad = True + do_rescale = True + do_normalize = True + image_mean = list(VISION_MEAN) + image_std = list(VISION_STD) + do_convert_rgb = True + disable_grouping = False + + def _validate_preprocess_kwargs(self, **kwargs): + # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for. + kwargs.pop("do_resize", None) + return super()._validate_preprocess_kwargs(**kwargs) + + def _prepare_images_structure( + self, + images: ImageInput, + expected_ndims: int = 3, + ) -> ImageInput: + images = self.fetch_images(images) + return make_nested_list_of_images(images, expected_ndims=expected_ndims) + + def resize( + self, + image: torch.Tensor, + size: SizeDict, + **kwargs, + ) -> torch.Tensor: + return F.interpolate(image, size=(size.height, size.width), mode="bilinear", align_corners=False) + + def pad( + self, + vision_patches: list[list[torch.Tensor]], + vision_token_grids: list[list[torch.Tensor]], + ) -> dict[str, torch.Tensor]: + batch_size = len(vision_patches) + first_patch = next(patches for sample_patches in vision_patches for patches in sample_patches) + max_images = max(len(sample_patches) for sample_patches in vision_patches) + max_patches = max(patches.shape[0] for sample_patches in vision_patches for patches in sample_patches) + patch_dim = first_patch.shape[-1] + patch_dtype = first_patch.dtype + patch_device = first_patch.device + + tensors = { + "vision_patches": torch.zeros( + (batch_size, max_images, max_patches, patch_dim), device=patch_device, dtype=patch_dtype + ), + "vision_patch_attention_mask": torch.zeros( + (batch_size, max_images, max_patches), device=patch_device, dtype=torch.long + ), + "vision_token_grids": torch.zeros((batch_size, max_images, 2), device=patch_device, dtype=torch.long), + } + + for batch_idx, (sample_patches, sample_token_grids) in enumerate( + zip(vision_patches, vision_token_grids, strict=True) + ): + for image_idx, (patches, token_grid) in enumerate(zip(sample_patches, sample_token_grids, strict=True)): + patch_count = int(patches.shape[0]) + tensors["vision_patches"][batch_idx, image_idx, :patch_count] = patches + tensors["vision_patch_attention_mask"][batch_idx, image_idx, :patch_count] = 1 + tensors["vision_token_grids"][batch_idx, image_idx] = token_grid + + return tensors + + def _preprocess( + self, + images: list[list[torch.Tensor]], + do_resize: bool, + resample: Any | None, + do_rescale: bool | None, + rescale_factor: float | None, + do_normalize: bool | None, + image_mean: float | Sequence[float] | None, + image_std: float | Sequence[float] | None, + do_pad: bool | None = None, + disable_grouping: bool | None = None, + return_tensors: str | TensorType | None = None, + patch_size: int | None = None, + max_num_patches: int | None = None, + min_num_patches: int | None = None, + pixel_shuffle_scale: int | None = None, + **kwargs, + ) -> BatchFeature: + resample = kwargs.pop("interpolation", resample) + batch_size = len(images) + # IsaacProcessor routes text-only calls here as an empty image list per sample. + # This returns empty vision tensors to preserve the multimodal output schema; + # image-token/image-count mismatches are validated earlier in processor's _preprocess call. + if all(len(sample_images) == 0 for sample_images in images): + tensors = { + "vision_patches": torch.zeros((batch_size, 0, 0, 0), dtype=torch.float32), + "vision_patch_attention_mask": torch.zeros((batch_size, 0, 0), dtype=torch.long), + "vision_token_grids": torch.zeros((batch_size, 0, 2), dtype=torch.long), + } + return BatchFeature(data=tensors, tensor_type=return_tensors) + + grouped_images, grouped_images_index = group_images_by_shape( + images, disable_grouping=disable_grouping, is_nested=True + ) + + grouped_outputs = {} + + for shape, stacked_images in grouped_images.items(): + grouped_batch_size, channels, original_height, original_width = stacked_images.shape + target_height, target_width = get_image_size_for_max_num_patches( + original_height, + original_width, + patch_size, + max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + ) + if do_resize: + image_batch = self.resize( + stacked_images, SizeDict(height=target_height, width=target_width), resample=resample + ) + else: + if (original_height % patch_size) or (original_width % patch_size): + raise ValueError( + f"Image dimensions (h={original_height}, w={original_width}) must be divisible by patch_size={patch_size} when resize is disabled; enable resizing or adjust the input resolution." + ) + image_batch, target_height, target_width = stacked_images, original_height, original_width + + image_batch = self.rescale_and_normalize( + image_batch, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + ) + + patches = torch_extract_patches(image_batch, patch_size, patch_size) + _, height_tokens, width_tokens, patch_dim = patches.shape + + token_grid = ( + torch.tensor([height_tokens, width_tokens], device=patches.device).long().expand(grouped_batch_size, 2) + ) + + if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): + raise ValueError( + f"Token grid (h={height_tokens}, w={width_tokens}) must be divisible by pixel_shuffle_scale={pixel_shuffle_scale}; adjust resize/patch parameters or disable pixel shuffle." + ) + + grouped_outputs[shape] = ( + patches.reshape(grouped_batch_size, -1, patch_dim), + token_grid, + ) + + keys = ("vision_patches", "vision_token_grids") + nested_outputs = { + key: reorder_images( + {shape: values[i] for shape, values in grouped_outputs.items()}, + grouped_images_index, + is_nested=True, + ) + for i, key in enumerate(keys) + } + + if not do_pad: + raise ValueError("IsaacImageProcessor doesn't support `do_pad=False` mode.") + + tensors = self.pad( + vision_patches=nested_outputs["vision_patches"], + vision_token_grids=nested_outputs["vision_token_grids"], + ) + + return BatchFeature(data=tensors, tensor_type=return_tensors) + + +__all__ = ["IsaacImageProcessor"] diff --git a/src/transformers/models/isaac/image_processing_isaac_fast.py b/src/transformers/models/isaac/image_processing_isaac_fast.py deleted file mode 100644 index db1a8b52a756..000000000000 --- a/src/transformers/models/isaac/image_processing_isaac_fast.py +++ /dev/null @@ -1,330 +0,0 @@ -# ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# This file was automatically generated from src/transformers/models/isaac/modular_isaac.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_isaac.py file directly. One of our CI enforces this. -# ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ๐Ÿšจ -# Copyright 2026 Perceptron, Inc and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from collections.abc import Sequence -from typing import Any - -from ...feature_extraction_utils import BatchFeature -from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict, group_images_by_shape, reorder_images -from ...image_utils import ImageInput, PILImageResampling, make_nested_list_of_images -from ...utils import TensorType, auto_docstring -from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN -from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD -from ...utils.import_utils import is_torch_available -from .modeling_isaac import IsaacImageProcessorFastKwargs - - -if is_torch_available(): - import torch - import torch.nn.functional as F - - -# Disable as it causes issues with torch.compile -@torch.compiler.disable -def torch_extract_patches(image_tensor, patch_height, patch_width): - """ - Extract patches from image tensor. Returns tensor of shape (batch, rows, columns, patch_height*patch_width*channels). - - Args: - image_tensor (`torch.Tensor`): - Image tensor of shape (batch, channels, height, width). - patch_height (`int`): - Height of patches to extract. - patch_width (`int`): - Width of patches to extract. - """ - batch_size, channels, height, width = image_tensor.shape - patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width)) - patches = patches.reshape(batch_size, channels, patch_height, patch_width, -1) - patches = patches.permute(0, 4, 2, 3, 1).reshape( - batch_size, height // patch_height, width // patch_width, channels * patch_height * patch_width - ) - return patches - - -def get_scaled_image_size( - scale: float, - original_size: int, - patch_size: int, - pixel_shuffle_scale: int, -) -> int: - scaled_size = scale * original_size - divisor = patch_size * pixel_shuffle_scale - scaled_size = math.ceil(scaled_size / divisor) * divisor - scaled_size = max(divisor, scaled_size) - return int(scaled_size) - - -def get_image_size_for_max_num_patches( - image_height: int, - image_width: int, - patch_size: int, - max_num_patches: int, - min_num_patches: int | None = None, - eps: float = 1e-5, - pixel_shuffle_scale: int = 1, -) -> tuple[int, int]: - r"""Compute a target resolution whose patch grid satisfies patching parametrization. - - Args: - image_height (`int`): - Height in pixels of the source image prior to any resizing. - image_width (`int`): - Width in pixels of the source image prior to any resizing. - patch_size (`int`): - Size of the square patch used by the vision encoder. - max_num_patches (`int`): - Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. - min_num_patches (`int`, *optional*): - Lower bound on the number of patches. When provided the image will be scaled up if necessary. - eps (`float`, *optional*, defaults to 1e-5): - Convergence tolerance for the internal binary search to determing the target dimensions. - pixel_shuffle_scale (`int`, *optional*, defaults to 1): - Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. - - Returns: - `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` - and respect both the maximum and optional minimum patch-count constraints. - """ - - # Ensure divisibility - divisor = patch_size * pixel_shuffle_scale - adjusted_height = math.ceil(image_height / divisor) * divisor - adjusted_height = max(divisor, adjusted_height) - adjusted_width = math.ceil(image_width / divisor) * divisor - adjusted_width = max(divisor, adjusted_width) - - num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) - - if min_num_patches is not None and num_patches < min_num_patches: - # Scale up via binary search to satisfy the minimum patch budget while - # preserving divisibility by patch_size * pixel_shuffle_scale. - scale_min, scale_max = 1.0, 100.0 - while (scale_max - scale_min) >= eps: - scale = (scale_min + scale_max) / 2 - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) - num_patches = (target_height / patch_size) * (target_width / patch_size) - if num_patches >= min_num_patches: - scale_max = scale - else: - scale_min = scale - scale = scale_max - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) - return target_height, target_width - elif num_patches <= max_num_patches: - return adjusted_height, adjusted_width - else: - # Scale down - scale_min, scale_max = eps / 10, 1.0 - while (scale_max - scale_min) >= eps: - scale = (scale_min + scale_max) / 2 - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) - num_patches = (target_height / patch_size) * (target_width / patch_size) - if num_patches <= max_num_patches: - scale_min = scale - else: - scale_max = scale - scale = scale_min - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) - return target_height, target_width - - -@auto_docstring -class IsaacImageProcessorFast(BaseImageProcessorFast): - MAX_PIXELS = 60_000_000 # 60โ€‘megapixel ceiling โ‰ˆ 8200 ร— 7300 px - - resample = PILImageResampling.BILINEAR - model_input_names = [ - "vision_patches", - "vision_patch_attention_mask", - "vision_token_grids", - ] - valid_kwargs = IsaacImageProcessorFastKwargs - unused_kwargs = ["size", "do_center_crop", "crop_size", "pad_size", "do_pad"] - - do_resize = True - do_center_crop = False - patch_size: int | None = 16 - max_num_patches: int | None = 256 - min_num_patches: int | None = None - pixel_shuffle_scale: int | None = 1 - do_pad = False - do_rescale = True - do_normalize = True - image_mean = list(VISION_MEAN) - image_std = list(VISION_STD) - do_convert_rgb = True - disable_grouping = False - - def _validate_preprocess_kwargs(self, **kwargs): - # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for. - kwargs.pop("do_resize", None) - return super()._validate_preprocess_kwargs(**kwargs) - - def _prepare_images_structure( - self, - images: ImageInput, - expected_ndims: int = 3, - ) -> ImageInput: - images = self.fetch_images(images) - return make_nested_list_of_images(images, expected_ndims=expected_ndims) - - def resize( - self, - image: torch.Tensor, - size: SizeDict, - **kwargs, - ) -> torch.Tensor: - return F.interpolate(image, size=(size.height, size.width), mode="bilinear", align_corners=False) - - def _preprocess( - self, - images: list[list[torch.Tensor]], - do_resize: bool, - interpolation: Any | None, - do_rescale: bool | None, - rescale_factor: float | None, - do_normalize: bool | None, - image_mean: float | Sequence[float] | None, - image_std: float | Sequence[float] | None, - disable_grouping: bool | None = None, - return_tensors: str | TensorType | None = None, - patch_size: int | None = None, - max_num_patches: int | None = None, - min_num_patches: int | None = None, - pixel_shuffle_scale: int | None = None, - **kwargs, - ) -> BatchFeature: - batch_size = len(images) - if all(len(sample_images) == 0 for sample_images in images): - tensors = { - "vision_patches": torch.zeros((batch_size, 0, 0, 0), dtype=torch.float32), - "vision_patch_attention_mask": torch.zeros((batch_size, 0, 0), dtype=torch.long), - "vision_token_grids": torch.zeros((batch_size, 0, 2), dtype=torch.long), - } - return BatchFeature(data=tensors, tensor_type=return_tensors) - - grouped_images, grouped_images_index = group_images_by_shape( - images, disable_grouping=disable_grouping, is_nested=True - ) - - grouped_outputs = {} - - for shape, stacked_images in grouped_images.items(): - grouped_batch_size, channels, original_height, original_width = stacked_images.shape - target_height, target_width = get_image_size_for_max_num_patches( - original_height, - original_width, - patch_size, - max_num_patches, - min_num_patches=min_num_patches, - pixel_shuffle_scale=pixel_shuffle_scale, - ) - if do_resize: - image_batch = self.resize( - stacked_images, SizeDict(height=target_height, width=target_width), interpolation=interpolation - ) - else: - if (original_height % patch_size) or (original_width % patch_size): - raise ValueError( - f"Image dimensions (h={original_height}, w={original_width}) must be divisible by patch_size={patch_size} when resize is disabled; enable resizing or adjust the input resolution." - ) - image_batch, target_height, target_width = stacked_images, original_height, original_width - - image_batch = self.rescale_and_normalize( - image_batch, - do_rescale=do_rescale, - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - ) - - patches = torch_extract_patches(image_batch, patch_size, patch_size) - _, height_tokens, width_tokens, patch_dim = patches.shape - - token_grid = ( - torch.tensor([height_tokens, width_tokens], device=patches.device).long().expand(grouped_batch_size, 2) - ) - - if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): - raise ValueError( - f"Token grid (h={height_tokens}, w={width_tokens}) must be divisible by pixel_shuffle_scale={pixel_shuffle_scale}; adjust resize/patch parameters or disable pixel shuffle." - ) - - grouped_outputs[shape] = ( - patches.reshape(grouped_batch_size, -1, patch_dim), - token_grid, - ) - - keys = ("vision_patches", "vision_token_grids") - nested_outputs = { - key: reorder_images( - {shape: values[i] for shape, values in grouped_outputs.items()}, - grouped_images_index, - is_nested=True, - ) - for i, key in enumerate(keys) - } - - first_patch = next( - patches for sample_patches in nested_outputs["vision_patches"] for patches in sample_patches - ) - max_images = max(len(sample_patches) for sample_patches in nested_outputs["vision_patches"]) - patch_dim = first_patch.shape[-1] - patch_dtype = first_patch.dtype - patch_device = first_patch.device - max_patches = max( - patches.shape[0] for sample_patches in nested_outputs["vision_patches"] for patches in sample_patches - ) - - tensors = { - "vision_patches": torch.zeros( - (batch_size, max_images, max_patches, patch_dim), device=patch_device, dtype=patch_dtype - ), - "vision_patch_attention_mask": torch.zeros( - (batch_size, max_images, max_patches), device=patch_device, dtype=torch.long - ), - "vision_token_grids": torch.zeros((batch_size, max_images, 2), device=patch_device, dtype=torch.long), - } - - for batch_idx, sample_patches in enumerate(nested_outputs["vision_patches"]): - sample_image_count = len(sample_patches) - if sample_image_count == 0: - continue - - for image_idx, patches in enumerate(sample_patches): - patch_count = int(patches.shape[0]) - - tensors["vision_patches"][batch_idx, image_idx, :patch_count] = patches - tensors["vision_patch_attention_mask"][batch_idx, image_idx, :patch_count] = 1 - tensors["vision_token_grids"][batch_idx, image_idx] = nested_outputs["vision_token_grids"][batch_idx][ - image_idx - ] - - return BatchFeature(data=tensors, tensor_type=return_tensors) - - -__all__ = ["IsaacImageProcessorFast"] diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index e1236be08020..e843db2ce2b4 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -18,8 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import copy from collections.abc import Callable from typing import Any, NamedTuple, Optional @@ -27,7 +25,6 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation.utils import GenerationMixin -from ...image_processing_utils_fast import ImagesKwargs from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func from ...masking_utils import create_bidirectional_mask, create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -72,22 +69,10 @@ class BoundingBox(NamedTuple): t: float | None = None -class IsaacImageProcessorFastKwargs(ImagesKwargs, total=False): - """ - patch_size (`int`, *optional*): - Side length (in pixels) for square patches extracted from resized images. - max_num_patches (`int`, *optional*): - Upper bound on extracted patches per image after resizing. - min_num_patches (`int`, *optional*): - Lower bound on extracted patches per image after resizing. - pixel_shuffle_scale (`int`, *optional*): - Pixel-shuffle reduction factor applied in the vision tower. - """ - - patch_size: int - max_num_patches: int - min_num_patches: int - pixel_shuffle_scale: int +class Polygon(NamedTuple): + points: tuple[SinglePoint, ...] + mention: str | None = None + t: float | None = None class IsaacVisionEmbeddings(nn.Module): @@ -443,6 +428,7 @@ class IsaacVisionTransformer(PreTrainedModel): """ + config: IsaacVisionConfig _supports_sdpa = True _supports_flash_attn = True _can_record_outputs = { @@ -537,12 +523,9 @@ def forward(self, image_features): class IsaacRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` - def __init__(self, config: IsaacConfig | IsaacTextConfig, device=None): + def __init__(self, config: IsaacTextConfig, device=None): super().__init__() - rope_source_cfg = config.get_text_config() - config_for_rope = copy.copy(rope_source_cfg) - rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} - config_for_rope.rope_scaling = rope_scaling + rope_parameters = config.rope_parameters self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings @@ -557,8 +540,8 @@ def __init__(self, config: IsaacConfig | IsaacTextConfig, device=None): self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) - self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), self.inv_freq.shape[0]) - self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) + self.mrope_section = self._resolve_mrope_section(rope_parameters.get("mrope_section"), self.inv_freq.shape[0]) + self.hidden_size = config.hidden_size @staticmethod def compute_default_rope_parameters( @@ -1104,7 +1087,6 @@ def __init__(self, config: IsaacConfig, layer_idx: int): self.mlp = IsaacMLP(config) self.input_layernorm = IsaacRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = IsaacRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -1155,6 +1137,7 @@ class IsaacModel(PreTrainedModel): "attentions": IsaacAttention, } _tied_weights_keys = {} + _input_embed_layer = "text_model.embed_tokens" def __init__(self, config: IsaacConfig): Qwen3PreTrainedModel.__init__(self, config) @@ -1169,12 +1152,6 @@ def __init__(self, config: IsaacConfig): self.post_init() - def get_input_embeddings(self) -> nn.Module: - return self.text_model.get_input_embeddings() - - def set_input_embeddings(self, value: nn.Module) -> None: - self.text_model.set_input_embeddings(value) - @can_return_tuple @auto_docstring def get_image_features( @@ -1202,64 +1179,57 @@ def get_image_features( image_attention_mask (`torch.Tensor`, *optional*): Mask indicating which image slots are populated, shaped `(batch_size, max_images)`. """ - device = self.text_model.embed_tokens.weight.device - pixel_values = pixel_values.to(device=device) - image_token_grids = image_token_grids.to(device=device, dtype=torch.long) - patch_attention_mask = image_patch_attention_mask.to(device=device, dtype=torch.long) + image_token_grids = image_token_grids.to(dtype=torch.long) + patch_attention_mask = image_patch_attention_mask.to(dtype=torch.long) if image_attention_mask is None: if image_token_lengths is not None: - image_attention_mask = image_token_lengths.to(device=device, dtype=torch.long) > 0 + image_attention_mask = image_token_lengths > 0 else: image_attention_mask = image_token_grids.any(dim=-1) else: - image_attention_mask = image_attention_mask.to(device=device, dtype=torch.bool) + image_attention_mask = image_attention_mask.to(dtype=torch.bool) + + torch_compilable_check( + image_attention_mask.any(), + "IsaacModel.get_image_features expects at least one active image slot; text-only inputs should skip this method.", + ) batch_size, max_images = pixel_values.shape[:2] hidden_size = self.config.get_text_config().hidden_size - if image_attention_mask.any(): - vision_kwargs = { - key: value - for key in ("output_hidden_states", "output_attentions") - if (value := kwargs.get(key)) is not None - } - vision_outputs = self.vision_tower( - vision_patches=pixel_values[image_attention_mask], - vision_token_grids=image_token_grids[image_attention_mask], - vision_patch_attention_mask=patch_attention_mask[image_attention_mask], - return_dict=True, - **vision_kwargs, - ) - flat_projected_features = self.multimodal_projector(vision_outputs.last_hidden_state) - max_tokens = flat_projected_features.shape[1] - projected_features = flat_projected_features.new_zeros((batch_size, max_images, max_tokens, hidden_size)) - projected_features[image_attention_mask] = flat_projected_features - offsets = ( - image_token_offsets.to(device=device, dtype=torch.long) - if image_token_offsets is not None - else torch.zeros((batch_size, max_images), device=device, dtype=torch.long) - ) - lengths = ( - image_token_lengths.to(device=device, dtype=torch.long) - if image_token_lengths is not None - else torch.full((batch_size, max_images), max_tokens, device=device, dtype=torch.long) - ) - flat_offsets = offsets[image_attention_mask] - flat_lengths = lengths[image_attention_mask] - token_positions = torch.arange(flat_lengths.max(), device=device, dtype=torch.long) - gather_positions = flat_offsets[:, None] + token_positions[None, :] - gather_mask = token_positions[None, :] < flat_lengths[:, None] - image_features = flat_projected_features[ - torch.arange(flat_projected_features.shape[0], device=device, dtype=torch.long)[:, None], - gather_positions, - ][gather_mask] - hidden_states = vision_outputs.hidden_states - attentions = vision_outputs.attentions - else: - projected_features = pixel_values.new_zeros((batch_size, max_images, 0, hidden_size)) - image_features = pixel_values.new_zeros((0, hidden_size)) - hidden_states = None - attentions = None + vision_outputs = self.vision_tower( + vision_patches=pixel_values[image_attention_mask], + vision_token_grids=image_token_grids[image_attention_mask], + vision_patch_attention_mask=patch_attention_mask[image_attention_mask], + return_dict=True, + **kwargs, + ) + flat_projected_features = self.multimodal_projector(vision_outputs.last_hidden_state) + max_tokens = flat_projected_features.shape[1] + projected_features = flat_projected_features.new_zeros((batch_size, max_images, max_tokens, hidden_size)) + projected_features[image_attention_mask] = flat_projected_features + feature_device = flat_projected_features.device + offsets = ( + image_token_offsets.to(dtype=torch.long) + if image_token_offsets is not None + else torch.zeros((batch_size, max_images), device=feature_device, dtype=torch.long) + ) + lengths = ( + image_token_lengths.to(dtype=torch.long) + if image_token_lengths is not None + else torch.full((batch_size, max_images), max_tokens, device=feature_device, dtype=torch.long) + ) + flat_offsets = offsets[image_attention_mask] + flat_lengths = lengths[image_attention_mask] + token_positions = torch.arange(flat_lengths.max(), device=feature_device, dtype=torch.long) + gather_positions = flat_offsets[:, None] + token_positions[None, :] + gather_mask = token_positions[None, :] < flat_lengths[:, None] + image_features = flat_projected_features[ + torch.arange(flat_projected_features.shape[0], device=feature_device, dtype=torch.long)[:, None], + gather_positions, + ][gather_mask] + hidden_states = vision_outputs.hidden_states + attentions = vision_outputs.attentions return BaseModelOutputWithPooling( last_hidden_state=projected_features, @@ -1274,7 +1244,7 @@ def get_placeholder_mask( inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor, ) -> torch.BoolTensor: - image_token_mask = mm_token_type_ids.to(device=inputs_embeds.device, dtype=torch.long) == 1 + image_token_mask = mm_token_type_ids.to(dtype=torch.long) == 1 n_image_tokens = image_token_mask.sum() image_token_mask = image_token_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) torch_compilable_check( @@ -1321,11 +1291,11 @@ def get_rope_index( device = attention_mask.device batch_size, seq_len = attention_mask.shape - mm_token_type_ids = mm_token_type_ids.to(device=device, dtype=torch.long) - image_token_grids = image_token_grids.to(device=device, dtype=torch.long) - image_token_offsets = image_token_offsets.to(device=device, dtype=torch.long) - image_token_lengths = image_token_lengths.to(device=device, dtype=torch.long) - attention_mask = attention_mask.to(device=device, dtype=torch.long) + mm_token_type_ids = mm_token_type_ids.to(dtype=torch.long) + image_token_grids = image_token_grids.to(dtype=torch.long) + image_token_offsets = image_token_offsets.to(dtype=torch.long) + image_token_lengths = image_token_lengths.to(dtype=torch.long) + attention_mask = attention_mask.to(dtype=torch.long) image_attention_mask = image_token_lengths > 0 position_ids = torch.zeros((3, batch_size, seq_len), device=device, dtype=torch.long) @@ -1384,7 +1354,6 @@ def compute_3d_position_ids( image_token_grids: torch.Tensor | None = None, image_token_offsets: torch.Tensor | None = None, image_token_lengths: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, past_key_values: torch.Tensor | None = None, ) -> torch.Tensor: past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() @@ -1401,13 +1370,6 @@ def compute_3d_position_ids( self.rope_deltas = rope_deltas return position_ids - if position_ids is not None and past_seen_tokens == 0: - position_ids = position_ids.to(device=inputs_embeds.device) - if position_ids.ndim == 2: - return position_ids.view(1, position_ids.shape[0], -1).expand(3, -1, -1) - if position_ids.ndim == 3 and position_ids.shape[0] in (1, 4): - return position_ids - if self.rope_deltas is None: return None @@ -1443,7 +1405,6 @@ def compute_3d_position_ids( """, ) @can_return_tuple - @merge_with_config_defaults def forward( self, input_ids: torch.LongTensor | None = None, @@ -1494,7 +1455,7 @@ def forward( batch_size, seq_len = inputs_embeds.shape[:2] mm_token_type_ids = torch.full((batch_size, seq_len), 0, device=inputs_embeds.device, dtype=torch.long) else: - mm_token_type_ids = mm_token_type_ids.to(device=inputs_embeds.device, dtype=torch.long) + mm_token_type_ids = mm_token_type_ids.to(dtype=torch.long) image_token_mask = mm_token_type_ids == 1 if created_inputs_embeds and torch.any(image_token_mask): @@ -1516,19 +1477,27 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(scatter_mask, image_features) if isinstance(attention_mask, dict): - attention_mask = attention_mask.get("full_attention", next(iter(attention_mask.values()))) + attention_mask = attention_mask["full_attention"] - position_ids = self.compute_3d_position_ids( + past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() + computed_position_ids = self.compute_3d_position_ids( input_ids=input_ids, inputs_embeds=inputs_embeds, mm_token_type_ids=mm_token_type_ids, image_token_grids=vision_token_grids, image_token_offsets=vision_token_offsets, image_token_lengths=vision_token_lengths, - position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, ) + if computed_position_ids is not None: + position_ids = computed_position_ids + elif past_seen_tokens > 0: + position_ids = None + elif position_ids is not None and past_seen_tokens == 0: + position_ids = position_ids.to(device=inputs_embeds.device) + if position_ids.ndim == 2: + position_ids = position_ids.view(1, position_ids.shape[0], -1).expand(3, -1, -1) text_model_outputs = self.text_model( attention_mask=attention_mask, @@ -1667,19 +1636,6 @@ def prepare_inputs_for_generation( image_patch_attention_mask if vision_patch_attention_mask is None else vision_patch_attention_mask ) vision_token_grids = image_token_grids if vision_token_grids is None else vision_token_grids - if position_ids is None or position_ids.ndim == 2: - position_ids = self._prepare_position_ids_for_generation( - input_ids, - { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "mm_token_type_ids": mm_token_type_ids, - "vision_token_grids": vision_token_grids, - "vision_token_offsets": vision_token_offsets, - "vision_token_lengths": vision_token_lengths, - }, - ) model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -1717,7 +1673,7 @@ def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs): inputs_tensor = model_kwargs["input_ids"] if ( - model_kwargs.get("image_token_lengths") is not None + model_kwargs.get("vision_token_lengths") is not None and len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long] ): @@ -1756,8 +1712,5 @@ def _expand_inputs_for_generation( model_kwargs["position_ids"] = position_ids.repeat_interleave(expand_size, dim=dim) return input_ids, model_kwargs - def get_input_embeddings(self) -> nn.Module: - return self.model.get_input_embeddings() - __all__ = ["IsaacTextModel", "IsaacModel", "IsaacPreTrainedModel", "IsaacForConditionalGeneration"] diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 0264896bcfd4..40ec4b2d6e6d 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -20,12 +20,14 @@ from collections.abc import Sequence from typing import Any, NamedTuple +from huggingface_hub.dataclasses import strict + +from ... import TorchvisionBackend from ... import initialization as init from ...configuration_utils import PretrainedConfig from ...feature_extraction_utils import BatchFeature from ...generation.utils import GenerationMixin from ...image_processing_utils_fast import ( - BaseImageProcessorFast, ImagesKwargs, SizeDict, group_images_by_shape, @@ -80,7 +82,7 @@ else: Image = None if is_torchvision_available(): - from ..pix2struct.image_processing_pix2struct_fast import torch_extract_patches + from ..pix2struct.image_processing_pix2struct import torch_extract_patches class SinglePoint(NamedTuple): @@ -97,8 +99,17 @@ class BoundingBox(NamedTuple): t: float | None = None -_POINT_OR_BOX_TAG = re.compile( - r"<(?Ppoint|point_box)(?P[^>]*)>(?P[\s\S]*?)", re.IGNORECASE +class Polygon(NamedTuple): + points: tuple[SinglePoint, ...] + mention: str | None = None + t: float | None = None + + +IsaacAnnotation = SinglePoint | BoundingBox | Polygon + + +_POINT_BOX_OR_POLYGON_TAG = re.compile( + r"<(?Ppoint|point_box|polygon)(?P[^>]*)>(?P[\s\S]*?)", re.IGNORECASE ) _ATTR_RE = re.compile(r"(\w+)\s*=\s*(?:\"([^\"]*)\"|([^\s>]+))") _COORD_RE = re.compile(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)") @@ -149,12 +160,21 @@ def _parse_box_body(body: str, mention: str | None = None, t: str | None = None) return BoundingBox(top_left=top_left, bottom_right=bottom_right, mention=mention, t=_maybe_float(t)) +def _parse_polygon_body(body: str, mention: str | None = None, t: str | None = None) -> Polygon: + coords = list(_COORD_RE.finditer(body)) + if len(coords) < 3: + raise ValueError(f"Malformed tag: {body!r}") + + points = tuple(SinglePoint(x=int(coord.group(1)), y=int(coord.group(2))) for coord in coords) + return Polygon(points=points, mention=mention, t=_maybe_float(t)) + + def clean_text_and_extract_points( text: str, expected: str | None = None, -) -> tuple[str, list[SinglePoint | BoundingBox]]: - results = [] - for match in _POINT_OR_BOX_TAG.finditer(text or ""): +) -> tuple[str, list[IsaacAnnotation]]: + results: list[IsaacAnnotation] = [] + for match in _POINT_BOX_OR_POLYGON_TAG.finditer(text or ""): tag = match.group("tag").lower() attrs = _parse_attrs(match.group("attrs")) mention = attrs.get("mention") @@ -163,15 +183,20 @@ def clean_text_and_extract_points( if expected not in (None, "point"): continue results.append(_parse_point_body(match.group("body"), mention=mention, t=t)) - else: + elif tag == "point_box": if expected not in (None, "box"): continue results.append(_parse_box_body(match.group("body"), mention=mention, t=t)) + else: + if expected not in (None, "polygon"): + continue + results.append(_parse_polygon_body(match.group("body"), mention=mention, t=t)) - clean_text = re.sub(r"\s+", " ", _POINT_OR_BOX_TAG.sub("", text or "")).strip() + clean_text = re.sub(r"\s+", " ", _POINT_BOX_OR_POLYGON_TAG.sub("", text or "")).strip() return clean_text, results +@strict(accept_kwargs=True) class IsaacVisionConfig(Siglip2VisionConfig): """Vision configuration for Isaac with Pixel Shuffle support. @@ -185,24 +210,21 @@ class IsaacVisionConfig(Siglip2VisionConfig): model_type = "isaac_vision" base_config_key = "vision_config" - def __init__( - self, - pixel_shuffle_scale_factor=1, - **super_kwargs, - ): - super().__init__(**super_kwargs) - # Add our custom fields - self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor + pixel_shuffle_scale_factor: int = 1 +@strict(accept_kwargs=True) class IsaacTextConfig(Qwen3Config): model_type = "isaac_text" + ignore_keys_at_rope_validation = {"mrope_section", "mrope_interleaved"} + max_position_embeddings: int = 32768 - def __init__(self, **super_kwargs): - super().__init__(ignore_keys_at_rope_validation={"mrope_section", "mrope_interleaved"}, **super_kwargs) + def __post_init__(self, **kwargs): + super().__post_init__(**kwargs) + self.validate_layer_type() -class IsaacImageProcessorFastKwargs(ImagesKwargs, total=False): +class IsaacImageProcessorKwargs(ImagesKwargs, total=False): """ patch_size (`int`, *optional*): Side length (in pixels) for square patches extracted from resized images. @@ -221,7 +243,7 @@ class IsaacImageProcessorFastKwargs(ImagesKwargs, total=False): @auto_docstring -class IsaacImageProcessorFast(BaseImageProcessorFast): +class IsaacImageProcessor(TorchvisionBackend): MAX_PIXELS = 60_000_000 # 60โ€‘megapixel ceiling โ‰ˆ 8200 ร— 7300 px resample = PILImageResampling.BILINEAR @@ -230,8 +252,7 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): "vision_patch_attention_mask", "vision_token_grids", ] - valid_kwargs = IsaacImageProcessorFastKwargs - unused_kwargs = ["size", "do_center_crop", "crop_size", "pad_size", "do_pad"] + valid_kwargs = IsaacImageProcessorKwargs do_resize = True do_center_crop = False @@ -239,7 +260,7 @@ class IsaacImageProcessorFast(BaseImageProcessorFast): max_num_patches: int | None = 256 min_num_patches: int | None = None pixel_shuffle_scale: int | None = 1 - do_pad = False + do_pad = True do_rescale = True do_normalize = True image_mean = list(VISION_MEAN) @@ -268,16 +289,51 @@ def resize( ) -> torch.Tensor: return F.interpolate(image, size=(size.height, size.width), mode="bilinear", align_corners=False) + def pad( + self, + vision_patches: list[list[torch.Tensor]], + vision_token_grids: list[list[torch.Tensor]], + ) -> dict[str, torch.Tensor]: + batch_size = len(vision_patches) + first_patch = next(patches for sample_patches in vision_patches for patches in sample_patches) + max_images = max(len(sample_patches) for sample_patches in vision_patches) + max_patches = max(patches.shape[0] for sample_patches in vision_patches for patches in sample_patches) + patch_dim = first_patch.shape[-1] + patch_dtype = first_patch.dtype + patch_device = first_patch.device + + tensors = { + "vision_patches": torch.zeros( + (batch_size, max_images, max_patches, patch_dim), device=patch_device, dtype=patch_dtype + ), + "vision_patch_attention_mask": torch.zeros( + (batch_size, max_images, max_patches), device=patch_device, dtype=torch.long + ), + "vision_token_grids": torch.zeros((batch_size, max_images, 2), device=patch_device, dtype=torch.long), + } + + for batch_idx, (sample_patches, sample_token_grids) in enumerate( + zip(vision_patches, vision_token_grids, strict=True) + ): + for image_idx, (patches, token_grid) in enumerate(zip(sample_patches, sample_token_grids, strict=True)): + patch_count = int(patches.shape[0]) + tensors["vision_patches"][batch_idx, image_idx, :patch_count] = patches + tensors["vision_patch_attention_mask"][batch_idx, image_idx, :patch_count] = 1 + tensors["vision_token_grids"][batch_idx, image_idx] = token_grid + + return tensors + def _preprocess( self, images: list[list[torch.Tensor]], do_resize: bool, - interpolation: Any | None, + resample: Any | None, do_rescale: bool | None, rescale_factor: float | None, do_normalize: bool | None, image_mean: float | Sequence[float] | None, image_std: float | Sequence[float] | None, + do_pad: bool | None = None, disable_grouping: bool | None = None, return_tensors: str | TensorType | None = None, patch_size: int | None = None, @@ -286,7 +342,11 @@ def _preprocess( pixel_shuffle_scale: int | None = None, **kwargs, ) -> BatchFeature: + resample = kwargs.pop("interpolation", resample) batch_size = len(images) + # IsaacProcessor routes text-only calls here as an empty image list per sample. + # This returns empty vision tensors to preserve the multimodal output schema; + # image-token/image-count mismatches are validated earlier in processor's _preprocess call. if all(len(sample_images) == 0 for sample_images in images): tensors = { "vision_patches": torch.zeros((batch_size, 0, 0, 0), dtype=torch.float32), @@ -313,7 +373,7 @@ def _preprocess( ) if do_resize: image_batch = self.resize( - stacked_images, SizeDict(height=target_height, width=target_width), interpolation=interpolation + stacked_images, SizeDict(height=target_height, width=target_width), resample=resample ) else: if (original_height % patch_size) or (original_width % patch_size): @@ -358,40 +418,13 @@ def _preprocess( for i, key in enumerate(keys) } - first_patch = next( - patches for sample_patches in nested_outputs["vision_patches"] for patches in sample_patches - ) - max_images = max(len(sample_patches) for sample_patches in nested_outputs["vision_patches"]) - patch_dim = first_patch.shape[-1] - patch_dtype = first_patch.dtype - patch_device = first_patch.device - max_patches = max( - patches.shape[0] for sample_patches in nested_outputs["vision_patches"] for patches in sample_patches - ) - - tensors = { - "vision_patches": torch.zeros( - (batch_size, max_images, max_patches, patch_dim), device=patch_device, dtype=patch_dtype - ), - "vision_patch_attention_mask": torch.zeros( - (batch_size, max_images, max_patches), device=patch_device, dtype=torch.long - ), - "vision_token_grids": torch.zeros((batch_size, max_images, 2), device=patch_device, dtype=torch.long), - } - - for batch_idx, sample_patches in enumerate(nested_outputs["vision_patches"]): - sample_image_count = len(sample_patches) - if sample_image_count == 0: - continue - - for image_idx, patches in enumerate(sample_patches): - patch_count = int(patches.shape[0]) + if not do_pad: + raise ValueError("IsaacImageProcessor doesn't support `do_pad=False` mode.") - tensors["vision_patches"][batch_idx, image_idx, :patch_count] = patches - tensors["vision_patch_attention_mask"][batch_idx, image_idx, :patch_count] = 1 - tensors["vision_token_grids"][batch_idx, image_idx] = nested_outputs["vision_token_grids"][batch_idx][ - image_idx - ] + tensors = self.pad( + vision_patches=nested_outputs["vision_patches"], + vision_token_grids=nested_outputs["vision_token_grids"], + ) return BatchFeature(data=tensors, tensor_type=return_tensors) @@ -527,6 +560,7 @@ class IsaacVisionTransformer(PreTrainedModel): """ + config: IsaacVisionConfig _supports_sdpa = True _supports_flash_attn = True _can_record_outputs = { @@ -740,6 +774,9 @@ def __init__( vision_token: str = "", **kwargs, ): + for key in ("use_cache", "rope_theta", "max_position_embeddings"): + kwargs.pop(key, None) + if isinstance(text_config, dict): self.text_config = self.sub_configs["text_config"](**text_config) elif isinstance(text_config, IsaacTextConfig): @@ -756,12 +793,6 @@ def __init__( super().__init__(**kwargs) - # Mirror frequently accessed composite-level attributes. - self.use_cache = self.text_config.use_cache - self.rope_theta = self.text_config.rope_parameters["rope_theta"] - self.max_position_embeddings = getattr(self.text_config, "max_position_embeddings", max_sequence_length) - self.text_config.max_position_embeddings = self.max_position_embeddings - # Vision normalization parameters self.vision_rescale_factor = float(vision_rescale_factor) @@ -805,13 +836,43 @@ def __init__( self.vision_token = vision_token self.max_sequence_length = max_sequence_length - def _build_batch( + def post_process_generation( + self, + text: str, + expected: str | None = None, + cleanup_and_extract: bool = True, + ) -> str | tuple[str, list[IsaacAnnotation]]: + if cleanup_and_extract: + return clean_text_and_extract_points(text, expected=expected) + return text + + def post_process_image_text_to_text( + self, + generated_outputs, + skip_special_tokens: bool = True, + cleanup_and_extract: bool = False, + expected: str | None = None, + **kwargs, + ): + generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs) + return [ + self.post_process_generation(text, expected=expected, cleanup_and_extract=cleanup_and_extract) + for text in generated_texts + ] + + def __call__( self, text: str | list[str], images: ImageInput | None = None, - text_kwargs: dict[str, Any] | None = None, - ) -> dict[str, torch.Tensor | None]: - text_kwargs = copy.deepcopy(text_kwargs) if text_kwargs is not None else {} + return_tensors: str | TensorType | None = TensorType.PYTORCH, + **kwargs, + ) -> BatchFeature: + output_kwargs = self._merge_kwargs( + IsaacProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + text_kwargs = copy.deepcopy(output_kwargs["text_kwargs"]) truncation = text_kwargs.pop("truncation", None) max_length = text_kwargs.pop("max_length", None) padding = text_kwargs.pop("padding", True) @@ -870,15 +931,15 @@ def _build_batch( expanded_text += (self.image_token * segment_length) + segments[image_idx + 1] expanded_texts.append(expanded_text) - text_inputs = self.tokenizer(expanded_texts, return_tensors=None, **text_kwargs) - self._check_special_mm_tokens(expanded_texts, text_inputs, modalities=["image"]) + tokenized_text_inputs = self.tokenizer(expanded_texts, return_tensors=None, **text_kwargs) + self._check_special_mm_tokens(expanded_texts, tokenized_text_inputs, modalities=["image"]) effective_max_length = self.max_sequence_length if truncation and max_length is not None: effective_max_length = max_length for batch_idx, (expected_image_lengths, sample_input_ids_list) in enumerate( - zip(expected_image_lengths_per_sample, text_inputs["input_ids"], strict=True) + zip(expected_image_lengths_per_sample, tokenized_text_inputs["input_ids"], strict=True) ): sample_input = torch.tensor(sample_input_ids_list, dtype=torch.long) image_positions = sample_input.eq(self.image_pad_token_id).nonzero(as_tuple=False).flatten() @@ -902,7 +963,8 @@ def _build_batch( vision_token_offsets[batch_idx, image_idx] = kept_start - image_start vision_token_lengths[batch_idx, image_idx] = kept_end - kept_start - text_inputs = self.tokenizer.pad( + # Pad only after Isaac-specific truncation so image span offsets and lengths stay aligned. + padded_text_inputs = self.tokenizer.pad( {"input_ids": [sample_input.tolist() for sample_input in sample_input_ids]}, padding=padding, max_length=max_length if padding == "max_length" else None, @@ -911,8 +973,8 @@ def _build_batch( return_attention_mask=return_attention_mask, return_tensors=TensorType.PYTORCH, ) - input_ids = text_inputs["input_ids"] - attention_mask = text_inputs.get("attention_mask") + input_ids = padded_text_inputs["input_ids"] + attention_mask = padded_text_inputs.get("attention_mask") if attention_mask is None: attention_mask = input_ids.ne(self.pad_token_id).to(dtype=torch.long) @@ -922,74 +984,33 @@ def _build_batch( vision_patches = image_inputs["vision_patches"] vision_patch_attention_mask = image_inputs["vision_patch_attention_mask"] - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "mm_token_type_ids": mm_token_type_ids, - "vision_patches": vision_patches, - "vision_patch_attention_mask": vision_patch_attention_mask, - "vision_token_grids": vision_token_grids, - "vision_token_offsets": vision_token_offsets, - "vision_token_lengths": vision_token_lengths, - "vision_image_attention_mask": vision_image_attention_mask, - } - - def post_process_generation( - self, - text: str, - expected: str | None = None, - cleanup_and_extract: bool = True, - ) -> str | tuple[str, list[SinglePoint | BoundingBox]]: - if cleanup_and_extract: - return clean_text_and_extract_points(text, expected=expected) - return text - - def post_process_image_text_to_text( - self, - generated_outputs, - skip_special_tokens: bool = True, - cleanup_and_extract: bool = False, - expected: str | None = None, - **kwargs, - ): - generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs) - return [ - self.post_process_generation(text, expected=expected, cleanup_and_extract=cleanup_and_extract) - for text in generated_texts - ] - - def __call__( - self, - text: str | list[str], - images: ImageInput | None = None, - return_tensors: str | TensorType | None = TensorType.PYTORCH, - **kwargs, - ) -> BatchFeature: - output_kwargs = self._merge_kwargs( - IsaacProcessorKwargs, - tokenizer_init_kwargs=self.tokenizer.init_kwargs, - **kwargs, - ) return BatchFeature( - data=self._build_batch(text=text, images=images, text_kwargs=output_kwargs["text_kwargs"]), + data={ + "input_ids": input_ids, + "attention_mask": attention_mask, + "mm_token_type_ids": mm_token_type_ids, + "vision_patches": vision_patches, + "vision_patch_attention_mask": vision_patch_attention_mask, + "vision_token_grids": vision_token_grids, + "vision_token_offsets": vision_token_offsets, + "vision_token_lengths": vision_token_lengths, + "vision_image_attention_mask": vision_image_attention_mask, + }, tensor_type=return_tensors, ) class IsaacRotaryEmbedding(Qwen3VLTextRotaryEmbedding): - def __init__(self, config: IsaacConfig | IsaacTextConfig, device=None): - rope_source_cfg = config.get_text_config() - config_for_rope = copy.copy(rope_source_cfg) - rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} - config_for_rope.rope_scaling = rope_scaling + def __init__(self, config: IsaacTextConfig, device=None): + rope_parameters = config.rope_parameters super().__init__( - config_for_rope, + config, device=device if device is not None and getattr(device, "type", None) != "meta" else None, ) - self.mrope_section = self._resolve_mrope_section(rope_scaling.get("mrope_section"), self.inv_freq.shape[0]) - self.hidden_size = getattr(rope_source_cfg, "hidden_size", None) + self.mrope_section = self._resolve_mrope_section(rope_parameters.get("mrope_section"), self.inv_freq.shape[0]) + self.hidden_size = config.hidden_size @staticmethod def _resolve_mrope_section(section: list[int] | None, rotary_half_dim: int) -> list[int]: @@ -1027,6 +1048,7 @@ class IsaacModel(Qwen3PreTrainedModel): _can_compile_fullgraph = False _supports_flex_attn = False _tied_weights_keys = {} + _input_embed_layer = "text_model.embed_tokens" def __init__(self, config: IsaacConfig): Qwen3PreTrainedModel.__init__(self, config) @@ -1041,12 +1063,6 @@ def __init__(self, config: IsaacConfig): self.post_init() - def get_input_embeddings(self) -> nn.Module: - return self.text_model.get_input_embeddings() - - def set_input_embeddings(self, value: nn.Module) -> None: - self.text_model.set_input_embeddings(value) - @can_return_tuple @auto_docstring def get_image_features( @@ -1074,64 +1090,57 @@ def get_image_features( image_attention_mask (`torch.Tensor`, *optional*): Mask indicating which image slots are populated, shaped `(batch_size, max_images)`. """ - device = self.text_model.embed_tokens.weight.device - pixel_values = pixel_values.to(device=device) - image_token_grids = image_token_grids.to(device=device, dtype=torch.long) - patch_attention_mask = image_patch_attention_mask.to(device=device, dtype=torch.long) + image_token_grids = image_token_grids.to(dtype=torch.long) + patch_attention_mask = image_patch_attention_mask.to(dtype=torch.long) if image_attention_mask is None: if image_token_lengths is not None: - image_attention_mask = image_token_lengths.to(device=device, dtype=torch.long) > 0 + image_attention_mask = image_token_lengths > 0 else: image_attention_mask = image_token_grids.any(dim=-1) else: - image_attention_mask = image_attention_mask.to(device=device, dtype=torch.bool) + image_attention_mask = image_attention_mask.to(dtype=torch.bool) + + torch_compilable_check( + image_attention_mask.any(), + "IsaacModel.get_image_features expects at least one active image slot; text-only inputs should skip this method.", + ) batch_size, max_images = pixel_values.shape[:2] hidden_size = self.config.get_text_config().hidden_size - if image_attention_mask.any(): - vision_kwargs = { - key: value - for key in ("output_hidden_states", "output_attentions") - if (value := kwargs.get(key)) is not None - } - vision_outputs = self.vision_tower( - vision_patches=pixel_values[image_attention_mask], - vision_token_grids=image_token_grids[image_attention_mask], - vision_patch_attention_mask=patch_attention_mask[image_attention_mask], - return_dict=True, - **vision_kwargs, - ) - flat_projected_features = self.multimodal_projector(vision_outputs.last_hidden_state) - max_tokens = flat_projected_features.shape[1] - projected_features = flat_projected_features.new_zeros((batch_size, max_images, max_tokens, hidden_size)) - projected_features[image_attention_mask] = flat_projected_features - offsets = ( - image_token_offsets.to(device=device, dtype=torch.long) - if image_token_offsets is not None - else torch.zeros((batch_size, max_images), device=device, dtype=torch.long) - ) - lengths = ( - image_token_lengths.to(device=device, dtype=torch.long) - if image_token_lengths is not None - else torch.full((batch_size, max_images), max_tokens, device=device, dtype=torch.long) - ) - flat_offsets = offsets[image_attention_mask] - flat_lengths = lengths[image_attention_mask] - token_positions = torch.arange(flat_lengths.max(), device=device, dtype=torch.long) - gather_positions = flat_offsets[:, None] + token_positions[None, :] - gather_mask = token_positions[None, :] < flat_lengths[:, None] - image_features = flat_projected_features[ - torch.arange(flat_projected_features.shape[0], device=device, dtype=torch.long)[:, None], - gather_positions, - ][gather_mask] - hidden_states = vision_outputs.hidden_states - attentions = vision_outputs.attentions - else: - projected_features = pixel_values.new_zeros((batch_size, max_images, 0, hidden_size)) - image_features = pixel_values.new_zeros((0, hidden_size)) - hidden_states = None - attentions = None + vision_outputs = self.vision_tower( + vision_patches=pixel_values[image_attention_mask], + vision_token_grids=image_token_grids[image_attention_mask], + vision_patch_attention_mask=patch_attention_mask[image_attention_mask], + return_dict=True, + **kwargs, + ) + flat_projected_features = self.multimodal_projector(vision_outputs.last_hidden_state) + max_tokens = flat_projected_features.shape[1] + projected_features = flat_projected_features.new_zeros((batch_size, max_images, max_tokens, hidden_size)) + projected_features[image_attention_mask] = flat_projected_features + feature_device = flat_projected_features.device + offsets = ( + image_token_offsets.to(dtype=torch.long) + if image_token_offsets is not None + else torch.zeros((batch_size, max_images), device=feature_device, dtype=torch.long) + ) + lengths = ( + image_token_lengths.to(dtype=torch.long) + if image_token_lengths is not None + else torch.full((batch_size, max_images), max_tokens, device=feature_device, dtype=torch.long) + ) + flat_offsets = offsets[image_attention_mask] + flat_lengths = lengths[image_attention_mask] + token_positions = torch.arange(flat_lengths.max(), device=feature_device, dtype=torch.long) + gather_positions = flat_offsets[:, None] + token_positions[None, :] + gather_mask = token_positions[None, :] < flat_lengths[:, None] + image_features = flat_projected_features[ + torch.arange(flat_projected_features.shape[0], device=feature_device, dtype=torch.long)[:, None], + gather_positions, + ][gather_mask] + hidden_states = vision_outputs.hidden_states + attentions = vision_outputs.attentions return BaseModelOutputWithPooling( last_hidden_state=projected_features, @@ -1146,7 +1155,7 @@ def get_placeholder_mask( inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor, ) -> torch.BoolTensor: - image_token_mask = mm_token_type_ids.to(device=inputs_embeds.device, dtype=torch.long) == 1 + image_token_mask = mm_token_type_ids.to(dtype=torch.long) == 1 n_image_tokens = image_token_mask.sum() image_token_mask = image_token_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) torch_compilable_check( @@ -1193,11 +1202,11 @@ def get_rope_index( device = attention_mask.device batch_size, seq_len = attention_mask.shape - mm_token_type_ids = mm_token_type_ids.to(device=device, dtype=torch.long) - image_token_grids = image_token_grids.to(device=device, dtype=torch.long) - image_token_offsets = image_token_offsets.to(device=device, dtype=torch.long) - image_token_lengths = image_token_lengths.to(device=device, dtype=torch.long) - attention_mask = attention_mask.to(device=device, dtype=torch.long) + mm_token_type_ids = mm_token_type_ids.to(dtype=torch.long) + image_token_grids = image_token_grids.to(dtype=torch.long) + image_token_offsets = image_token_offsets.to(dtype=torch.long) + image_token_lengths = image_token_lengths.to(dtype=torch.long) + attention_mask = attention_mask.to(dtype=torch.long) image_attention_mask = image_token_lengths > 0 position_ids = torch.zeros((3, batch_size, seq_len), device=device, dtype=torch.long) @@ -1256,7 +1265,6 @@ def compute_3d_position_ids( image_token_grids: torch.Tensor | None = None, image_token_offsets: torch.Tensor | None = None, image_token_lengths: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, past_key_values: torch.Tensor | None = None, ) -> torch.Tensor: past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() @@ -1273,13 +1281,6 @@ def compute_3d_position_ids( self.rope_deltas = rope_deltas return position_ids - if position_ids is not None and past_seen_tokens == 0: - position_ids = position_ids.to(device=inputs_embeds.device) - if position_ids.ndim == 2: - return position_ids.view(1, position_ids.shape[0], -1).expand(3, -1, -1) - if position_ids.ndim == 3 and position_ids.shape[0] in (1, 4): - return position_ids - if self.rope_deltas is None: return None @@ -1315,7 +1316,6 @@ def compute_3d_position_ids( """, ) @can_return_tuple - @merge_with_config_defaults def forward( self, input_ids: torch.LongTensor | None = None, @@ -1366,7 +1366,7 @@ def forward( batch_size, seq_len = inputs_embeds.shape[:2] mm_token_type_ids = torch.full((batch_size, seq_len), 0, device=inputs_embeds.device, dtype=torch.long) else: - mm_token_type_ids = mm_token_type_ids.to(device=inputs_embeds.device, dtype=torch.long) + mm_token_type_ids = mm_token_type_ids.to(dtype=torch.long) image_token_mask = mm_token_type_ids == 1 if created_inputs_embeds and torch.any(image_token_mask): @@ -1388,19 +1388,27 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(scatter_mask, image_features) if isinstance(attention_mask, dict): - attention_mask = attention_mask.get("full_attention", next(iter(attention_mask.values()))) + attention_mask = attention_mask["full_attention"] - position_ids = self.compute_3d_position_ids( + past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() + computed_position_ids = self.compute_3d_position_ids( input_ids=input_ids, inputs_embeds=inputs_embeds, mm_token_type_ids=mm_token_type_ids, image_token_grids=vision_token_grids, image_token_offsets=vision_token_offsets, image_token_lengths=vision_token_lengths, - position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, ) + if computed_position_ids is not None: + position_ids = computed_position_ids + elif past_seen_tokens > 0: + position_ids = None + elif position_ids is not None and past_seen_tokens == 0: + position_ids = position_ids.to(device=inputs_embeds.device) + if position_ids.ndim == 2: + position_ids = position_ids.view(1, position_ids.shape[0], -1).expand(3, -1, -1) text_model_outputs = self.text_model( attention_mask=attention_mask, @@ -1534,19 +1542,6 @@ def prepare_inputs_for_generation( image_patch_attention_mask if vision_patch_attention_mask is None else vision_patch_attention_mask ) vision_token_grids = image_token_grids if vision_token_grids is None else vision_token_grids - if position_ids is None or position_ids.ndim == 2: - position_ids = self._prepare_position_ids_for_generation( - input_ids, - { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "mm_token_type_ids": mm_token_type_ids, - "vision_token_grids": vision_token_grids, - "vision_token_offsets": vision_token_offsets, - "vision_token_lengths": vision_token_lengths, - }, - ) model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -1584,7 +1579,7 @@ def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs): inputs_tensor = model_kwargs["input_ids"] if ( - model_kwargs.get("image_token_lengths") is not None + model_kwargs.get("vision_token_lengths") is not None and len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long] ): @@ -1623,9 +1618,6 @@ def _expand_inputs_for_generation( model_kwargs["position_ids"] = position_ids.repeat_interleave(expand_size, dim=dim) return input_ids, model_kwargs - def get_input_embeddings(self) -> nn.Module: - return self.model.get_input_embeddings() - __all__ = [ "IsaacConfig", @@ -1635,6 +1627,6 @@ def get_input_embeddings(self) -> nn.Module: "IsaacModel", "IsaacPreTrainedModel", # noqa: F822 "IsaacForConditionalGeneration", - "IsaacImageProcessorFast", + "IsaacImageProcessor", "IsaacProcessor", ] diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 93e991702649..8b1696a7b501 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -21,14 +21,13 @@ import copy import re -from typing import Any from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, make_nested_list_of_images from ...processing_utils import ProcessingKwargs, ProcessorMixin from ...utils import TensorType, auto_docstring from ...utils.import_utils import is_torch_available -from .modeling_isaac import BoundingBox, SinglePoint +from .modeling_isaac import BoundingBox, Polygon, SinglePoint if is_torch_available(): @@ -44,8 +43,11 @@ class IsaacProcessorKwargs(ProcessingKwargs, total=False): } -_POINT_OR_BOX_TAG = re.compile( - r"<(?Ppoint|point_box)(?P[^>]*)>(?P[\s\S]*?)", re.IGNORECASE +IsaacAnnotation = SinglePoint | BoundingBox | Polygon + + +_POINT_BOX_OR_POLYGON_TAG = re.compile( + r"<(?Ppoint|point_box|polygon)(?P[^>]*)>(?P[\s\S]*?)", re.IGNORECASE ) _ATTR_RE = re.compile(r"(\w+)\s*=\s*(?:\"([^\"]*)\"|([^\s>]+))") _COORD_RE = re.compile(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)") @@ -87,12 +89,21 @@ def _parse_box_body(body: str, mention: str | None = None, t: str | None = None) return BoundingBox(top_left=top_left, bottom_right=bottom_right, mention=mention, t=_maybe_float(t)) +def _parse_polygon_body(body: str, mention: str | None = None, t: str | None = None) -> Polygon: + coords = list(_COORD_RE.finditer(body)) + if len(coords) < 3: + raise ValueError(f"Malformed tag: {body!r}") + + points = tuple(SinglePoint(x=int(coord.group(1)), y=int(coord.group(2))) for coord in coords) + return Polygon(points=points, mention=mention, t=_maybe_float(t)) + + def clean_text_and_extract_points( text: str, expected: str | None = None, -) -> tuple[str, list[SinglePoint | BoundingBox]]: - results = [] - for match in _POINT_OR_BOX_TAG.finditer(text or ""): +) -> tuple[str, list[IsaacAnnotation]]: + results: list[IsaacAnnotation] = [] + for match in _POINT_BOX_OR_POLYGON_TAG.finditer(text or ""): tag = match.group("tag").lower() attrs = _parse_attrs(match.group("attrs")) mention = attrs.get("mention") @@ -101,12 +112,16 @@ def clean_text_and_extract_points( if expected not in (None, "point"): continue results.append(_parse_point_body(match.group("body"), mention=mention, t=t)) - else: + elif tag == "point_box": if expected not in (None, "box"): continue results.append(_parse_box_body(match.group("body"), mention=mention, t=t)) + else: + if expected not in (None, "polygon"): + continue + results.append(_parse_polygon_body(match.group("body"), mention=mention, t=t)) - clean_text = re.sub(r"\s+", " ", _POINT_OR_BOX_TAG.sub("", text or "")).strip() + clean_text = re.sub(r"\s+", " ", _POINT_BOX_OR_POLYGON_TAG.sub("", text or "")).strip() return clean_text, results @@ -145,13 +160,43 @@ def __init__( self.vision_token = vision_token self.max_sequence_length = max_sequence_length - def _build_batch( + def post_process_generation( + self, + text: str, + expected: str | None = None, + cleanup_and_extract: bool = True, + ) -> str | tuple[str, list[IsaacAnnotation]]: + if cleanup_and_extract: + return clean_text_and_extract_points(text, expected=expected) + return text + + def post_process_image_text_to_text( + self, + generated_outputs, + skip_special_tokens: bool = True, + cleanup_and_extract: bool = False, + expected: str | None = None, + **kwargs, + ): + generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs) + return [ + self.post_process_generation(text, expected=expected, cleanup_and_extract=cleanup_and_extract) + for text in generated_texts + ] + + def __call__( self, text: str | list[str], images: ImageInput | None = None, - text_kwargs: dict[str, Any] | None = None, - ) -> dict[str, torch.Tensor | None]: - text_kwargs = copy.deepcopy(text_kwargs) if text_kwargs is not None else {} + return_tensors: str | TensorType | None = TensorType.PYTORCH, + **kwargs, + ) -> BatchFeature: + output_kwargs = self._merge_kwargs( + IsaacProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + text_kwargs = copy.deepcopy(output_kwargs["text_kwargs"]) truncation = text_kwargs.pop("truncation", None) max_length = text_kwargs.pop("max_length", None) padding = text_kwargs.pop("padding", True) @@ -210,15 +255,15 @@ def _build_batch( expanded_text += (self.image_token * segment_length) + segments[image_idx + 1] expanded_texts.append(expanded_text) - text_inputs = self.tokenizer(expanded_texts, return_tensors=None, **text_kwargs) - self._check_special_mm_tokens(expanded_texts, text_inputs, modalities=["image"]) + tokenized_text_inputs = self.tokenizer(expanded_texts, return_tensors=None, **text_kwargs) + self._check_special_mm_tokens(expanded_texts, tokenized_text_inputs, modalities=["image"]) effective_max_length = self.max_sequence_length if truncation and max_length is not None: effective_max_length = max_length for batch_idx, (expected_image_lengths, sample_input_ids_list) in enumerate( - zip(expected_image_lengths_per_sample, text_inputs["input_ids"], strict=True) + zip(expected_image_lengths_per_sample, tokenized_text_inputs["input_ids"], strict=True) ): sample_input = torch.tensor(sample_input_ids_list, dtype=torch.long) image_positions = sample_input.eq(self.image_pad_token_id).nonzero(as_tuple=False).flatten() @@ -242,7 +287,8 @@ def _build_batch( vision_token_offsets[batch_idx, image_idx] = kept_start - image_start vision_token_lengths[batch_idx, image_idx] = kept_end - kept_start - text_inputs = self.tokenizer.pad( + # Pad only after Isaac-specific truncation so image span offsets and lengths stay aligned. + padded_text_inputs = self.tokenizer.pad( {"input_ids": [sample_input.tolist() for sample_input in sample_input_ids]}, padding=padding, max_length=max_length if padding == "max_length" else None, @@ -251,8 +297,8 @@ def _build_batch( return_attention_mask=return_attention_mask, return_tensors=TensorType.PYTORCH, ) - input_ids = text_inputs["input_ids"] - attention_mask = text_inputs.get("attention_mask") + input_ids = padded_text_inputs["input_ids"] + attention_mask = padded_text_inputs.get("attention_mask") if attention_mask is None: attention_mask = input_ids.ne(self.pad_token_id).to(dtype=torch.long) @@ -262,56 +308,18 @@ def _build_batch( vision_patches = image_inputs["vision_patches"] vision_patch_attention_mask = image_inputs["vision_patch_attention_mask"] - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "mm_token_type_ids": mm_token_type_ids, - "vision_patches": vision_patches, - "vision_patch_attention_mask": vision_patch_attention_mask, - "vision_token_grids": vision_token_grids, - "vision_token_offsets": vision_token_offsets, - "vision_token_lengths": vision_token_lengths, - "vision_image_attention_mask": vision_image_attention_mask, - } - - def post_process_generation( - self, - text: str, - expected: str | None = None, - cleanup_and_extract: bool = True, - ) -> str | tuple[str, list[SinglePoint | BoundingBox]]: - if cleanup_and_extract: - return clean_text_and_extract_points(text, expected=expected) - return text - - def post_process_image_text_to_text( - self, - generated_outputs, - skip_special_tokens: bool = True, - cleanup_and_extract: bool = False, - expected: str | None = None, - **kwargs, - ): - generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs) - return [ - self.post_process_generation(text, expected=expected, cleanup_and_extract=cleanup_and_extract) - for text in generated_texts - ] - - def __call__( - self, - text: str | list[str], - images: ImageInput | None = None, - return_tensors: str | TensorType | None = TensorType.PYTORCH, - **kwargs, - ) -> BatchFeature: - output_kwargs = self._merge_kwargs( - IsaacProcessorKwargs, - tokenizer_init_kwargs=self.tokenizer.init_kwargs, - **kwargs, - ) return BatchFeature( - data=self._build_batch(text=text, images=images, text_kwargs=output_kwargs["text_kwargs"]), + data={ + "input_ids": input_ids, + "attention_mask": attention_mask, + "mm_token_type_ids": mm_token_type_ids, + "vision_patches": vision_patches, + "vision_patch_attention_mask": vision_patch_attention_mask, + "vision_token_grids": vision_token_grids, + "vision_token_offsets": vision_token_offsets, + "vision_token_lengths": vision_token_lengths, + "vision_image_attention_mask": vision_image_attention_mask, + }, tensor_type=return_tensors, ) diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py index df4ee148e6c3..374ddb1a1d78 100644 --- a/tests/models/isaac/test_modeling_isaac.py +++ b/tests/models/isaac/test_modeling_isaac.py @@ -29,16 +29,16 @@ from tests.test_configuration_common import ConfigTester from tests.test_pipeline_mixin import PipelineTesterMixin from transformers import ( - AutoTokenizer, IsaacConfig, IsaacForConditionalGeneration, IsaacModel, PythonBackend, + Qwen2Tokenizer, is_torch_available, ) from transformers.image_utils import load_image from transformers.masking_utils import create_bidirectional_mask -from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessorFast +from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessor from transformers.models.isaac.modeling_isaac import ( IsaacVisionAttention, IsaacVisionConfig, @@ -218,7 +218,7 @@ def create_isaac_processor( processor_image = image_processor if processor_image is None: - processor_image = IsaacImageProcessorFast( + processor_image = IsaacImageProcessor( patch_size=params["vision_patch_size"], max_num_patches=params["vision_max_num_patches"], min_num_patches=params["vision_min_num_patches"], @@ -636,33 +636,6 @@ def test_pixel_shuffle_padded_zero_grid(self): self.assertEqual(hidden.shape, (1, 0, 32)) -@require_torch -class IsaacPixelShufflePaddedTest(unittest.TestCase): - def test_pixel_shuffle_padded_matches_reference_no_attention_mask(self): - x = torch.arange(2 * 16 * 4, device=torch_device, dtype=torch.float32).view(2, 16, 4) - token_grids = torch.tensor([[4, 4], [2, 4]], device=torch_device, dtype=torch.long) - expected_hidden, expected_mask, expected_lengths = _pixel_shuffle_reference(x, token_grids, scale_factor=2) - - hidden = pixel_shuffle_padded(x=x, token_grids=token_grids, scale_factor=2) - - torch.testing.assert_close(hidden, expected_hidden) - - def test_pixel_shuffle_padded_raises_on_non_divisible_grid(self): - x = torch.randn(1, 15, 8, device=torch_device) - token_grids = torch.tensor([[3, 5]], device=torch_device, dtype=torch.long) - - with pytest.raises(ValueError, match="divisible"): - pixel_shuffle_padded(x=x, token_grids=token_grids, scale_factor=2) - - def test_pixel_shuffle_padded_zero_grid(self): - x = torch.randn(1, 4, 8, device=torch_device) - token_grids = torch.tensor([[0, 0]], device=torch_device, dtype=torch.long) - - hidden = pixel_shuffle_padded(x=x, token_grids=token_grids, scale_factor=2) - - self.assertEqual(hidden.shape, (1, 0, 32)) - - @require_torch @require_flash_attn class IsaacAttentionDtypeTest(unittest.TestCase): @@ -790,8 +763,8 @@ def setUp(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.checkpoint = _base_reference_checkpoint_or_skip() self.hf_config = IsaacConfig.from_pretrained(self.checkpoint, revision=BASE_MODEL_REVISION) - self.tokenizer = AutoTokenizer.from_pretrained( - self.checkpoint, trust_remote_code=True, use_fast=False, revision=BASE_MODEL_REVISION + self.tokenizer = Qwen2Tokenizer.from_pretrained( + self.checkpoint, trust_remote_code=False, use_fast=False, revision=BASE_MODEL_REVISION ) self.processor = create_isaac_processor(self.tokenizer, self.hf_config) self.hf_config.vision_config._attn_implementation = "flash_attention_2" @@ -1108,7 +1081,7 @@ def setUp(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.checkpoint = _reference_checkpoint_or_skip() self.hf_config = IsaacConfig.from_pretrained(self.checkpoint, revision=MODEL_REVISION) - self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer = Qwen2Tokenizer.from_pretrained( self.checkpoint, trust_remote_code=True, use_fast=False, revision=MODEL_REVISION ) self.processor = create_isaac_processor(self.tokenizer, self.hf_config) @@ -1169,3 +1142,62 @@ def test_hf_generate_box_points(self): assert first_point.top_left.y == 247 assert first_point.bottom_right.x == 863 assert first_point.bottom_right.y == 386 + + def test_hf_generate_polygon_points(self): + document = [ + { + "type": "text", + "content": "POLYGON", + "role": "user", + }, + { + "type": "image", + "content": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", + "role": "user", + }, + { + "type": "text", + "content": "Determine whether it is safe to cross the street. Look for signage and moving traffic.", + "role": "user", + }, + ] + messages, images = document_to_messages(document, vision_token=self.hf_config.vision_token) + prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() + processor_output = self.processor(text=prompt, images=images, return_tensors="pt") + input_ids = processor_output["input_ids"].to(self.device) + prompt_len = input_ids.shape[1] + multimodal_inputs = to_model_multimodal_inputs(processor_output, self.device) + + with torch.no_grad(): + outputs = self.model.generate( + input_ids=input_ids, + **multimodal_inputs, + max_new_tokens=self.max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + ) + + generated_ids = outputs.sequences + hf_generated_tail = generated_ids[:, prompt_len:] + hf_generated_text = self.tokenizer.decode(hf_generated_tail[0], skip_special_tokens=True) + _, polygons = self.processor.post_process_generation(hf_generated_text, expected="polygon") + assert len(polygons) == 1 + first_polygon = polygons[0] + xs = [point.x for point in first_polygon.points] + ys = [point.y for point in first_polygon.points] + expected_left, expected_top, expected_right, expected_bottom = 808, 247, 863, 386 + + assert len(first_polygon.points) >= 3 + assert first_polygon.mention == "traffic light" + assert min(xs) >= expected_left - 4 + assert max(xs) <= expected_right + 4 + assert min(ys) >= expected_top - 4 + assert max(ys) <= expected_bottom + 4 + assert max(xs) - min(xs) >= 35 + assert max(ys) - min(ys) >= 100 + assert any(abs(x - expected_left) <= 12 for x in xs) + assert any(abs(x - expected_right) <= 12 for x in xs) + assert any(abs(y - expected_top) <= 12 for y in ys) + assert any(abs(y - expected_bottom) <= 12 for y in ys) diff --git a/tests/models/isaac/test_processing_isaac.py b/tests/models/isaac/test_processing_isaac.py index 659ce2bfbb07..daee6affcb97 100644 --- a/tests/models/isaac/test_processing_isaac.py +++ b/tests/models/isaac/test_processing_isaac.py @@ -484,6 +484,44 @@ def test_post_process_generation_extracts_boxes_and_cleans_text(): assert box.bottom_right.y == 386 +@require_torch +def test_post_process_generation_extracts_polygons_and_filters_by_expected_type(): + processor = _make_post_process_processor() + + generated_text = ( + 'Point (1, 2) ' + 'Box (3, 4), (5, 6) ' + 'Polygon (10, 20), (30, 40), (50, 60)' + ) + + clean_text, annotations = processor.post_process_generation(generated_text, expected="polygon") + + assert clean_text == "Point Box Polygon" + assert len(annotations) == 1 + polygon = annotations[0] + assert polygon.mention == "lane" + assert polygon.t == pytest.approx(0.25) + assert len(polygon.points) == 3 + assert polygon.points[0].x == 10 + assert polygon.points[0].y == 20 + assert polygon.points[1].x == 30 + assert polygon.points[1].y == 40 + assert polygon.points[2].x == 50 + assert polygon.points[2].y == 60 + + _, boxes = processor.post_process_generation(generated_text, expected="box") + assert len(boxes) == 1 + assert boxes[0].mention == "sign" + + +@require_torch +def test_post_process_generation_rejects_polygons_with_fewer_than_three_points(): + processor = _make_post_process_processor() + + with pytest.raises(ValueError, match=r"Malformed tag"): + processor.post_process_generation('(10, 20), (30, 40)', expected="polygon") + + @require_torch @require_vision def test_single_image_returns_offsets_and_lengths(isaac_processor): diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 7366845c4d78..3ac8ff1573f8 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -3246,6 +3246,67 @@ def test_vision_language_model(self): assert dec is model.model.language_model, f"LLaVA get_decoder() should return language_model, got {type(dec)}" +class TestEmbeddingAccessMixin(unittest.TestCase): + def test_get_input_embeddings_supports_dotted_input_embed_layer(self): + class NestedEmbeddingModel(PreTrainedModel): + config_class = PreTrainedConfig + _input_embed_layer = "text_model.embed_tokens" + + def __init__(self, config): + super().__init__(config) + self.text_model = nn.Module() + self.text_model.embed_tokens = nn.Embedding(8, 4) + + def forward(self, input_ids=None): + return input_ids + + model = NestedEmbeddingModel(PreTrainedConfig()) + + assert model.get_input_embeddings() is model.text_model.embed_tokens + + def test_set_input_embeddings_supports_dotted_input_embed_layer(self): + class NestedEmbeddingModel(PreTrainedModel): + config_class = PreTrainedConfig + _input_embed_layer = "text_model.embed_tokens" + + def __init__(self, config): + super().__init__(config) + self.text_model = nn.Module() + self.text_model.embed_tokens = nn.Embedding(8, 4) + + def forward(self, input_ids=None): + return input_ids + + model = NestedEmbeddingModel(PreTrainedConfig()) + new_embeddings = nn.Embedding(10, 4) + + model.set_input_embeddings(new_embeddings) + + assert model.get_input_embeddings() is new_embeddings + assert model.text_model.embed_tokens is new_embeddings + + def test_invalid_dotted_input_embed_layer_raises(self): + class NestedEmbeddingModel(PreTrainedModel): + config_class = PreTrainedConfig + _input_embed_layer = "text_model.missing_embed_tokens" + + def __init__(self, config): + super().__init__(config) + self.text_model = nn.Module() + self.text_model.embed_tokens = nn.Embedding(8, 4) + + def forward(self, input_ids=None): + return input_ids + + model = NestedEmbeddingModel(PreTrainedConfig()) + + with self.assertRaises(NotImplementedError): + model.get_input_embeddings() + + with self.assertRaises(NotImplementedError): + model.set_input_embeddings(nn.Embedding(10, 4)) + + class TestGetEncoder(unittest.TestCase): def test_seq2seq_lm_get_encoder_returns_encoder(self): cfg = BartConfig( From 67ae690ad1b4105394a5821ba3a62c2aac2c8fcf Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Wed, 25 Mar 2026 00:58:24 +0400 Subject: [PATCH 66/77] style: cleanup (#17) * fix: use torchvisionbackend * fix: import IsaacImageProcessor * fix: resample not interpolation * style: orgranize import * chore: auto processing auto from main * feat: register isaac image processor according to new convention * fix: update to new config style * fix: correct pix2struct import * docs: initial doc update * feat: re-register isaac processor to auto * refactor: move max_posiiton_embeddings to isaac config refactor: move max_position_embeddings to isaac config * TEMP pop! * docs: update date * style: remove removed attr * style: add config attr for completeness * style: drop redundant merge_with_config_defaults * style: remove redundant positions ids handling * refactor: rely on base class for setting embeddings * fix: always use full attention * style: clarify padding logic * chore: remove stale artifcat * fix: kwargs name! * refactor: isolate custom padding to image processor pad method * feat: no device movement * style: align with transformers standard for loading rope params * refactor: drop unneeded arg filter * feat: compile check image presence instead * docs: add clarifying comment for keeping empty tensors * refactor: move broadcasting to forward WIP * style: use new layer validation functionality * feat: update embedding access mixin to support nested paths! * chore: convert artifacts * refactor: inline build batch * style: drop duplicate test * test: polygons * feat: polygon extraction * test: polygon generation test * style: align with new config implementation style * style: add date * chore: remove all isaac image processor fast * test: new image processor test setup * feat: special path for uint8 interp * test: update image processor test for torchvision backend * fix: don't mutate nested outputs; copy image index * style: drop redundant copy * chore: make fix-repo * chore: convert artifacts * chore: fix docstring --- docs/source/en/model_doc/isaac.md | 6 +- .../models/isaac/configuration_isaac.py | 121 ++++++++-------- .../models/isaac/image_processing_isaac.py | 13 +- .../models/isaac/modular_isaac.py | 131 ++++++++++-------- .../models/isaac/processing_isaac.py | 23 ++- .../isaac/test_image_processing_isaac.py | 30 ++-- .../isaac/test_post_processing_isaac.py | 4 +- tests/models/isaac/test_processing_isaac.py | 10 +- 8 files changed, 177 insertions(+), 161 deletions(-) diff --git a/docs/source/en/model_doc/isaac.md b/docs/source/en/model_doc/isaac.md index eff6634c4b29..27bccc1dd670 100644 --- a/docs/source/en/model_doc/isaac.md +++ b/docs/source/en/model_doc/isaac.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-30.* +*This model was released on {release_date} and added to Hugging Face Transformers on 2026-03-24.*
@@ -153,7 +153,3 @@ Set `expected="point"` to extract point annotations, or leave `expected=None` to ## IsaacImageProcessor [[autodoc]] IsaacImageProcessor - -## IsaacImageProcessorFast - -[[autodoc]] IsaacImageProcessorFast diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index 2f0bdbeccdc2..944f57d37ae5 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -23,17 +23,19 @@ from ...configuration_utils import PreTrainedConfig, PretrainedConfig from ...modeling_rope_utils import RopeParameters +from ...utils import auto_docstring +@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") @strict(accept_kwargs=True) class IsaacVisionConfig(PreTrainedConfig): - """Vision configuration for Isaac with Pixel Shuffle support. - - Extends Siglip2VisionConfig with additional fields for pixel shuffle. - - Args: - pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1): - Spatial factor applied before pixel shuffle reduces the resolution. + r""" + num_patches (`int`, *optional*, defaults to 256): + The number of patches in the image with the size of (`patch_size`, `patch_size`). The image is resized to + fill a maximum of this number of patches while preserving the aspect ratio. If the resulting number of patches + is lower, the image is padded in the patch dimension. + pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1): + Spatial factor applied before pixel shuffle reduces the resolution. """ model_type = "isaac_vision" @@ -54,25 +56,20 @@ class IsaacVisionConfig(PreTrainedConfig): pixel_shuffle_scale_factor: int = 1 +@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") @strict(accept_kwargs=True) class IsaacTextConfig(PreTrainedConfig): r""" - max_window_layers (`int`, *optional*, defaults to 28): - The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any - additional layer afterwards will use SWA (Sliding Window Attention). + Example: ```python - >>> from transformers import IsaacTextModel, IsaacTextConfig + >>> from transformers import IsaacTextConfig, IsaacTextModel - >>> # Initializing a IsaacText style configuration >>> configuration = IsaacTextConfig() - - >>> # Initializing a model from the IsaacText-8B style configuration >>> model = IsaacTextModel(configuration) - - >>> # Accessing the model configuration >>> configuration = model.config - ```""" + ``` + """ model_type = "isaac_text" keys_to_ignore_at_inference = ["past_key_values"] @@ -136,62 +133,66 @@ def __post_init__(self, **kwargs): self.validate_layer_type() +@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") +@strict(accept_kwargs=True) class IsaacConfig(PretrainedConfig): - """Configuration class for Isaac multimodal model. - - This configuration corresponds to checkpoints such as - [Perceptron/isaac-base](https://huggingface.co/Perceptron/isaac-base). - - Args: - vision_config (`IsaacVisionConfig`, *optional*): - Configuration for the Isaac vision tower. If unset, the default [`IsaacVisionConfig`] is used. - text_config (`IsaacTextConfig` or `dict`, *optional*): - Configuration for the text backbone. Dictionaries are converted to [`IsaacTextConfig`]. - vision_rescale_factor (`float`, *optional*, defaults to 1 / 255): - Rescale factor applied by the image processor before normalization. - max_sequence_length (`int`, *optional*, defaults to 16384): - Maximum multimodal sequence length produced by the processor and expected by the model. - vision_token (`str`, *optional*, defaults to `""`): - Placeholder string inserted into text prompts to mark image positions. + r""" + vision_config (`IsaacVisionConfig` or `dict`, *optional*): + Configuration for the Isaac vision tower. Dictionaries are converted to [`IsaacVisionConfig`]. If unset, + the default [`IsaacVisionConfig`] is used. + text_config (`IsaacTextConfig` or `dict`, *optional*): + Configuration for the text backbone. Dictionaries are converted to [`IsaacTextConfig`]. + vision_rescale_factor (`float`, *optional*, defaults to 1 / 255): + Rescale factor applied by the image processor before normalization. + max_sequence_length (`int`, *optional*, defaults to 16384): + Maximum multimodal sequence length produced by the processor and expected by the model. + vision_token (`str`, *optional*, defaults to `""`): + Placeholder string inserted into text prompts to mark image positions. + + Example: + + ```python + >>> from transformers import IsaacConfig, IsaacModel + + >>> configuration = IsaacConfig() + >>> model = IsaacModel(configuration) + >>> configuration = model.config + ``` """ model_type = "isaac" sub_configs = {"vision_config": IsaacVisionConfig, "text_config": IsaacTextConfig} + vision_config: IsaacVisionConfig | dict | None = None + text_config: IsaacTextConfig | dict | None = None + vision_rescale_factor: float = 1 / 255 + max_sequence_length: int = 16384 + vision_token: str = "" - def __init__( - self, - vision_config: IsaacVisionConfig | None = None, - text_config: IsaacTextConfig | dict | None = None, - vision_rescale_factor: float = 1 / 255, - max_sequence_length: int = 16384, - vision_token: str = "", - **kwargs, - ): + def __post_init__(self, **kwargs): for key in ("use_cache", "rope_theta", "max_position_embeddings"): kwargs.pop(key, None) - if isinstance(text_config, dict): - self.text_config = self.sub_configs["text_config"](**text_config) - elif isinstance(text_config, IsaacTextConfig): - self.text_config = text_config - elif text_config is None: + if isinstance(self.text_config, dict): + self.text_config = self.sub_configs["text_config"](**self.text_config) + elif self.text_config is None: self.text_config = self.sub_configs["text_config"]() - - if isinstance(vision_config, dict): - self.vision_config = self.sub_configs["vision_config"](**vision_config) - elif isinstance(vision_config, IsaacVisionConfig): - self.vision_config = vision_config - elif vision_config is None: + elif not isinstance(self.text_config, IsaacTextConfig): + raise TypeError( + f"text_config must be a dict or an IsaacTextConfig instance, got {type(self.text_config).__name__}." + ) + + if isinstance(self.vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**self.vision_config) + elif self.vision_config is None: self.vision_config = self.sub_configs["vision_config"]() + elif not isinstance(self.vision_config, IsaacVisionConfig): + raise TypeError( + f"vision_config must be a dict or an IsaacVisionConfig instance, got {type(self.vision_config).__name__}." + ) - super().__init__(**kwargs) + self.vision_rescale_factor = float(self.vision_rescale_factor) - # Vision normalization parameters - self.vision_rescale_factor = float(vision_rescale_factor) - - # Processing parameters - self.max_sequence_length = max_sequence_length - self.vision_token = vision_token + super().__post_init__(**kwargs) __all__ = ["IsaacConfig", "IsaacTextConfig", "IsaacVisionConfig"] diff --git a/src/transformers/models/isaac/image_processing_isaac.py b/src/transformers/models/isaac/image_processing_isaac.py index a4c3275f98cb..8573e5392e70 100644 --- a/src/transformers/models/isaac/image_processing_isaac.py +++ b/src/transformers/models/isaac/image_processing_isaac.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import math from collections.abc import Sequence from typing import Any @@ -214,6 +215,9 @@ def resize( size: SizeDict, **kwargs, ) -> torch.Tensor: + if image.dtype == torch.uint8: + image = F.interpolate(image.float(), size=(size.height, size.width), mode="bilinear", align_corners=False) + return image.clamp(0, 255).round().to(torch.uint8) return F.interpolate(image, size=(size.height, size.width), mode="bilinear", align_corners=False) def pad( @@ -336,14 +340,13 @@ def _preprocess( ) keys = ("vision_patches", "vision_token_grids") - nested_outputs = { - key: reorder_images( + nested_outputs = {} + for i, key in enumerate(keys): + nested_outputs[key] = reorder_images( {shape: values[i] for shape, values in grouped_outputs.items()}, - grouped_images_index, + dict(grouped_images_index), is_nested=True, ) - for i, key in enumerate(keys) - } if not do_pad: raise ValueError("IsaacImageProcessor doesn't support `do_pad=False` mode.") diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 40ec4b2d6e6d..8398c64fdc5b 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -14,7 +14,6 @@ from __future__ import annotations -import copy import math import re from collections.abc import Sequence @@ -196,15 +195,16 @@ def clean_text_and_extract_points( return clean_text, results +@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") @strict(accept_kwargs=True) class IsaacVisionConfig(Siglip2VisionConfig): - """Vision configuration for Isaac with Pixel Shuffle support. - - Extends Siglip2VisionConfig with additional fields for pixel shuffle. - - Args: - pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1): - Spatial factor applied before pixel shuffle reduces the resolution. + r""" + num_patches (`int`, *optional*, defaults to 256): + The number of patches in the image with the size of (`patch_size`, `patch_size`). The image is resized to + fill a maximum of this number of patches while preserving the aspect ratio. If the resulting number of patches + is lower, the image is padded in the patch dimension. + pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1): + Spatial factor applied before pixel shuffle reduces the resolution. """ model_type = "isaac_vision" @@ -213,8 +213,21 @@ class IsaacVisionConfig(Siglip2VisionConfig): pixel_shuffle_scale_factor: int = 1 +@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") @strict(accept_kwargs=True) class IsaacTextConfig(Qwen3Config): + r""" + Example: + + ```python + >>> from transformers import IsaacTextConfig, IsaacTextModel + + >>> configuration = IsaacTextConfig() + >>> model = IsaacTextModel(configuration) + >>> configuration = model.config + ``` + """ + model_type = "isaac_text" ignore_keys_at_rope_validation = {"mrope_section", "mrope_interleaved"} max_position_embeddings: int = 32768 @@ -287,6 +300,9 @@ def resize( size: SizeDict, **kwargs, ) -> torch.Tensor: + if image.dtype == torch.uint8: + image = F.interpolate(image.float(), size=(size.height, size.width), mode="bilinear", align_corners=False) + return image.clamp(0, 255).round().to(torch.uint8) return F.interpolate(image, size=(size.height, size.width), mode="bilinear", align_corners=False) def pad( @@ -409,14 +425,13 @@ def _preprocess( ) keys = ("vision_patches", "vision_token_grids") - nested_outputs = { - key: reorder_images( + nested_outputs = {} + for i, key in enumerate(keys): + nested_outputs[key] = reorder_images( {shape: values[i] for shape, values in grouped_outputs.items()}, - grouped_images_index, + dict(grouped_images_index), is_nested=True, ) - for i, key in enumerate(keys) - } if not do_pad: raise ValueError("IsaacImageProcessor doesn't support `do_pad=False` mode.") @@ -743,62 +758,66 @@ def get_image_size_for_max_num_patches( return target_height, target_width +@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") +@strict(accept_kwargs=True) class IsaacConfig(PretrainedConfig): - """Configuration class for Isaac multimodal model. - - This configuration corresponds to checkpoints such as - [Perceptron/isaac-base](https://huggingface.co/Perceptron/isaac-base). - - Args: - vision_config (`IsaacVisionConfig`, *optional*): - Configuration for the Isaac vision tower. If unset, the default [`IsaacVisionConfig`] is used. - text_config (`IsaacTextConfig` or `dict`, *optional*): - Configuration for the text backbone. Dictionaries are converted to [`IsaacTextConfig`]. - vision_rescale_factor (`float`, *optional*, defaults to 1 / 255): - Rescale factor applied by the image processor before normalization. - max_sequence_length (`int`, *optional*, defaults to 16384): - Maximum multimodal sequence length produced by the processor and expected by the model. - vision_token (`str`, *optional*, defaults to `""`): - Placeholder string inserted into text prompts to mark image positions. + r""" + vision_config (`IsaacVisionConfig` or `dict`, *optional*): + Configuration for the Isaac vision tower. Dictionaries are converted to [`IsaacVisionConfig`]. If unset, + the default [`IsaacVisionConfig`] is used. + text_config (`IsaacTextConfig` or `dict`, *optional*): + Configuration for the text backbone. Dictionaries are converted to [`IsaacTextConfig`]. + vision_rescale_factor (`float`, *optional*, defaults to 1 / 255): + Rescale factor applied by the image processor before normalization. + max_sequence_length (`int`, *optional*, defaults to 16384): + Maximum multimodal sequence length produced by the processor and expected by the model. + vision_token (`str`, *optional*, defaults to `""`): + Placeholder string inserted into text prompts to mark image positions. + + Example: + + ```python + >>> from transformers import IsaacConfig, IsaacModel + + >>> configuration = IsaacConfig() + >>> model = IsaacModel(configuration) + >>> configuration = model.config + ``` """ model_type = "isaac" sub_configs = {"vision_config": IsaacVisionConfig, "text_config": IsaacTextConfig} + vision_config: IsaacVisionConfig | dict | None = None + text_config: IsaacTextConfig | dict | None = None + vision_rescale_factor: float = 1 / 255 + max_sequence_length: int = 16384 + vision_token: str = "" - def __init__( - self, - vision_config: IsaacVisionConfig | None = None, - text_config: IsaacTextConfig | dict | None = None, - vision_rescale_factor: float = 1 / 255, - max_sequence_length: int = 16384, - vision_token: str = "", - **kwargs, - ): + def __post_init__(self, **kwargs): for key in ("use_cache", "rope_theta", "max_position_embeddings"): kwargs.pop(key, None) - if isinstance(text_config, dict): - self.text_config = self.sub_configs["text_config"](**text_config) - elif isinstance(text_config, IsaacTextConfig): - self.text_config = text_config - elif text_config is None: + if isinstance(self.text_config, dict): + self.text_config = self.sub_configs["text_config"](**self.text_config) + elif self.text_config is None: self.text_config = self.sub_configs["text_config"]() + elif not isinstance(self.text_config, IsaacTextConfig): + raise TypeError( + f"text_config must be a dict or an IsaacTextConfig instance, got {type(self.text_config).__name__}." + ) - if isinstance(vision_config, dict): - self.vision_config = self.sub_configs["vision_config"](**vision_config) - elif isinstance(vision_config, IsaacVisionConfig): - self.vision_config = vision_config - elif vision_config is None: + if isinstance(self.vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**self.vision_config) + elif self.vision_config is None: self.vision_config = self.sub_configs["vision_config"]() + elif not isinstance(self.vision_config, IsaacVisionConfig): + raise TypeError( + f"vision_config must be a dict or an IsaacVisionConfig instance, got {type(self.vision_config).__name__}." + ) - super().__init__(**kwargs) - - # Vision normalization parameters - self.vision_rescale_factor = float(vision_rescale_factor) + self.vision_rescale_factor = float(self.vision_rescale_factor) - # Processing parameters - self.max_sequence_length = max_sequence_length - self.vision_token = vision_token + super().__post_init__(**kwargs) @auto_docstring @@ -872,7 +891,7 @@ def __call__( tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) - text_kwargs = copy.deepcopy(output_kwargs["text_kwargs"]) + text_kwargs = output_kwargs["text_kwargs"] truncation = text_kwargs.pop("truncation", None) max_length = text_kwargs.pop("max_length", None) padding = text_kwargs.pop("padding", True) diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 8b1696a7b501..92c111f38845 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -18,8 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import copy import re from ...feature_extraction_utils import BatchFeature @@ -136,16 +134,15 @@ def __init__( max_sequence_length: int = 16384, rescale_factor: float | None = None, ): - """ - Args: - chat_template (`str` or `dict[str, str]`, *optional*): - Chat template override forwarded to [`~processing_utils.ProcessorMixin`]. - vision_token (`str`, *optional*, defaults to `""`): - Placeholder token used inside text prompts to mark image positions. - max_sequence_length (`int`, *optional*, defaults to 16384): - Maximum packed multimodal sequence length produced by the processor. - rescale_factor (`float`, *optional*): - Deprecated compatibility argument accepted for backward compatibility. + r""" + chat_template (`str` or `dict[str, str]`, *optional*): + Chat template override forwarded to [`~processing_utils.ProcessorMixin`]. + vision_token (`str`, *optional*, defaults to `""`): + Placeholder token used inside text prompts to mark image positions. + max_sequence_length (`int`, *optional*, defaults to 16384): + Maximum packed multimodal sequence length produced by the processor. + rescale_factor (`float`, *optional*): + Deprecated compatibility argument accepted for backward compatibility. """ if chat_template is None: chat_template = getattr(tokenizer, "chat_template", None) @@ -196,7 +193,7 @@ def __call__( tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) - text_kwargs = copy.deepcopy(output_kwargs["text_kwargs"]) + text_kwargs = output_kwargs["text_kwargs"] truncation = text_kwargs.pop("truncation", None) max_length = text_kwargs.pop("max_length", None) padding = text_kwargs.pop("padding", True) diff --git a/tests/models/isaac/test_image_processing_isaac.py b/tests/models/isaac/test_image_processing_isaac.py index 1b968087b277..27ee2448f8d3 100644 --- a/tests/models/isaac/test_image_processing_isaac.py +++ b/tests/models/isaac/test_image_processing_isaac.py @@ -37,7 +37,7 @@ from PIL import Image if is_torchvision_available(): - from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessorFast + from transformers.models.isaac.image_processing_isaac import IsaacImageProcessor def _make_dummy_image(size=(32, 32), color=(255, 0, 0)): @@ -122,7 +122,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F @require_vision class IsaacImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = None - fast_image_processing_class = IsaacImageProcessorFast if is_torchvision_available() else None + fast_image_processing_class = IsaacImageProcessor if is_torchvision_available() else None test_slow_image_processor = False def setUp(self): @@ -185,7 +185,7 @@ def _assert_encoding_close(self, eager_encoding, compiled_encoding): torch.testing.assert_close(eager_encoding["vision_token_grids"], compiled_encoding["vision_token_grids"]) def test_image_processor_properties(self): - for image_processing_class in self.image_processor_list: + for image_processing_class in self.image_processing_classes.values(): image_processor = image_processing_class(**self.image_processor_dict) self.assertTrue(hasattr(image_processor, "do_resize")) self.assertTrue(hasattr(image_processor, "do_rescale")) @@ -200,7 +200,7 @@ def test_image_processor_properties(self): self.assertTrue(hasattr(image_processor, "do_convert_rgb")) def test_call_pil(self): - for image_processing_class in self.image_processor_list: + for image_processing_class in self.image_processing_classes.values(): image_processor = image_processing_class(**self.image_processor_dict) image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) @@ -224,7 +224,7 @@ def test_call_pil(self): ) def test_call_numpy(self): - for image_processing_class in self.image_processor_list: + for image_processing_class in self.image_processing_classes.values(): image_processor = image_processing_class(**self.image_processor_dict) image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) @@ -248,7 +248,7 @@ def test_call_numpy(self): ) def test_call_pytorch(self): - for image_processing_class in self.image_processor_list: + for image_processing_class in self.image_processing_classes.values(): image_processor = image_processing_class(**self.image_processor_dict) image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) @@ -276,7 +276,7 @@ def test_call_numpy_4_channels(self): pass def test_nested_multi_image_batch_preserves_grids_and_padding(self): - for image_processing_class in self.image_processor_list: + for image_processing_class in self.image_processing_classes.values(): image_processor = image_processing_class( **{ **self.image_processor_dict, @@ -322,7 +322,7 @@ def test_nested_multi_image_batch_preserves_grids_and_padding(self): torch.testing.assert_close(encoding["vision_patch_attention_mask"].sum(dim=-1), expected_patch_counts) def test_all_empty_images_returns_zero_sized_tensors(self): - for image_processing_class in self.image_processor_list: + for image_processing_class in self.image_processing_classes.values(): image_processor = image_processing_class(**self.image_processor_dict) encoding = image_processor([[], []], return_tensors="pt") @@ -337,7 +337,7 @@ def test_all_empty_images_returns_zero_sized_tensors(self): self.assertEqual(encoding["vision_token_grids"].dtype, torch.long) def test_do_resize_false_requires_patch_divisibility(self): - for image_processing_class in self.image_processor_list: + for image_processing_class in self.image_processing_classes.values(): image_processor = image_processing_class( **{ **self.image_processor_dict, @@ -350,7 +350,7 @@ def test_do_resize_false_requires_patch_divisibility(self): image_processor([[_make_dummy_image(size=(31, 32))]], return_tensors="pt") def test_pixel_shuffle_scale_requires_divisible_token_grid(self): - for image_processing_class in self.image_processor_list: + for image_processing_class in self.image_processing_classes.values(): image_processor = image_processing_class( **{ **self.image_processor_dict, @@ -364,7 +364,7 @@ def test_pixel_shuffle_scale_requires_divisible_token_grid(self): image_processor([[_make_dummy_image(size=(32, 16))]], return_tensors="pt") def test_cast_dtype_device(self): - for image_processing_class in self.image_processor_list: + for image_processing_class in self.image_processing_classes.values(): image_processor = image_processing_class(**self.image_processor_dict) image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) @@ -403,13 +403,13 @@ def test_cast_dtype_device(self): @require_torch_accelerator @require_vision @pytest.mark.torch_compile_test - def test_can_compile_fast_image_processor(self): - if self.fast_image_processing_class is None: - self.skipTest("Skipping compilation test as fast image processor is not defined") + def test_can_compile_torchvision_backend(self): + if "torchvision" not in self.image_processing_classes: + self.skipTest("Skipping compilation test as torchvision backend is not available") torch.compiler.reset() input_image = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8) - image_processor = self.fast_image_processing_class(**self.image_processor_dict) + image_processor = self.image_processing_classes["torchvision"](**self.image_processor_dict) output_eager = image_processor(input_image, device=torch_device, return_tensors="pt") image_processor = torch.compile(image_processor, mode="reduce-overhead") diff --git a/tests/models/isaac/test_post_processing_isaac.py b/tests/models/isaac/test_post_processing_isaac.py index 613ce3a9732f..32c52c175ffd 100644 --- a/tests/models/isaac/test_post_processing_isaac.py +++ b/tests/models/isaac/test_post_processing_isaac.py @@ -17,7 +17,7 @@ import pytest from transformers import PythonBackend -from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessorFast +from transformers.models.isaac.image_processing_isaac import IsaacImageProcessor from transformers.models.isaac.processing_isaac import IsaacProcessor from transformers.testing_utils import require_torch @@ -77,7 +77,7 @@ def save_vocabulary(self, save_directory, filename_prefix=None): def _make_processor(): - return IsaacProcessor(image_processor=IsaacImageProcessorFast(), tokenizer=SimpleIsaacTokenizer()) + return IsaacProcessor(image_processor=IsaacImageProcessor(), tokenizer=SimpleIsaacTokenizer()) @require_torch diff --git a/tests/models/isaac/test_processing_isaac.py b/tests/models/isaac/test_processing_isaac.py index daee6affcb97..3a67d064b8c9 100644 --- a/tests/models/isaac/test_processing_isaac.py +++ b/tests/models/isaac/test_processing_isaac.py @@ -24,7 +24,7 @@ from huggingface_hub import is_offline_mode from transformers import IsaacConfig, PythonBackend -from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessorFast +from transformers.models.isaac.image_processing_isaac import IsaacImageProcessor from transformers.models.isaac.processing_isaac import IsaacProcessor from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_vision_available @@ -202,7 +202,7 @@ def _make_processor_with_max_len(tokenizer, base_config, max_len): config = IsaacConfig(**base_config.to_dict()) config.max_sequence_length = max_len vision_config = config.vision_config - image_processor = IsaacImageProcessorFast( + image_processor = IsaacImageProcessor( patch_size=vision_config.patch_size, max_num_patches=vision_config.num_patches, pixel_shuffle_scale=vision_config.pixel_shuffle_scale_factor, @@ -221,12 +221,12 @@ def _run_processor(processor, text, images=None): def _make_post_process_processor(): - return IsaacProcessor(image_processor=IsaacImageProcessorFast(), tokenizer=SimpleIsaacTokenizer()) + return IsaacProcessor(image_processor=IsaacImageProcessor(), tokenizer=SimpleIsaacTokenizer()) def test_processor_prefers_named_image_pad_token(): processor = IsaacProcessor( - image_processor=IsaacImageProcessorFast(), tokenizer=SimpleIsaacTokenizerWithNamedImagePad() + image_processor=IsaacImageProcessor(), tokenizer=SimpleIsaacTokenizerWithNamedImagePad() ) assert processor.image_token == "" @@ -346,7 +346,7 @@ def isaac_tokenizer(): @pytest.fixture def isaac_processor(isaac_tokenizer, isaac_tiny_config): vision_config = isaac_tiny_config.vision_config - image_processor = IsaacImageProcessorFast( + image_processor = IsaacImageProcessor( patch_size=vision_config.patch_size, max_num_patches=vision_config.num_patches, pixel_shuffle_scale=vision_config.pixel_shuffle_scale_factor, From bbd8289a5a092cf4431dae62e11b31a42f49ce48 Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Wed, 25 Mar 2026 01:57:55 +0400 Subject: [PATCH 67/77] style: unify image attention mask + import update (#18) * fix: use torchvisionbackend * fix: import IsaacImageProcessor * fix: resample not interpolation * style: orgranize import * chore: auto processing auto from main * feat: register isaac image processor according to new convention * fix: update to new config style * fix: correct pix2struct import * docs: initial doc update * feat: re-register isaac processor to auto * refactor: move max_posiiton_embeddings to isaac config refactor: move max_position_embeddings to isaac config * TEMP pop! * docs: update date * style: remove removed attr * style: add config attr for completeness * style: drop redundant merge_with_config_defaults * style: remove redundant positions ids handling * refactor: rely on base class for setting embeddings * fix: always use full attention * style: clarify padding logic * chore: remove stale artifcat * fix: kwargs name! * refactor: isolate custom padding to image processor pad method * feat: no device movement * style: align with transformers standard for loading rope params * refactor: drop unneeded arg filter * feat: compile check image presence instead * docs: add clarifying comment for keeping empty tensors * refactor: move broadcasting to forward WIP * style: use new layer validation functionality * feat: update embedding access mixin to support nested paths! * chore: convert artifacts * refactor: inline build batch * style: drop duplicate test * test: polygons * feat: polygon extraction * test: polygon generation test * style: align with new config implementation style * style: add date * chore: remove all isaac image processor fast * test: new image processor test setup * feat: special path for uint8 interp * test: update image processor test for torchvision backend * fix: don't mutate nested outputs; copy image index * style: drop redundant copy * chore: make fix-repo * chore: convert artifacts * chore: fix docstring * style: update imports * refactor: unify vision attention mask test: unified vision attention mask fix: correct arg name * fix: standardize on the image attention mask name --- .../models/isaac/image_processing_isaac.py | 13 +++-- .../models/isaac/modeling_isaac.py | 30 ++++------ .../models/isaac/modular_isaac.py | 57 +++++++------------ .../models/isaac/processing_isaac.py | 4 +- tests/models/isaac/test_modeling_isaac.py | 2 +- tests/models/isaac/test_processing_isaac.py | 10 ++-- 6 files changed, 44 insertions(+), 72 deletions(-) diff --git a/src/transformers/models/isaac/image_processing_isaac.py b/src/transformers/models/isaac/image_processing_isaac.py index 8573e5392e70..7beaca45be08 100644 --- a/src/transformers/models/isaac/image_processing_isaac.py +++ b/src/transformers/models/isaac/image_processing_isaac.py @@ -25,8 +25,9 @@ from ... import TorchvisionBackend from ...feature_extraction_utils import BatchFeature -from ...image_processing_utils_fast import ImagesKwargs, SizeDict, group_images_by_shape, reorder_images -from ...image_utils import ImageInput, PILImageResampling, make_nested_list_of_images +from ...image_transforms import group_images_by_shape, reorder_images +from ...image_utils import ImageInput, PILImageResampling, SizeDict, make_nested_list_of_images +from ...processing_utils import ImagesKwargs from ...utils import TensorType, auto_docstring from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD @@ -177,7 +178,7 @@ class IsaacImageProcessor(TorchvisionBackend): resample = PILImageResampling.BILINEAR model_input_names = [ "vision_patches", - "vision_patch_attention_mask", + "image_patch_attention_mask", "vision_token_grids", ] valid_kwargs = IsaacImageProcessorKwargs @@ -237,7 +238,7 @@ def pad( "vision_patches": torch.zeros( (batch_size, max_images, max_patches, patch_dim), device=patch_device, dtype=patch_dtype ), - "vision_patch_attention_mask": torch.zeros( + "image_patch_attention_mask": torch.zeros( (batch_size, max_images, max_patches), device=patch_device, dtype=torch.long ), "vision_token_grids": torch.zeros((batch_size, max_images, 2), device=patch_device, dtype=torch.long), @@ -249,7 +250,7 @@ def pad( for image_idx, (patches, token_grid) in enumerate(zip(sample_patches, sample_token_grids, strict=True)): patch_count = int(patches.shape[0]) tensors["vision_patches"][batch_idx, image_idx, :patch_count] = patches - tensors["vision_patch_attention_mask"][batch_idx, image_idx, :patch_count] = 1 + tensors["image_patch_attention_mask"][batch_idx, image_idx, :patch_count] = 1 tensors["vision_token_grids"][batch_idx, image_idx] = token_grid return tensors @@ -281,7 +282,7 @@ def _preprocess( if all(len(sample_images) == 0 for sample_images in images): tensors = { "vision_patches": torch.zeros((batch_size, 0, 0, 0), dtype=torch.float32), - "vision_patch_attention_mask": torch.zeros((batch_size, 0, 0), dtype=torch.long), + "image_patch_attention_mask": torch.zeros((batch_size, 0, 0), dtype=torch.long), "vision_token_grids": torch.zeros((batch_size, 0, 2), dtype=torch.long), } return BatchFeature(data=tensors, tensor_type=return_tensors) diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index e843db2ce2b4..dd468f76e41a 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -458,7 +458,7 @@ def forward( self, vision_patches: torch.Tensor, vision_token_grids: torch.Tensor, - vision_patch_attention_mask: torch.Tensor, + image_patch_attention_mask: torch.Tensor, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: """ @@ -467,7 +467,7 @@ def forward( Patches shaped `(num_images, max_patches, patch_dim)`. vision_token_grids (`torch.Tensor`): Token grids shaped `(num_images, 2)` with per-image `(H_tokens, W_tokens)`. - vision_patch_attention_mask (`torch.Tensor`): + image_patch_attention_mask (`torch.Tensor`): Patch mask shaped `(num_images, max_patches)`. Returns: @@ -476,13 +476,13 @@ def forward( hidden_states = self.embeddings( vision_patches, vision_token_grids, - attention_mask=vision_patch_attention_mask, + attention_mask=image_patch_attention_mask, ) encoder_attention_mask = create_bidirectional_mask( config=self.config, inputs_embeds=hidden_states, - attention_mask=vision_patch_attention_mask, + attention_mask=image_patch_attention_mask, ) encoder_outputs = self.encoder(inputs_embeds=hidden_states, attention_mask=encoder_attention_mask, **kwargs) hidden_states = self.post_layernorm(encoder_outputs.last_hidden_state) @@ -1200,7 +1200,7 @@ def get_image_features( vision_outputs = self.vision_tower( vision_patches=pixel_values[image_attention_mask], vision_token_grids=image_token_grids[image_attention_mask], - vision_patch_attention_mask=patch_attention_mask[image_attention_mask], + image_patch_attention_mask=patch_attention_mask[image_attention_mask], return_dict=True, **kwargs, ) @@ -1410,7 +1410,6 @@ def forward( input_ids: torch.LongTensor | None = None, mm_token_type_ids: torch.LongTensor | None = None, vision_patches: torch.Tensor | None = None, - vision_patch_attention_mask: torch.Tensor | None = None, image_patch_attention_mask: torch.Tensor | None = None, vision_token_grids: torch.LongTensor | None = None, image_token_grids: torch.LongTensor | None = None, @@ -1431,10 +1430,8 @@ def forward( follows the standard convention `0 -> text`, `1 -> image`. Treated as text-only when omitted. vision_patches (`torch.FloatTensor`, *optional*): Padded per-image patch vectors of shape `(batch_size, max_images, max_patches, patch_dim)`. - vision_patch_attention_mask (`torch.LongTensor`, *optional*): - Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. image_patch_attention_mask (`torch.LongTensor`, *optional*): - Alias for `vision_patch_attention_mask`. + Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. vision_token_grids (`torch.LongTensor`, *optional*): Per-image patch grids `(h, w)` with shape `(batch_size, max_images, 2)`. image_token_grids (`torch.LongTensor`, *optional*): @@ -1462,7 +1459,7 @@ def forward( image_outputs = self.get_image_features( pixel_values=vision_patches, image_token_grids=vision_token_grids, - image_patch_attention_mask=vision_patch_attention_mask, + image_patch_attention_mask=image_patch_attention_mask, image_token_offsets=vision_token_offsets, image_token_lengths=vision_token_lengths, image_attention_mask=image_attention_mask, @@ -1541,7 +1538,6 @@ def forward( mm_token_type_ids: torch.LongTensor | None = None, vision_patches: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None, - vision_patch_attention_mask: torch.Tensor | None = None, image_patch_attention_mask: torch.Tensor | None = None, vision_token_grids: torch.LongTensor | None = None, image_token_grids: torch.LongTensor | None = None, @@ -1564,10 +1560,8 @@ def forward( Padded per-image patch vectors of shape `(batch_size, max_images, max_patches, patch_dim)`. pixel_values (`torch.FloatTensor`, *optional*): Alias for `vision_patches` accepted by generic image-feature and generation helpers. - vision_patch_attention_mask (`torch.LongTensor`, *optional*): - Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. image_patch_attention_mask (`torch.LongTensor`, *optional*): - Alias for `vision_patch_attention_mask`. + Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. vision_token_grids (`torch.LongTensor`, *optional*): Per-image patch grids `(h, w)` with shape `(batch_size, max_images, 2)`. image_token_grids (`torch.LongTensor`, *optional*): @@ -1584,7 +1578,7 @@ def forward( input_ids=input_ids, mm_token_type_ids=mm_token_type_ids, vision_patches=vision_patches, - vision_patch_attention_mask=vision_patch_attention_mask, + image_patch_attention_mask=image_patch_attention_mask, vision_token_grids=vision_token_grids, vision_token_offsets=vision_token_offsets, vision_token_lengths=vision_token_lengths, @@ -1619,7 +1613,6 @@ def prepare_inputs_for_generation( mm_token_type_ids: torch.LongTensor | None = None, vision_patches: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None, - vision_patch_attention_mask: torch.Tensor | None = None, image_patch_attention_mask: torch.Tensor | None = None, vision_token_grids: torch.LongTensor | None = None, image_token_grids: torch.LongTensor | None = None, @@ -1632,9 +1625,6 @@ def prepare_inputs_for_generation( **kwargs, ) -> dict[str, Any]: if vision_patches is None: - vision_patch_attention_mask = ( - image_patch_attention_mask if vision_patch_attention_mask is None else vision_patch_attention_mask - ) vision_token_grids = image_token_grids if vision_token_grids is None else vision_token_grids model_inputs = super().prepare_inputs_for_generation( input_ids, @@ -1649,7 +1639,7 @@ def prepare_inputs_for_generation( multimodal_inputs = { "mm_token_type_ids": mm_token_type_ids, "vision_patches": vision_patches, - "vision_patch_attention_mask": vision_patch_attention_mask, + "image_patch_attention_mask": image_patch_attention_mask, "vision_token_grids": vision_token_grids, "vision_token_offsets": vision_token_offsets, "vision_token_lengths": vision_token_lengths, diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 8398c64fdc5b..1536e1ddc1d5 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -26,17 +26,8 @@ from ...configuration_utils import PretrainedConfig from ...feature_extraction_utils import BatchFeature from ...generation.utils import GenerationMixin -from ...image_processing_utils_fast import ( - ImagesKwargs, - SizeDict, - group_images_by_shape, - reorder_images, -) -from ...image_utils import ( - ImageInput, - PILImageResampling, - make_nested_list_of_images, -) +from ...image_transforms import group_images_by_shape, reorder_images +from ...image_utils import ImageInput, PILImageResampling, SizeDict, make_nested_list_of_images from ...masking_utils import create_bidirectional_mask from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel @@ -45,7 +36,7 @@ Qwen3ForCausalLM, Qwen3PreTrainedModel, ) -from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack from ...utils import TensorType, auto_docstring, torch_compilable_check from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD @@ -262,7 +253,7 @@ class IsaacImageProcessor(TorchvisionBackend): resample = PILImageResampling.BILINEAR model_input_names = [ "vision_patches", - "vision_patch_attention_mask", + "image_patch_attention_mask", "vision_token_grids", ] valid_kwargs = IsaacImageProcessorKwargs @@ -322,7 +313,7 @@ def pad( "vision_patches": torch.zeros( (batch_size, max_images, max_patches, patch_dim), device=patch_device, dtype=patch_dtype ), - "vision_patch_attention_mask": torch.zeros( + "image_patch_attention_mask": torch.zeros( (batch_size, max_images, max_patches), device=patch_device, dtype=torch.long ), "vision_token_grids": torch.zeros((batch_size, max_images, 2), device=patch_device, dtype=torch.long), @@ -334,7 +325,7 @@ def pad( for image_idx, (patches, token_grid) in enumerate(zip(sample_patches, sample_token_grids, strict=True)): patch_count = int(patches.shape[0]) tensors["vision_patches"][batch_idx, image_idx, :patch_count] = patches - tensors["vision_patch_attention_mask"][batch_idx, image_idx, :patch_count] = 1 + tensors["image_patch_attention_mask"][batch_idx, image_idx, :patch_count] = 1 tensors["vision_token_grids"][batch_idx, image_idx] = token_grid return tensors @@ -366,7 +357,7 @@ def _preprocess( if all(len(sample_images) == 0 for sample_images in images): tensors = { "vision_patches": torch.zeros((batch_size, 0, 0, 0), dtype=torch.float32), - "vision_patch_attention_mask": torch.zeros((batch_size, 0, 0), dtype=torch.long), + "image_patch_attention_mask": torch.zeros((batch_size, 0, 0), dtype=torch.long), "vision_token_grids": torch.zeros((batch_size, 0, 2), dtype=torch.long), } return BatchFeature(data=tensors, tensor_type=return_tensors) @@ -605,7 +596,7 @@ def forward( self, vision_patches: torch.Tensor, vision_token_grids: torch.Tensor, - vision_patch_attention_mask: torch.Tensor, + image_patch_attention_mask: torch.Tensor, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: """ @@ -614,7 +605,7 @@ def forward( Patches shaped `(num_images, max_patches, patch_dim)`. vision_token_grids (`torch.Tensor`): Token grids shaped `(num_images, 2)` with per-image `(H_tokens, W_tokens)`. - vision_patch_attention_mask (`torch.Tensor`): + image_patch_attention_mask (`torch.Tensor`): Patch mask shaped `(num_images, max_patches)`. Returns: @@ -623,13 +614,13 @@ def forward( hidden_states = self.embeddings( vision_patches, vision_token_grids, - attention_mask=vision_patch_attention_mask, + attention_mask=image_patch_attention_mask, ) encoder_attention_mask = create_bidirectional_mask( config=self.config, inputs_embeds=hidden_states, - attention_mask=vision_patch_attention_mask, + attention_mask=image_patch_attention_mask, ) encoder_outputs = self.encoder(inputs_embeds=hidden_states, attention_mask=encoder_attention_mask, **kwargs) hidden_states = self.post_layernorm(encoder_outputs.last_hidden_state) @@ -1001,7 +992,7 @@ def __call__( vision_image_attention_mask = vision_token_lengths.gt(0).to(dtype=torch.long) vision_patches = image_inputs["vision_patches"] - vision_patch_attention_mask = image_inputs["vision_patch_attention_mask"] + image_patch_attention_mask = image_inputs["image_patch_attention_mask"] return BatchFeature( data={ @@ -1009,7 +1000,7 @@ def __call__( "attention_mask": attention_mask, "mm_token_type_ids": mm_token_type_ids, "vision_patches": vision_patches, - "vision_patch_attention_mask": vision_patch_attention_mask, + "image_patch_attention_mask": image_patch_attention_mask, "vision_token_grids": vision_token_grids, "vision_token_offsets": vision_token_offsets, "vision_token_lengths": vision_token_lengths, @@ -1130,7 +1121,7 @@ def get_image_features( vision_outputs = self.vision_tower( vision_patches=pixel_values[image_attention_mask], vision_token_grids=image_token_grids[image_attention_mask], - vision_patch_attention_mask=patch_attention_mask[image_attention_mask], + image_patch_attention_mask=patch_attention_mask[image_attention_mask], return_dict=True, **kwargs, ) @@ -1340,7 +1331,6 @@ def forward( input_ids: torch.LongTensor | None = None, mm_token_type_ids: torch.LongTensor | None = None, vision_patches: torch.Tensor | None = None, - vision_patch_attention_mask: torch.Tensor | None = None, image_patch_attention_mask: torch.Tensor | None = None, vision_token_grids: torch.LongTensor | None = None, image_token_grids: torch.LongTensor | None = None, @@ -1361,10 +1351,8 @@ def forward( follows the standard convention `0 -> text`, `1 -> image`. Treated as text-only when omitted. vision_patches (`torch.FloatTensor`, *optional*): Padded per-image patch vectors of shape `(batch_size, max_images, max_patches, patch_dim)`. - vision_patch_attention_mask (`torch.LongTensor`, *optional*): - Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. image_patch_attention_mask (`torch.LongTensor`, *optional*): - Alias for `vision_patch_attention_mask`. + Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. vision_token_grids (`torch.LongTensor`, *optional*): Per-image patch grids `(h, w)` with shape `(batch_size, max_images, 2)`. image_token_grids (`torch.LongTensor`, *optional*): @@ -1392,7 +1380,7 @@ def forward( image_outputs = self.get_image_features( pixel_values=vision_patches, image_token_grids=vision_token_grids, - image_patch_attention_mask=vision_patch_attention_mask, + image_patch_attention_mask=image_patch_attention_mask, image_token_offsets=vision_token_offsets, image_token_lengths=vision_token_lengths, image_attention_mask=image_attention_mask, @@ -1466,7 +1454,6 @@ def forward( mm_token_type_ids: torch.LongTensor | None = None, vision_patches: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None, - vision_patch_attention_mask: torch.Tensor | None = None, image_patch_attention_mask: torch.Tensor | None = None, vision_token_grids: torch.LongTensor | None = None, image_token_grids: torch.LongTensor | None = None, @@ -1489,10 +1476,8 @@ def forward( Padded per-image patch vectors of shape `(batch_size, max_images, max_patches, patch_dim)`. pixel_values (`torch.FloatTensor`, *optional*): Alias for `vision_patches` accepted by generic image-feature and generation helpers. - vision_patch_attention_mask (`torch.LongTensor`, *optional*): - Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. image_patch_attention_mask (`torch.LongTensor`, *optional*): - Alias for `vision_patch_attention_mask`. + Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. vision_token_grids (`torch.LongTensor`, *optional*): Per-image patch grids `(h, w)` with shape `(batch_size, max_images, 2)`. image_token_grids (`torch.LongTensor`, *optional*): @@ -1509,7 +1494,7 @@ def forward( input_ids=input_ids, mm_token_type_ids=mm_token_type_ids, vision_patches=vision_patches, - vision_patch_attention_mask=vision_patch_attention_mask, + image_patch_attention_mask=image_patch_attention_mask, vision_token_grids=vision_token_grids, vision_token_offsets=vision_token_offsets, vision_token_lengths=vision_token_lengths, @@ -1544,7 +1529,6 @@ def prepare_inputs_for_generation( mm_token_type_ids: torch.LongTensor | None = None, vision_patches: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None, - vision_patch_attention_mask: torch.Tensor | None = None, image_patch_attention_mask: torch.Tensor | None = None, vision_token_grids: torch.LongTensor | None = None, image_token_grids: torch.LongTensor | None = None, @@ -1557,9 +1541,6 @@ def prepare_inputs_for_generation( **kwargs, ) -> dict[str, Any]: if vision_patches is None: - vision_patch_attention_mask = ( - image_patch_attention_mask if vision_patch_attention_mask is None else vision_patch_attention_mask - ) vision_token_grids = image_token_grids if vision_token_grids is None else vision_token_grids model_inputs = super().prepare_inputs_for_generation( input_ids, @@ -1574,7 +1555,7 @@ def prepare_inputs_for_generation( multimodal_inputs = { "mm_token_type_ids": mm_token_type_ids, "vision_patches": vision_patches, - "vision_patch_attention_mask": vision_patch_attention_mask, + "image_patch_attention_mask": image_patch_attention_mask, "vision_token_grids": vision_token_grids, "vision_token_offsets": vision_token_offsets, "vision_token_lengths": vision_token_lengths, diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 92c111f38845..028a35f13cbb 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -303,7 +303,7 @@ def __call__( vision_image_attention_mask = vision_token_lengths.gt(0).to(dtype=torch.long) vision_patches = image_inputs["vision_patches"] - vision_patch_attention_mask = image_inputs["vision_patch_attention_mask"] + image_patch_attention_mask = image_inputs["image_patch_attention_mask"] return BatchFeature( data={ @@ -311,7 +311,7 @@ def __call__( "attention_mask": attention_mask, "mm_token_type_ids": mm_token_type_ids, "vision_patches": vision_patches, - "vision_patch_attention_mask": vision_patch_attention_mask, + "image_patch_attention_mask": image_patch_attention_mask, "vision_token_grids": vision_token_grids, "vision_token_offsets": vision_token_offsets, "vision_token_lengths": vision_token_lengths, diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py index 374ddb1a1d78..01897fc7fe2f 100644 --- a/tests/models/isaac/test_modeling_isaac.py +++ b/tests/models/isaac/test_modeling_isaac.py @@ -244,7 +244,7 @@ def to_model_multimodal_inputs(processor_output, device): keys = ( "mm_token_type_ids", "vision_patches", - "vision_patch_attention_mask", + "image_patch_attention_mask", "vision_token_grids", "vision_token_offsets", "vision_token_lengths", diff --git a/tests/models/isaac/test_processing_isaac.py b/tests/models/isaac/test_processing_isaac.py index 3a67d064b8c9..13275095fef6 100644 --- a/tests/models/isaac/test_processing_isaac.py +++ b/tests/models/isaac/test_processing_isaac.py @@ -43,7 +43,7 @@ "attention_mask", "mm_token_type_ids", "vision_patches", - "vision_patch_attention_mask", + "image_patch_attention_mask", "vision_token_grids", "vision_token_offsets", "vision_token_lengths", @@ -241,7 +241,7 @@ def _assert_common(outputs, batch_size=1): attention_mask = outputs["attention_mask"] mm_token_type_ids = outputs["mm_token_type_ids"] vision_patches = outputs["vision_patches"] - vision_patch_attention_mask = outputs["vision_patch_attention_mask"] + image_patch_attention_mask = outputs["image_patch_attention_mask"] vision_token_grids = outputs["vision_token_grids"] vision_token_offsets = outputs["vision_token_offsets"] vision_token_lengths = outputs["vision_token_lengths"] @@ -254,7 +254,7 @@ def _assert_common(outputs, batch_size=1): assert attention_mask.dtype == torch.long assert mm_token_type_ids.dtype == torch.long - assert vision_patches.shape[:2] == vision_patch_attention_mask.shape[:2] + assert vision_patches.shape[:2] == image_patch_attention_mask.shape[:2] assert vision_patches.shape[0] == batch_size assert vision_token_grids.shape == (batch_size, vision_patches.shape[1], 2) assert vision_token_offsets.shape == (batch_size, vision_patches.shape[1]) @@ -265,7 +265,7 @@ def _assert_common(outputs, batch_size=1): def _assert_no_vision(outputs, batch_index=0): - assert outputs["vision_patch_attention_mask"][batch_index].sum().item() == 0 + assert outputs["image_patch_attention_mask"][batch_index].sum().item() == 0 assert outputs["vision_token_grids"][batch_index].sum().item() == 0 assert outputs["vision_token_offsets"][batch_index].sum().item() == 0 assert outputs["vision_token_lengths"][batch_index].sum().item() == 0 @@ -277,7 +277,7 @@ def _assert_vision_segments(outputs, expected_segments, batch_index=0): active_segments = int(outputs["vision_image_attention_mask"][batch_index].sum().item()) assert active_segments == expected_segments assert torch.all(outputs["vision_token_lengths"][batch_index, :expected_segments] > 0) - assert torch.all(outputs["vision_patch_attention_mask"][batch_index, :expected_segments].sum(dim=-1) > 0) + assert torch.all(outputs["image_patch_attention_mask"][batch_index, :expected_segments].sum(dim=-1) > 0) def _count_modality(outputs, modality_value, batch_index=0): From ed8fc0a18d30dda3db7267577264f72aa5090d18 Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Wed, 25 Mar 2026 02:35:10 +0400 Subject: [PATCH 68/77] style: further mask threading simplification + processing docstring (#19) * style: remove fast import from modeling test * refactor: drop image_attention_mask from external interface * style: drop rescale factor from overall processor attrs; no longer used --- .../models/isaac/modeling_isaac.py | 23 +++------------- .../models/isaac/modular_isaac.py | 26 +++---------------- .../models/isaac/processing_isaac.py | 18 ++++++------- tests/models/isaac/test_modeling_isaac.py | 8 ++---- 4 files changed, 16 insertions(+), 59 deletions(-) diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index dd468f76e41a..c81564eab759 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -1161,7 +1161,6 @@ def get_image_features( image_patch_attention_mask: torch.Tensor | None = None, image_token_offsets: torch.Tensor | None = None, image_token_lengths: torch.Tensor | None = None, - image_attention_mask: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ @@ -1176,18 +1175,13 @@ def get_image_features( Start offsets inside each per-image embedding sequence, shaped `(batch_size, max_images)`. image_token_lengths (`torch.Tensor`, *optional*): Number of image tokens to gather per image for placeholder scattering, shaped `(batch_size, max_images)`. - image_attention_mask (`torch.Tensor`, *optional*): - Mask indicating which image slots are populated, shaped `(batch_size, max_images)`. """ image_token_grids = image_token_grids.to(dtype=torch.long) patch_attention_mask = image_patch_attention_mask.to(dtype=torch.long) - if image_attention_mask is None: - if image_token_lengths is not None: - image_attention_mask = image_token_lengths > 0 - else: - image_attention_mask = image_token_grids.any(dim=-1) + if image_token_lengths is not None: + image_attention_mask = image_token_lengths > 0 else: - image_attention_mask = image_attention_mask.to(dtype=torch.bool) + image_attention_mask = image_token_grids.any(dim=-1) torch_compilable_check( image_attention_mask.any(), @@ -1415,7 +1409,6 @@ def forward( image_token_grids: torch.LongTensor | None = None, vision_token_offsets: torch.LongTensor | None = None, vision_token_lengths: torch.LongTensor | None = None, - image_attention_mask: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: list[torch.FloatTensor] | None = None, @@ -1440,9 +1433,6 @@ def forward( Start offsets inside the per-image vision embedding sequence, shape `(batch_size, max_images)`. vision_token_lengths (`torch.LongTensor`, *optional*): Number of vision tokens to consume per image, shape `(batch_size, max_images)`. - image_attention_mask (`torch.LongTensor`, *optional*): - Backward-compatible override for populated image slots. When omitted, the model derives it from - `vision_token_lengths > 0`. """ created_inputs_embeds = inputs_embeds is None if created_inputs_embeds: @@ -1462,7 +1452,6 @@ def forward( image_patch_attention_mask=image_patch_attention_mask, image_token_offsets=vision_token_offsets, image_token_lengths=vision_token_lengths, - image_attention_mask=image_attention_mask, return_dict=True, ) image_features = image_outputs.pooler_output.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype) @@ -1543,7 +1532,6 @@ def forward( image_token_grids: torch.LongTensor | None = None, vision_token_offsets: torch.LongTensor | None = None, vision_token_lengths: torch.LongTensor | None = None, - image_attention_mask: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: list[torch.FloatTensor] | None = None, @@ -1570,9 +1558,6 @@ def forward( Start offsets inside the per-image vision embedding sequence, shape `(batch_size, max_images)`. vision_token_lengths (`torch.LongTensor`, *optional*): Number of vision tokens to consume per image, shape `(batch_size, max_images)`. - image_attention_mask (`torch.LongTensor`, *optional*): - Backward-compatible override for populated image slots. When omitted, the model derives it from - `vision_token_lengths > 0`. """ outputs = self.model( input_ids=input_ids, @@ -1582,7 +1567,6 @@ def forward( vision_token_grids=vision_token_grids, vision_token_offsets=vision_token_offsets, vision_token_lengths=vision_token_lengths, - image_attention_mask=image_attention_mask, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -1618,7 +1602,6 @@ def prepare_inputs_for_generation( image_token_grids: torch.LongTensor | None = None, vision_token_offsets: torch.LongTensor | None = None, vision_token_lengths: torch.LongTensor | None = None, - image_attention_mask: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, is_first_iteration=False, use_cache=True, diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 1536e1ddc1d5..f37d82f6f23a 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -820,7 +820,6 @@ def __init__( chat_template: str | dict[str, str] | None = None, vision_token: str = "", max_sequence_length: int = 16384, - rescale_factor: float | None = None, ): """ Args: @@ -830,8 +829,6 @@ def __init__( Placeholder token used inside text prompts to mark image positions. max_sequence_length (`int`, *optional*, defaults to 16384): Maximum packed multimodal sequence length produced by the processor. - rescale_factor (`float`, *optional*): - Deprecated compatibility argument accepted for backward compatibility. """ if chat_template is None: chat_template = getattr(tokenizer, "chat_template", None) @@ -1082,7 +1079,6 @@ def get_image_features( image_patch_attention_mask: torch.Tensor | None = None, image_token_offsets: torch.Tensor | None = None, image_token_lengths: torch.Tensor | None = None, - image_attention_mask: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ @@ -1097,18 +1093,13 @@ def get_image_features( Start offsets inside each per-image embedding sequence, shaped `(batch_size, max_images)`. image_token_lengths (`torch.Tensor`, *optional*): Number of image tokens to gather per image for placeholder scattering, shaped `(batch_size, max_images)`. - image_attention_mask (`torch.Tensor`, *optional*): - Mask indicating which image slots are populated, shaped `(batch_size, max_images)`. """ image_token_grids = image_token_grids.to(dtype=torch.long) patch_attention_mask = image_patch_attention_mask.to(dtype=torch.long) - if image_attention_mask is None: - if image_token_lengths is not None: - image_attention_mask = image_token_lengths > 0 - else: - image_attention_mask = image_token_grids.any(dim=-1) + if image_token_lengths is not None: + image_attention_mask = image_token_lengths > 0 else: - image_attention_mask = image_attention_mask.to(dtype=torch.bool) + image_attention_mask = image_token_grids.any(dim=-1) torch_compilable_check( image_attention_mask.any(), @@ -1336,7 +1327,6 @@ def forward( image_token_grids: torch.LongTensor | None = None, vision_token_offsets: torch.LongTensor | None = None, vision_token_lengths: torch.LongTensor | None = None, - image_attention_mask: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: list[torch.FloatTensor] | None = None, @@ -1361,9 +1351,6 @@ def forward( Start offsets inside the per-image vision embedding sequence, shape `(batch_size, max_images)`. vision_token_lengths (`torch.LongTensor`, *optional*): Number of vision tokens to consume per image, shape `(batch_size, max_images)`. - image_attention_mask (`torch.LongTensor`, *optional*): - Backward-compatible override for populated image slots. When omitted, the model derives it from - `vision_token_lengths > 0`. """ created_inputs_embeds = inputs_embeds is None if created_inputs_embeds: @@ -1383,7 +1370,6 @@ def forward( image_patch_attention_mask=image_patch_attention_mask, image_token_offsets=vision_token_offsets, image_token_lengths=vision_token_lengths, - image_attention_mask=image_attention_mask, return_dict=True, ) image_features = image_outputs.pooler_output.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype) @@ -1459,7 +1445,6 @@ def forward( image_token_grids: torch.LongTensor | None = None, vision_token_offsets: torch.LongTensor | None = None, vision_token_lengths: torch.LongTensor | None = None, - image_attention_mask: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: list[torch.FloatTensor] | None = None, @@ -1486,9 +1471,6 @@ def forward( Start offsets inside the per-image vision embedding sequence, shape `(batch_size, max_images)`. vision_token_lengths (`torch.LongTensor`, *optional*): Number of vision tokens to consume per image, shape `(batch_size, max_images)`. - image_attention_mask (`torch.LongTensor`, *optional*): - Backward-compatible override for populated image slots. When omitted, the model derives it from - `vision_token_lengths > 0`. """ outputs = self.model( input_ids=input_ids, @@ -1498,7 +1480,6 @@ def forward( vision_token_grids=vision_token_grids, vision_token_offsets=vision_token_offsets, vision_token_lengths=vision_token_lengths, - image_attention_mask=image_attention_mask, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -1534,7 +1515,6 @@ def prepare_inputs_for_generation( image_token_grids: torch.LongTensor | None = None, vision_token_offsets: torch.LongTensor | None = None, vision_token_lengths: torch.LongTensor | None = None, - image_attention_mask: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, is_first_iteration=False, use_cache=True, diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 028a35f13cbb..ddc36bc2abac 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -132,17 +132,15 @@ def __init__( chat_template: str | dict[str, str] | None = None, vision_token: str = "", max_sequence_length: int = 16384, - rescale_factor: float | None = None, ): - r""" - chat_template (`str` or `dict[str, str]`, *optional*): - Chat template override forwarded to [`~processing_utils.ProcessorMixin`]. - vision_token (`str`, *optional*, defaults to `""`): - Placeholder token used inside text prompts to mark image positions. - max_sequence_length (`int`, *optional*, defaults to 16384): - Maximum packed multimodal sequence length produced by the processor. - rescale_factor (`float`, *optional*): - Deprecated compatibility argument accepted for backward compatibility. + """ + Args: + chat_template (`str` or `dict[str, str]`, *optional*): + Chat template override forwarded to [`~processing_utils.ProcessorMixin`]. + vision_token (`str`, *optional*, defaults to `""`): + Placeholder token used inside text prompts to mark image positions. + max_sequence_length (`int`, *optional*, defaults to 16384): + Maximum packed multimodal sequence length produced by the processor. """ if chat_template is None: chat_template = getattr(tokenizer, "chat_template", None) diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py index 01897fc7fe2f..1c19f79474a6 100644 --- a/tests/models/isaac/test_modeling_isaac.py +++ b/tests/models/isaac/test_modeling_isaac.py @@ -38,7 +38,7 @@ ) from transformers.image_utils import load_image from transformers.masking_utils import create_bidirectional_mask -from transformers.models.isaac.image_processing_isaac_fast import IsaacImageProcessor +from transformers.models.isaac.image_processing_isaac import IsaacImageProcessor from transformers.models.isaac.modeling_isaac import ( IsaacVisionAttention, IsaacVisionConfig, @@ -230,7 +230,6 @@ def create_isaac_processor( processor_params = { "vision_token": isaac_config.vision_token, "max_sequence_length": isaac_config.max_sequence_length, - "rescale_factor": isaac_config.vision_rescale_factor, } return IsaacProcessor( @@ -439,7 +438,6 @@ def prepare_config_and_inputs_for_common(self): (self.batch_size, 1, num_image_patches), device=torch_device, dtype=torch.long ), "image_token_grids": torch.tensor([[[2, 2]]] * self.batch_size, device=torch_device, dtype=torch.long), - "image_attention_mask": torch.ones((self.batch_size, 1), device=torch_device, dtype=torch.long), } if labels is not None: inputs_dict["labels"] = labels @@ -534,7 +532,6 @@ def test_get_image_features_pooler_output_is_scatter_ready(self): dtype=torch.long, ) image_patch_attention_mask = torch.ones((2, 2, 4), device=torch_device, dtype=torch.long) - image_attention_mask = torch.tensor([[1, 1], [1, 0]], device=torch_device, dtype=torch.long) image_token_offsets = torch.tensor([[1, 0], [2, 0]], device=torch_device, dtype=torch.long) image_token_lengths = torch.tensor([[2, 1], [1, 0]], device=torch_device, dtype=torch.long) @@ -543,7 +540,6 @@ def test_get_image_features_pooler_output_is_scatter_ready(self): pixel_values=pixel_values, image_token_grids=image_token_grids, image_patch_attention_mask=image_patch_attention_mask, - image_attention_mask=image_attention_mask, image_token_offsets=image_token_offsets, image_token_lengths=image_token_lengths, return_dict=True, @@ -1082,7 +1078,7 @@ def setUp(self): self.checkpoint = _reference_checkpoint_or_skip() self.hf_config = IsaacConfig.from_pretrained(self.checkpoint, revision=MODEL_REVISION) self.tokenizer = Qwen2Tokenizer.from_pretrained( - self.checkpoint, trust_remote_code=True, use_fast=False, revision=MODEL_REVISION + self.checkpoint, trust_remote_code=False, use_fast=False, revision=MODEL_REVISION ) self.processor = create_isaac_processor(self.tokenizer, self.hf_config) self.hf_config.vision_config._attn_implementation = "flash_attention_2" From caf377c5d68e6aa7898e33f48fcf7fa2c6e4f71f Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Wed, 25 Mar 2026 02:41:42 +0400 Subject: [PATCH 69/77] test: update tests --- .../isaac/test_image_processing_isaac.py | 32 +- tests/models/isaac/test_processing_isaac.py | 726 ------------------ 2 files changed, 16 insertions(+), 742 deletions(-) delete mode 100644 tests/models/isaac/test_processing_isaac.py diff --git a/tests/models/isaac/test_image_processing_isaac.py b/tests/models/isaac/test_image_processing_isaac.py index 27ee2448f8d3..5884a8bf1e8f 100644 --- a/tests/models/isaac/test_image_processing_isaac.py +++ b/tests/models/isaac/test_image_processing_isaac.py @@ -143,15 +143,15 @@ def _assert_output_contract( ): self.assertEqual( set(encoding.keys()), - {"vision_patches", "vision_patch_attention_mask", "vision_token_grids"}, + {"vision_patches", "image_patch_attention_mask", "vision_token_grids"}, ) vision_patches = encoding["vision_patches"] - vision_patch_attention_mask = encoding["vision_patch_attention_mask"] + image_patch_attention_mask = encoding["image_patch_attention_mask"] vision_token_grids = encoding["vision_token_grids"] self.assertEqual(vision_patches.dtype, torch.float32) - self.assertEqual(vision_patch_attention_mask.dtype, torch.long) + self.assertEqual(image_patch_attention_mask.dtype, torch.long) self.assertEqual(vision_token_grids.dtype, torch.long) if expected_batch_size is not None: @@ -161,13 +161,13 @@ def _assert_output_contract( if expected_patch_dim is not None: self.assertEqual(vision_patches.shape[-1], expected_patch_dim) - self.assertEqual(tuple(vision_patch_attention_mask.shape), tuple(vision_patches.shape[:-1])) + self.assertEqual(tuple(image_patch_attention_mask.shape), tuple(vision_patches.shape[:-1])) self.assertEqual(tuple(vision_token_grids.shape), tuple(vision_patches.shape[:2]) + (2,)) expected_patch_counts = torch.prod(vision_token_grids, dim=-1) - torch.testing.assert_close(vision_patch_attention_mask.sum(dim=-1), expected_patch_counts) + torch.testing.assert_close(image_patch_attention_mask.sum(dim=-1), expected_patch_counts) - padded_patch_rows = vision_patches[vision_patch_attention_mask == 0] + padded_patch_rows = vision_patches[image_patch_attention_mask == 0] if padded_patch_rows.numel() > 0: self.assertTrue(torch.all(padded_patch_rows == 0)) @@ -179,8 +179,8 @@ def _assert_encoding_close(self, eager_encoding, compiled_encoding): rtol=1e-4, ) torch.testing.assert_close( - eager_encoding["vision_patch_attention_mask"], - compiled_encoding["vision_patch_attention_mask"], + eager_encoding["image_patch_attention_mask"], + compiled_encoding["image_patch_attention_mask"], ) torch.testing.assert_close(eager_encoding["vision_token_grids"], compiled_encoding["vision_token_grids"]) @@ -319,7 +319,7 @@ def test_nested_multi_image_batch_preserves_grids_and_padding(self): ) torch.testing.assert_close(encoding["vision_token_grids"], expected_grids) - torch.testing.assert_close(encoding["vision_patch_attention_mask"].sum(dim=-1), expected_patch_counts) + torch.testing.assert_close(encoding["image_patch_attention_mask"].sum(dim=-1), expected_patch_counts) def test_all_empty_images_returns_zero_sized_tensors(self): for image_processing_class in self.image_processing_classes.values(): @@ -327,13 +327,13 @@ def test_all_empty_images_returns_zero_sized_tensors(self): encoding = image_processor([[], []], return_tensors="pt") self.assertEqual( - set(encoding.keys()), {"vision_patches", "vision_patch_attention_mask", "vision_token_grids"} + set(encoding.keys()), {"vision_patches", "image_patch_attention_mask", "vision_token_grids"} ) self.assertEqual(tuple(encoding["vision_patches"].shape), (2, 0, 0, 0)) - self.assertEqual(tuple(encoding["vision_patch_attention_mask"].shape), (2, 0, 0)) + self.assertEqual(tuple(encoding["image_patch_attention_mask"].shape), (2, 0, 0)) self.assertEqual(tuple(encoding["vision_token_grids"].shape), (2, 0, 2)) self.assertEqual(encoding["vision_patches"].dtype, torch.float32) - self.assertEqual(encoding["vision_patch_attention_mask"].dtype, torch.long) + self.assertEqual(encoding["image_patch_attention_mask"].dtype, torch.long) self.assertEqual(encoding["vision_token_grids"].dtype, torch.long) def test_do_resize_false_requires_patch_divisibility(self): @@ -371,19 +371,19 @@ def test_cast_dtype_device(self): encoding = image_processor(image_inputs, return_tensors="pt") self.assertEqual(encoding["vision_patches"].device, torch.device("cpu")) self.assertEqual(encoding["vision_patches"].dtype, torch.float32) - self.assertEqual(encoding["vision_patch_attention_mask"].dtype, torch.long) + self.assertEqual(encoding["image_patch_attention_mask"].dtype, torch.long) self.assertEqual(encoding["vision_token_grids"].dtype, torch.long) encoding = image_processor(image_inputs, return_tensors="pt").to(torch.float16) self.assertEqual(encoding["vision_patches"].device, torch.device("cpu")) self.assertEqual(encoding["vision_patches"].dtype, torch.float16) - self.assertEqual(encoding["vision_patch_attention_mask"].dtype, torch.long) + self.assertEqual(encoding["image_patch_attention_mask"].dtype, torch.long) self.assertEqual(encoding["vision_token_grids"].dtype, torch.long) encoding = image_processor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16) self.assertEqual(encoding["vision_patches"].device, torch.device("cpu")) self.assertEqual(encoding["vision_patches"].dtype, torch.bfloat16) - self.assertEqual(encoding["vision_patch_attention_mask"].dtype, torch.long) + self.assertEqual(encoding["image_patch_attention_mask"].dtype, torch.long) self.assertEqual(encoding["vision_token_grids"].dtype, torch.long) with self.assertRaises(TypeError): @@ -395,7 +395,7 @@ def test_cast_dtype_device(self): self.assertEqual(encoding["vision_patches"].device, torch.device("cpu")) self.assertEqual(encoding["vision_patches"].dtype, torch.float16) - self.assertEqual(encoding["vision_patch_attention_mask"].dtype, torch.long) + self.assertEqual(encoding["image_patch_attention_mask"].dtype, torch.long) self.assertEqual(encoding["vision_token_grids"].dtype, torch.long) self.assertEqual(encoding["input_ids"].dtype, torch.long) diff --git a/tests/models/isaac/test_processing_isaac.py b/tests/models/isaac/test_processing_isaac.py deleted file mode 100644 index 13275095fef6..000000000000 --- a/tests/models/isaac/test_processing_isaac.py +++ /dev/null @@ -1,726 +0,0 @@ -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Testing suite for the Isaac processor.""" - -import os -import re -import unittest -from pathlib import Path - -import pytest -import torch -from huggingface_hub import is_offline_mode - -from transformers import IsaacConfig, PythonBackend -from transformers.models.isaac.image_processing_isaac import IsaacImageProcessor -from transformers.models.isaac.processing_isaac import IsaacProcessor -from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_vision_available - -from ...test_processing_common import ProcessorTesterMixin - - -if is_vision_available(): - from PIL import Image -else: - Image = None - - -ISAAC_OUTPUT_KEYS = { - "input_ids", - "attention_mask", - "mm_token_type_ids", - "vision_patches", - "image_patch_attention_mask", - "vision_token_grids", - "vision_token_offsets", - "vision_token_lengths", - "vision_image_attention_mask", -} - - -class SimpleIsaacTokenizer(PythonBackend): - vocab_files_names = {} - model_input_names = ["input_ids"] - - def __init__(self): - self._vocab = { - "": 0, - "": 1, - "": 2, - "": 3, - "": 4, - "<|image_pad|>": 5, - } - self._ids_to_tokens = {idx: tok for tok, idx in self._vocab.items()} - super().__init__( - bos_token="", - eos_token="", - pad_token="", - unk_token="", - additional_special_tokens=[""], - extra_special_tokens={"image_pad_token": "<|image_pad|>"}, - model_max_length=512, - ) - - def get_vocab(self): - return dict(self._vocab) - - def _tokenize(self, text): - clean = text.replace("\n", " ").strip() - if not clean: - return [] - - special_tokens = sorted( - (token for token in self._vocab if token.startswith("<") and token.endswith(">")), - key=len, - reverse=True, - ) - if not special_tokens: - return [token for token in clean.split(" ") if token] - - split_pattern = "(" + "|".join(re.escape(token) for token in special_tokens) + ")" - tokens = [] - for chunk in re.split(split_pattern, clean): - if not chunk or chunk.isspace(): - continue - if chunk in self._vocab: - tokens.append(chunk) - else: - tokens.extend(token for token in chunk.split(" ") if token) - return tokens - - def _convert_token_to_id(self, token): - if token not in self._vocab: - next_id = len(self._vocab) - self._vocab[token] = next_id - self._ids_to_tokens[next_id] = token - return self._vocab[token] - - def _convert_id_to_token(self, index): - return self._ids_to_tokens.get(index, self.unk_token) - - @property - def vocab_size(self) -> int: - return len(self._vocab) - - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - if token_ids_1 is not None: - token_ids_0 = token_ids_0 + token_ids_1 - return [self.bos_token_id] + list(token_ids_0) + [self.eos_token_id] - - def save_vocabulary(self, save_directory, filename_prefix=None): - return () - - -class SimpleIsaacTokenizerWithNamedImagePad(PythonBackend): - vocab_files_names = {} - model_input_names = ["input_ids"] - - def __init__(self): - self._vocab = { - "": 0, - "": 1, - "": 2, - "": 3, - "": 4, - "": 5, - "<|image_pad|>": 6, - } - self._ids_to_tokens = {idx: tok for tok, idx in self._vocab.items()} - super().__init__( - bos_token="", - eos_token="", - pad_token="", - unk_token="", - extra_special_tokens={"image_pad_token": ""}, - model_max_length=512, - ) - - def get_vocab(self): - return dict(self._vocab) - - def _tokenize(self, text): - clean = text.replace("\n", " ").strip() - if not clean: - return [] - - special_tokens = sorted( - (token for token in self._vocab if token.startswith("<") and token.endswith(">")), - key=len, - reverse=True, - ) - split_pattern = "(" + "|".join(re.escape(token) for token in special_tokens) + ")" - tokens = [] - for chunk in re.split(split_pattern, clean): - if not chunk or chunk.isspace(): - continue - if chunk in self._vocab: - tokens.append(chunk) - else: - tokens.extend(token for token in chunk.split(" ") if token) - return tokens - - def _convert_token_to_id(self, token): - return self._vocab.get(token, self._vocab[""]) - - def _convert_id_to_token(self, index): - return self._ids_to_tokens.get(index, self.unk_token) - - @property - def vocab_size(self) -> int: - return len(self._vocab) - - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - if token_ids_1 is not None: - token_ids_0 = token_ids_0 + token_ids_1 - return [self.bos_token_id] + list(token_ids_0) + [self.eos_token_id] - - def save_vocabulary(self, save_directory, filename_prefix=None): - return () - - -def _make_dummy_image(size=(32, 32), color=(255, 0, 0)): - if Image is None: - raise RuntimeError("PIL.Image is not available in this environment.") - return Image.new("RGB", size, color=color) - - -def _make_processor_with_max_len(tokenizer, base_config, max_len): - config = IsaacConfig(**base_config.to_dict()) - config.max_sequence_length = max_len - vision_config = config.vision_config - image_processor = IsaacImageProcessor( - patch_size=vision_config.patch_size, - max_num_patches=vision_config.num_patches, - pixel_shuffle_scale=vision_config.pixel_shuffle_scale_factor, - rescale_factor=config.vision_rescale_factor, - ) - return IsaacProcessor( - image_processor=image_processor, - tokenizer=tokenizer, - vision_token=config.vision_token, - max_sequence_length=config.max_sequence_length, - ) - - -def _run_processor(processor, text, images=None): - return processor(text=text, images=images, return_tensors="pt") - - -def _make_post_process_processor(): - return IsaacProcessor(image_processor=IsaacImageProcessor(), tokenizer=SimpleIsaacTokenizer()) - - -def test_processor_prefers_named_image_pad_token(): - processor = IsaacProcessor( - image_processor=IsaacImageProcessor(), tokenizer=SimpleIsaacTokenizerWithNamedImagePad() - ) - - assert processor.image_token == "" - assert processor.image_pad_token_id == processor.tokenizer.image_pad_token_id - assert processor.image_pad_token_id != processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") - - -def _assert_common(outputs, batch_size=1): - assert set(outputs.keys()) == ISAAC_OUTPUT_KEYS - - input_ids = outputs["input_ids"] - attention_mask = outputs["attention_mask"] - mm_token_type_ids = outputs["mm_token_type_ids"] - vision_patches = outputs["vision_patches"] - image_patch_attention_mask = outputs["image_patch_attention_mask"] - vision_token_grids = outputs["vision_token_grids"] - vision_token_offsets = outputs["vision_token_offsets"] - vision_token_lengths = outputs["vision_token_lengths"] - vision_image_attention_mask = outputs["vision_image_attention_mask"] - - assert input_ids.shape[0] == batch_size - assert attention_mask.shape == input_ids.shape - assert mm_token_type_ids.shape == input_ids.shape - assert input_ids.dtype == torch.long - assert attention_mask.dtype == torch.long - assert mm_token_type_ids.dtype == torch.long - - assert vision_patches.shape[:2] == image_patch_attention_mask.shape[:2] - assert vision_patches.shape[0] == batch_size - assert vision_token_grids.shape == (batch_size, vision_patches.shape[1], 2) - assert vision_token_offsets.shape == (batch_size, vision_patches.shape[1]) - assert vision_token_lengths.shape == (batch_size, vision_patches.shape[1]) - assert vision_image_attention_mask.shape == (batch_size, vision_patches.shape[1]) - - return outputs - - -def _assert_no_vision(outputs, batch_index=0): - assert outputs["image_patch_attention_mask"][batch_index].sum().item() == 0 - assert outputs["vision_token_grids"][batch_index].sum().item() == 0 - assert outputs["vision_token_offsets"][batch_index].sum().item() == 0 - assert outputs["vision_token_lengths"][batch_index].sum().item() == 0 - assert outputs["vision_image_attention_mask"][batch_index].sum().item() == 0 - assert not outputs["mm_token_type_ids"][batch_index].eq(1).any() - - -def _assert_vision_segments(outputs, expected_segments, batch_index=0): - active_segments = int(outputs["vision_image_attention_mask"][batch_index].sum().item()) - assert active_segments == expected_segments - assert torch.all(outputs["vision_token_lengths"][batch_index, :expected_segments] > 0) - assert torch.all(outputs["image_patch_attention_mask"][batch_index, :expected_segments].sum(dim=-1) > 0) - - -def _count_modality(outputs, modality_value, batch_index=0): - return int( - (outputs["attention_mask"][batch_index].bool() & outputs["mm_token_type_ids"][batch_index].eq(modality_value)) - .sum() - .item() - ) - - -def _get_active_vision_grids(outputs, batch_index=0): - mask = outputs["vision_image_attention_mask"][batch_index].bool() - return outputs["vision_token_grids"][batch_index][mask] - - -def _get_active_vision_lengths(outputs, batch_index=0): - mask = outputs["vision_image_attention_mask"][batch_index].bool() - return outputs["vision_token_lengths"][batch_index][mask] - - -@pytest.fixture -def isaac_tiny_config(): - text_config = { - "bos_token_id": 0, - "eos_token_id": 1, - "pad_token_id": 2, - "hidden_act": "silu", - "head_dim": 32 // 4, - "hidden_size": 32, - "vocab_size": 99, - "intermediate_size": 32 * 3, - "max_position_embeddings": 128, - "model_type": "qwen3", - "num_attention_heads": 4, - "num_hidden_layers": 2, - "num_key_value_heads": 4, - "rope_parameters": {"rope_type": "default", "mrope_section": [2, 1, 1], "mrope_interleaved": True}, - "tie_word_embeddings": True, - } - - vision_config = { - "hidden_size": 32, - "intermediate_size": 32 * 2, - "num_hidden_layers": 1, - "num_attention_heads": 4, - "num_channels": 3, - "num_patches": 64, - "patch_size": 4, - "pixel_shuffle_scale_factor": 1, - "attention_dropout": 0.0, - "layer_norm_eps": 1e-6, - } - - config = IsaacConfig(text_config=text_config, vision_config=vision_config) - config._attn_implementation = "sdpa" - config.text_config._attn_implementation = "sdpa" - config.vision_attn_implementation = "sdpa" - return config - - -@pytest.fixture -def isaac_tokenizer(): - return SimpleIsaacTokenizer() - - -@pytest.fixture -def isaac_processor(isaac_tokenizer, isaac_tiny_config): - vision_config = isaac_tiny_config.vision_config - image_processor = IsaacImageProcessor( - patch_size=vision_config.patch_size, - max_num_patches=vision_config.num_patches, - pixel_shuffle_scale=vision_config.pixel_shuffle_scale_factor, - rescale_factor=isaac_tiny_config.vision_rescale_factor, - ) - return IsaacProcessor( - image_processor=image_processor, - tokenizer=isaac_tokenizer, - vision_token=isaac_tiny_config.vision_token, - max_sequence_length=isaac_tiny_config.max_sequence_length, - ) - - -BASE_MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1-Base") -BASE_MODEL_REVISION = os.environ.get("ISAAC_TEST_MODEL_REVISION", "refs/pr/3") or None -LOCAL_CHECKPOINT = os.environ.get("ISAAC_TEST_MODEL_PATH") - - -def _checkpoint_or_skip(model_id=BASE_MODEL_ID): - if LOCAL_CHECKPOINT: - resolved = Path(LOCAL_CHECKPOINT).expanduser() - if not resolved.exists(): - pytest.skip(f"Local checkpoint path {resolved} does not exist.") - return str(resolved) - if is_offline_mode(): - pytest.skip("Offline mode: set ISAAC_TEST_MODEL_PATH to a local checkpoint to run these tests.") - return model_id - - -@require_torch -@require_vision -class IsaacProcessorTest(ProcessorTesterMixin, unittest.TestCase): - processor_class = IsaacProcessor - model_id = BASE_MODEL_ID - images_input_name = "vision_patches" - - @classmethod - def _setup_from_pretrained(cls, model_id, **kwargs): - checkpoint = _checkpoint_or_skip(model_id) - return super()._setup_from_pretrained( - checkpoint, - revision=BASE_MODEL_REVISION, - patch_size=4, - max_num_patches=4, - **kwargs, - ) - - @classmethod - def _setup_test_attributes(cls, processor): - cls.image_token = processor.vision_token - cls.pad_token_id = processor.tokenizer.pad_token_id - cls.image_pad_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") - - def prepare_image_inputs(self, batch_size: int | None = None, nested: bool = False): - if batch_size is None: - return _make_dummy_image(size=(16, 16)) - images = [_make_dummy_image(size=(16, 16), color=(50 * (i + 1), 0, 0)) for i in range(batch_size)] - if nested: - return [[image] for image in images] - return images - - def test_model_input_names(self): - processor = self.get_processor() - inputs = processor( - text=self.prepare_text_inputs(modalities="image"), - images=self.prepare_image_inputs(), - return_tensors="pt", - ) - - expected_input_names = set(processor.model_input_names) | { - "mm_token_type_ids", - "vision_token_offsets", - "vision_token_lengths", - "vision_image_attention_mask", - } - self.assertSetEqual(set(inputs.keys()), expected_input_names) - - @unittest.skip("IsaacProcessor expands image placeholders into image pad tokens before tokenization") - def test_tokenizer_defaults(self): - pass - - @unittest.skip("IsaacProcessor does not return offset mappings needed for assistant masks") - def test_apply_chat_template_assistant_mask(self): - pass - - def test_single_vs_batched_consistency(self): - processor = self.get_processor() - prompt = f"hello {processor.vision_token} world" - image = self.prepare_image_inputs() - - single = _assert_common(processor(text=prompt, images=[image], return_tensors="pt")) - batch = _assert_common( - processor(text=[prompt, "short"], images=[[image], []], return_tensors="pt"), batch_size=2 - ) - - single_ids = single["input_ids"].squeeze(0) - batch_ids = batch["input_ids"][0] - self.assertTrue(torch.equal(batch_ids[-single_ids.size(0) :], single_ids)) - - image_positions = batch["mm_token_type_ids"][0].eq(1) - if image_positions.any(): - self.assertTrue(torch.all(batch_ids[image_positions] == self.image_pad_token_id)) - self.assertTrue(torch.all(batch["attention_mask"][0][image_positions] == 1)) - - _assert_vision_segments(batch, expected_segments=1, batch_index=0) - _assert_no_vision(batch, batch_index=1) - - -@require_torch -@require_vision -def test_text_only_has_no_vision_fields(isaac_processor): - outputs = _assert_common(_run_processor(isaac_processor, text="Hello, how are you?", images=None)) - _assert_no_vision(outputs) - - -@require_torch -def test_post_process_generation_extracts_boxes_and_cleans_text(): - processor = _make_post_process_processor() - - generated_text = ( - "No, it is not safe to cross the street. " - '(808, 247), (863, 386)' - ) - - clean_text, annotations = processor.post_process_generation(generated_text) - - assert clean_text == "No, it is not safe to cross the street." - assert len(annotations) == 1 - box = annotations[0] - assert box.mention == "traffic light" - assert box.t == pytest.approx(0.5) - assert box.top_left.x == 808 - assert box.top_left.y == 247 - assert box.bottom_right.x == 863 - assert box.bottom_right.y == 386 - - -@require_torch -def test_post_process_generation_extracts_polygons_and_filters_by_expected_type(): - processor = _make_post_process_processor() - - generated_text = ( - 'Point (1, 2) ' - 'Box (3, 4), (5, 6) ' - 'Polygon (10, 20), (30, 40), (50, 60)' - ) - - clean_text, annotations = processor.post_process_generation(generated_text, expected="polygon") - - assert clean_text == "Point Box Polygon" - assert len(annotations) == 1 - polygon = annotations[0] - assert polygon.mention == "lane" - assert polygon.t == pytest.approx(0.25) - assert len(polygon.points) == 3 - assert polygon.points[0].x == 10 - assert polygon.points[0].y == 20 - assert polygon.points[1].x == 30 - assert polygon.points[1].y == 40 - assert polygon.points[2].x == 50 - assert polygon.points[2].y == 60 - - _, boxes = processor.post_process_generation(generated_text, expected="box") - assert len(boxes) == 1 - assert boxes[0].mention == "sign" - - -@require_torch -def test_post_process_generation_rejects_polygons_with_fewer_than_three_points(): - processor = _make_post_process_processor() - - with pytest.raises(ValueError, match=r"Malformed tag"): - processor.post_process_generation('(10, 20), (30, 40)', expected="polygon") - - -@require_torch -@require_vision -def test_single_image_returns_offsets_and_lengths(isaac_processor): - vision_token = isaac_processor.vision_token - outputs = _assert_common( - _run_processor( - isaac_processor, text=f"Look at this {vision_token} and describe it.", images=[_make_dummy_image()] - ) - ) - _assert_vision_segments(outputs, expected_segments=1) - - grid_tokens = torch.prod(_get_active_vision_grids(outputs), dim=-1) - torch.testing.assert_close(_get_active_vision_lengths(outputs), grid_tokens) - torch.testing.assert_close( - outputs["vision_token_offsets"][0, :1], torch.zeros_like(outputs["vision_token_offsets"][0, :1]) - ) - - -@require_torch -@require_vision -def test_multiple_images_have_matching_offsets_lengths_and_grids(isaac_processor): - vision_token = isaac_processor.vision_token - images = [_make_dummy_image(color=(255, 0, 0)), _make_dummy_image(color=(0, 255, 0))] - - outputs = _assert_common( - _run_processor(isaac_processor, text=f"First {vision_token} then {vision_token}", images=images) - ) - _assert_vision_segments(outputs, expected_segments=2) - - grid_tokens = torch.prod(_get_active_vision_grids(outputs), dim=-1) - torch.testing.assert_close(_get_active_vision_lengths(outputs), grid_tokens) - torch.testing.assert_close( - outputs["vision_token_offsets"][0, :2], torch.zeros_like(outputs["vision_token_offsets"][0, :2]) - ) - - -@require_torch -@require_vision -def test_error_on_image_mismatch(isaac_processor): - vision_token = isaac_processor.vision_token - with pytest.raises(ValueError, match="one image per"): - _run_processor(isaac_processor, text=f"{vision_token} {vision_token}", images=[_make_dummy_image()]) - - -@require_torch -@require_vision -def test_consecutive_vision_tokens_allow_empty_text_segments(isaac_processor): - vision_token = isaac_processor.vision_token - images = [_make_dummy_image(), _make_dummy_image(color=(0, 0, 255))] - - outputs = _assert_common( - _run_processor(isaac_processor, text=f"prefix {vision_token}{vision_token} suffix", images=images) - ) - _assert_vision_segments(outputs, expected_segments=2) - - torch.testing.assert_close( - outputs["vision_token_offsets"][0, :2], torch.zeros_like(outputs["vision_token_offsets"][0, :2]) - ) - grid_tokens = torch.prod(_get_active_vision_grids(outputs), dim=-1) - torch.testing.assert_close(_get_active_vision_lengths(outputs), grid_tokens) - - -@require_torch -@require_vision -def test_device_and_dtype_consistency(isaac_processor): - vision_token = isaac_processor.vision_token - outputs = _assert_common( - _run_processor(isaac_processor, text=f"Describe this {vision_token}", images=[_make_dummy_image()]) - ) - _assert_vision_segments(outputs, expected_segments=1) - - tensors = [ - outputs["input_ids"], - outputs["attention_mask"], - outputs["mm_token_type_ids"], - outputs["vision_token_offsets"], - outputs["vision_token_lengths"], - outputs["vision_token_grids"], - ] - devices = {tensor.device for tensor in tensors} - assert len(devices) == 1 - for tensor in tensors: - assert tensor.dtype == torch.long - - -@require_torch -@require_vision -def test_no_crop_when_total_below_max(isaac_processor): - vision_token = isaac_processor.vision_token - outputs = _assert_common( - _run_processor(isaac_processor, text=f"hello {vision_token} world", images=[_make_dummy_image()]) - ) - _assert_vision_segments(outputs, expected_segments=1) - - grid_tokens = torch.prod(_get_active_vision_grids(outputs), dim=-1) - text_tokens = _count_modality(outputs, 0) - assert outputs["input_ids"].shape[1] == grid_tokens.item() + text_tokens - - -@require_torch -@require_vision -def test_exact_fit_keeps_all_tokens(isaac_processor, isaac_tokenizer, isaac_tiny_config): - vision_token = isaac_processor.vision_token - text = f"hey {vision_token} there" - image = _make_dummy_image() - - base_outputs = _assert_common(_run_processor(isaac_processor, text=text, images=[image])) - base_length = base_outputs["input_ids"].shape[1] - base_vision_length = _get_active_vision_lengths(base_outputs).item() - - processor = _make_processor_with_max_len(isaac_tokenizer, isaac_tiny_config, base_length) - outputs = _assert_common(_run_processor(processor, text=text, images=[image])) - - _assert_vision_segments(outputs, expected_segments=1) - assert outputs["input_ids"].shape[1] == base_length - assert _get_active_vision_lengths(outputs).item() == base_vision_length - - -@require_torch -@require_vision -def test_crop_truncates_text_segment_only(isaac_processor, isaac_tokenizer, isaac_tiny_config): - vision_token = isaac_processor.vision_token - text_prefix_tokens = " ".join([f"t{i}" for i in range(8)]) - text = f"{text_prefix_tokens} {vision_token} tail end" - image = _make_dummy_image() - - base_outputs = _assert_common(_run_processor(isaac_processor, text=text, images=[image])) - full_text_tokens = _count_modality(base_outputs, 0) - vision_length = _get_active_vision_lengths(base_outputs).item() - - max_len = base_outputs["input_ids"].shape[1] - 4 - processor = _make_processor_with_max_len(isaac_tokenizer, isaac_tiny_config, max_len) - outputs = _assert_common(_run_processor(processor, text=text, images=[image])) - - _assert_vision_segments(outputs, expected_segments=1) - assert outputs["input_ids"].shape[1] == max_len - assert _count_modality(outputs, 0) == full_text_tokens - 4 - torch.testing.assert_close( - outputs["vision_token_offsets"][0, :1], torch.zeros_like(outputs["vision_token_offsets"][0, :1]) - ) - assert _get_active_vision_lengths(outputs).item() == vision_length - - -@require_torch -@require_vision -def test_crop_cuts_through_image_segment(isaac_processor, isaac_tokenizer, isaac_tiny_config): - vision_token = isaac_processor.vision_token - text_before = "hi" - text_after = "bye" - text = f"{text_before} {vision_token} {text_after}" - image = _make_dummy_image() - - base_outputs = _assert_common(_run_processor(isaac_processor, text=text, images=[image])) - vision_full = _get_active_vision_lengths(base_outputs).item() - text_before_len = len(isaac_tokenizer.encode(text_before, add_special_tokens=False)) - text_after_len = len(isaac_tokenizer.encode(text_after, add_special_tokens=False)) - total_length = vision_full + text_before_len + text_after_len - - max_len = 40 - start = total_length - max_len - expected_offset = max(0, start - text_before_len) - expected_length = vision_full - expected_offset - - processor = _make_processor_with_max_len(isaac_tokenizer, isaac_tiny_config, max_len) - outputs = _assert_common(_run_processor(processor, text=text, images=[image])) - - _assert_vision_segments(outputs, expected_segments=1) - assert outputs["input_ids"].shape[1] == max_len - assert outputs["vision_token_offsets"][0, 0].item() == expected_offset - assert _get_active_vision_lengths(outputs).item() == expected_length - assert _count_modality(outputs, 0) == text_after_len - - -@require_torch -@require_vision -def test_batch_outputs_match_individual_calls(isaac_processor): - texts = ["hi", "this one is longer"] - - per_sample = [_assert_common(_run_processor(isaac_processor, text=text, images=None)) for text in texts] - batch_outputs = _assert_common(_run_processor(isaac_processor, text=texts, images=None), batch_size=len(texts)) - - pad_id = isaac_processor.pad_token_id - for index, single_output in enumerate(per_sample): - single_ids = single_output["input_ids"].squeeze(0) - single_mask = single_output["attention_mask"].squeeze(0) - single_mm = single_output["mm_token_type_ids"].squeeze(0) - - batch_ids = batch_outputs["input_ids"][index] - batch_mask = batch_outputs["attention_mask"][index] - batch_mm = batch_outputs["mm_token_type_ids"][index] - - single_len = single_ids.shape[0] - assert torch.equal(batch_ids[-single_len:], single_ids) - assert torch.equal(batch_mask[-single_len:], single_mask) - assert torch.equal(batch_mm[-single_len:], single_mm) - - if single_len < batch_ids.shape[0]: - pad_span = batch_ids[: batch_ids.shape[0] - single_len] - assert torch.all(pad_span == pad_id) - assert not torch.any(batch_mask[: batch_ids.shape[0] - single_len]) - - _assert_no_vision(batch_outputs, batch_index=index) From 8b96e5f2d74f6e6bb0cf96bd4104a861c0c05de3 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Tue, 31 Mar 2026 14:21:32 +0400 Subject: [PATCH 70/77] Squash merge pg/additional_cleanup into main --- docs/source/en/model_doc/isaac.md | 41 +- src/transformers/conversion_mapping.py | 4 + .../models/isaac/configuration_isaac.py | 22 +- .../models/isaac/image_processing_isaac.py | 74 +- .../models/isaac/modeling_isaac.py | 1063 ++++++++++------- .../models/isaac/modular_isaac.py | 886 +++++++------- .../models/isaac/processing_isaac.py | 192 +-- .../isaac/test_image_processing_isaac.py | 129 +- tests/models/isaac/test_modeling_isaac.py | 654 ++++++++-- .../isaac/test_post_processing_isaac.py | 102 -- tests/models/isaac/test_processing_isaac.py | 879 ++++++++++++++ 11 files changed, 2764 insertions(+), 1282 deletions(-) delete mode 100644 tests/models/isaac/test_post_processing_isaac.py create mode 100644 tests/models/isaac/test_processing_isaac.py diff --git a/docs/source/en/model_doc/isaac.md b/docs/source/en/model_doc/isaac.md index 27bccc1dd670..ce2ee88866d8 100644 --- a/docs/source/en/model_doc/isaac.md +++ b/docs/source/en/model_doc/isaac.md @@ -37,7 +37,7 @@ weights before using them in commercial settings. ## Usage -Isaac uses explicit image placeholders in the rendered prompt. Every occurrence of `processor.vision_token` (usually +Isaac uses explicit image placeholders in the rendered prompt. Every occurrence of `processor.image_token` (usually ``) must have a matching image in the `images` argument. ```py @@ -57,7 +57,7 @@ model = IsaacForConditionalGeneration.from_pretrained( images = [Image.open("chart.png"), Image.open("panel.jpg")] messages = [ {"role": "user", "content": "Compare the two figures and explain what changed."}, - {"role": "user", "content": f"{processor.vision_token}{processor.vision_token}"}, + {"role": "user", "content": f"{processor.image_token}{processor.image_token}"}, ] prompt = processor.apply_chat_template( @@ -67,17 +67,11 @@ prompt = processor.apply_chat_template( ).strip() inputs = processor(text=prompt, images=images, return_tensors="pt") -multimodal_keys = ( - "input_ids", - "attention_mask", - "mm_token_type_ids", - "vision_patches", - "vision_patch_attention_mask", - "vision_token_grids", - "vision_token_offsets", - "vision_token_lengths", -) -model_inputs = {key: inputs[key].to(model.device) for key in multimodal_keys} +model_inputs = { + key: value.to(model.device) + for key, value in inputs.items() + if value is not None +} with torch.inference_mode(): generated_ids = model.generate( @@ -93,18 +87,25 @@ response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] print(response) ``` -`IsaacProcessor` returns standard multimodal tensors that can be passed directly to the model, including `input_ids`, -`attention_mask`, `mm_token_type_ids`, `vision_patches`, `vision_patch_attention_mask`, `vision_token_grids`, -`vision_token_offsets`, `vision_token_lengths`, and `vision_image_attention_mask`. +`IsaacProcessor` returns the standard text tensors plus Isaac's batch-major visual tensors: + +- `pixel_values`: `(batch_size, max_images, max_patches, patch_dim)` +- `image_grid_thw`: `(batch_size, max_images, 3)` +- `image_metadata`: `(batch_size, max_images, 2)` storing `(offset, length)` for each image slot +- `mm_token_type_ids`: `(batch_size, sequence_length)` Important notes: - Pass the full processor output to `generate()`. Isaac uses the multimodal tensors during prefill and handles cached decoding internally. -- Batched inputs can mix text-only and multimodal samples. For batched multimodal inputs, pass images as a nested list such - as `[[image_a], [image_b, image_c], []]`. -- If truncation is enabled, the processor keeps the rightmost part of the packed multimodal sequence and updates - `vision_token_offsets` and `vision_token_lengths` automatically. +- For fully text-only batches, `pixel_values`, `image_grid_thw`, and `image_metadata` are `None`. When moving inputs to + the model, keep only non-`None` values as shown above. +- Batched inputs can mix text-only and multimodal samples. For direct processor/model batching, pass images as a nested + list such as `[[], [image_a], [image_b, image_c]]`. +- `image_grid_thw[batch_idx, image_slot] == (0, 0, 0)` marks a padded empty slot. Real image slots have + `(T=1, H>0, W>0)`. +- If truncation is enabled, the processor keeps the rightmost part of the multimodal prompt and updates the slot-local + `image_metadata[..., 0]` and `image_metadata[..., 1]` values automatically. ### Post-processing grounded outputs diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index acca0cc4316c..d029459feab7 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -104,6 +104,10 @@ def _build_checkpoint_conversion_mapping(): source_patterns=r"(?= self.max_window_layers - else "full_attention" - for i in range(self.num_hidden_layers) - ] - super().__post_init__(**kwargs) + self.layer_types = ["full_attention" for _ in range(self.num_hidden_layers)] + + PretrainedConfig.__post_init__(self, **kwargs) self.validate_layer_type() @auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") -@strict(accept_kwargs=True) +@strict class IsaacConfig(PretrainedConfig): r""" vision_config (`IsaacVisionConfig` or `dict`, *optional*): @@ -146,9 +140,6 @@ class IsaacConfig(PretrainedConfig): Rescale factor applied by the image processor before normalization. max_sequence_length (`int`, *optional*, defaults to 16384): Maximum multimodal sequence length produced by the processor and expected by the model. - vision_token (`str`, *optional*, defaults to `""`): - Placeholder string inserted into text prompts to mark image positions. - Example: ```python @@ -166,7 +157,6 @@ class IsaacConfig(PretrainedConfig): text_config: IsaacTextConfig | dict | None = None vision_rescale_factor: float = 1 / 255 max_sequence_length: int = 16384 - vision_token: str = "" def __post_init__(self, **kwargs): for key in ("use_cache", "rope_theta", "max_position_embeddings"): diff --git a/src/transformers/models/isaac/image_processing_isaac.py b/src/transformers/models/isaac/image_processing_isaac.py index 7beaca45be08..460ea04452d6 100644 --- a/src/transformers/models/isaac/image_processing_isaac.py +++ b/src/transformers/models/isaac/image_processing_isaac.py @@ -177,9 +177,8 @@ class IsaacImageProcessor(TorchvisionBackend): resample = PILImageResampling.BILINEAR model_input_names = [ - "vision_patches", - "image_patch_attention_mask", - "vision_token_grids", + "pixel_values", + "image_grid_thw", ] valid_kwargs = IsaacImageProcessorKwargs @@ -221,37 +220,65 @@ def resize( return image.clamp(0, 255).round().to(torch.uint8) return F.interpolate(image, size=(size.height, size.width), mode="bilinear", align_corners=False) - def pad( + def get_number_of_image_patches( + self, + image_height: int, + image_width: int, + images_kwargs: dict[str, Any] | None = None, + ) -> int: + images_kwargs = images_kwargs or {} + patch_size = images_kwargs.get("patch_size", self.patch_size) + max_num_patches = images_kwargs.get("max_num_patches", self.max_num_patches) + min_num_patches = images_kwargs.get("min_num_patches", self.min_num_patches) + pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale", self.pixel_shuffle_scale) + + target_height, target_width = get_image_size_for_max_num_patches( + image_height, + image_width, + patch_size, + max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + ) + return (target_height // patch_size) * (target_width // patch_size) + + def pack_images( self, vision_patches: list[list[torch.Tensor]], vision_token_grids: list[list[torch.Tensor]], - ) -> dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor | None]: batch_size = len(vision_patches) - first_patch = next(patches for sample_patches in vision_patches for patches in sample_patches) - max_images = max(len(sample_patches) for sample_patches in vision_patches) - max_patches = max(patches.shape[0] for sample_patches in vision_patches for patches in sample_patches) + max_images = max((len(sample_patches) for sample_patches in vision_patches), default=0) + flat_patches = [patches for sample_patches in vision_patches for patches in sample_patches] + if max_images == 0 or not flat_patches: + return { + "pixel_values": None, + "image_grid_thw": None, + } + + first_patch = flat_patches[0] + max_patches = max(patches.shape[0] for patches in flat_patches) patch_dim = first_patch.shape[-1] patch_dtype = first_patch.dtype patch_device = first_patch.device tensors = { - "vision_patches": torch.zeros( - (batch_size, max_images, max_patches, patch_dim), device=patch_device, dtype=patch_dtype - ), - "image_patch_attention_mask": torch.zeros( - (batch_size, max_images, max_patches), device=patch_device, dtype=torch.long + "pixel_values": torch.zeros( + (batch_size, max_images, max_patches, patch_dim), + device=patch_device, + dtype=patch_dtype, ), - "vision_token_grids": torch.zeros((batch_size, max_images, 2), device=patch_device, dtype=torch.long), + "image_grid_thw": torch.zeros((batch_size, max_images, 3), device=patch_device, dtype=torch.long), } for batch_idx, (sample_patches, sample_token_grids) in enumerate( zip(vision_patches, vision_token_grids, strict=True) ): - for image_idx, (patches, token_grid) in enumerate(zip(sample_patches, sample_token_grids, strict=True)): + for image_slot, (patches, token_grid) in enumerate(zip(sample_patches, sample_token_grids, strict=True)): patch_count = int(patches.shape[0]) - tensors["vision_patches"][batch_idx, image_idx, :patch_count] = patches - tensors["image_patch_attention_mask"][batch_idx, image_idx, :patch_count] = 1 - tensors["vision_token_grids"][batch_idx, image_idx] = token_grid + tensors["pixel_values"][batch_idx, image_slot, :patch_count] = patches + tensors["image_grid_thw"][batch_idx, image_slot, 0] = 1 + tensors["image_grid_thw"][batch_idx, image_slot, 1:] = token_grid return tensors @@ -275,15 +302,12 @@ def _preprocess( **kwargs, ) -> BatchFeature: resample = kwargs.pop("interpolation", resample) - batch_size = len(images) # IsaacProcessor routes text-only calls here as an empty image list per sample. - # This returns empty vision tensors to preserve the multimodal output schema; - # image-token/image-count mismatches are validated earlier in processor's _preprocess call. + # Return `None` visual fields so text-only batches skip multimodal codepaths like other VLMs. if all(len(sample_images) == 0 for sample_images in images): tensors = { - "vision_patches": torch.zeros((batch_size, 0, 0, 0), dtype=torch.float32), - "image_patch_attention_mask": torch.zeros((batch_size, 0, 0), dtype=torch.long), - "vision_token_grids": torch.zeros((batch_size, 0, 2), dtype=torch.long), + "pixel_values": None, + "image_grid_thw": None, } return BatchFeature(data=tensors, tensor_type=return_tensors) @@ -352,7 +376,7 @@ def _preprocess( if not do_pad: raise ValueError("IsaacImageProcessor doesn't support `do_pad=False` mode.") - tensors = self.pad( + tensors = self.pack_images( vision_patches=nested_outputs["vision_patches"], vision_token_grids=nested_outputs["vision_token_grids"], ) diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index c81564eab759..01afede9a528 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -19,6 +19,7 @@ # limitations under the License. from collections.abc import Callable +from dataclasses import dataclass from typing import Any, NamedTuple, Optional from ... import initialization as init @@ -29,15 +30,9 @@ from ...masking_utils import create_bidirectional_mask, create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPast, - BaseModelOutputWithPooling, - CausalLMOutputWithPast, -) +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...models.qwen3.modeling_qwen3 import Qwen3PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, torch_compilable_check from ...utils.generic import TransformersKwargs, can_return_tuple, maybe_autocast, merge_with_config_defaults @@ -169,7 +164,7 @@ def resize_positional_embeddings( def forward( self, pixel_values: torch.Tensor, - spatial_shapes: torch.Tensor, + image_grid_thw: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: """ @@ -185,13 +180,16 @@ def forward( resized_positional_embeddings = self.resize_positional_embeddings( self.position_embedding, - spatial_shapes, + image_grid_thw[:, 1:], max_length=pixel_values.shape[1], ) + resized_positional_embeddings = resized_positional_embeddings.to( + device=patch_embeds.device, dtype=patch_embeds.dtype + ) embeddings = patch_embeds + resized_positional_embeddings if attention_mask is not None: - embeddings = embeddings * attention_mask.unsqueeze(-1).to(dtype=embeddings.dtype) + embeddings = embeddings * attention_mask.unsqueeze(-1).to(device=embeddings.device, dtype=embeddings.dtype) return embeddings @@ -361,7 +359,7 @@ def forward( def pixel_shuffle_padded( - x: torch.Tensor, + hidden_states: torch.Tensor, token_grids: torch.Tensor, scale_factor: int = 1, ) -> torch.Tensor: @@ -379,10 +377,10 @@ def pixel_shuffle_padded( `torch.Tensor`: Pixel-shuffled embeddings of shape `(num_images, max_tokens, hidden_size * scale_factor**2)`. """ - num_images, max_patches, embed_dim = x.shape + num_images, max_patches, embed_dim = hidden_states.shape output_dim = embed_dim * scale_factor * scale_factor - token_grids = token_grids.to(device=x.device, dtype=torch.long) + token_grids = token_grids.to(device=hidden_states.device, dtype=torch.long) heights = token_grids[:, 0] widths = token_grids[:, 1] full_lengths = heights * widths @@ -397,9 +395,11 @@ def pixel_shuffle_padded( output_lengths = (heights // scale_factor) * (widths // scale_factor) max_output_tokens = output_lengths.max() - shuffled_4d = x.new_zeros((num_images, max_output_tokens, scale_factor * scale_factor, embed_dim)) + shuffled_4d = hidden_states.new_zeros((num_images, max_output_tokens, scale_factor * scale_factor, embed_dim)) - token_positions = torch.arange(max_patches, device=x.device, dtype=torch.long).unsqueeze(0).expand(num_images, -1) + token_positions = ( + torch.arange(max_patches, device=hidden_states.device, dtype=torch.long).unsqueeze(0).expand(num_images, -1) + ) valid_token_mask = token_positions < full_lengths.unsqueeze(1) safe_widths = torch.where(widths > 0, widths, torch.ones_like(widths)) @@ -411,10 +411,12 @@ def pixel_shuffle_padded( output_index = output_index + col_index.div(scale_factor, rounding_mode="floor") sub_index = row_index.remainder(scale_factor) * scale_factor + col_index.remainder(scale_factor) - batch_index = torch.arange(num_images, device=x.device, dtype=torch.long).unsqueeze(1).expand_as(token_positions) - shuffled_4d[batch_index[valid_token_mask], output_index[valid_token_mask], sub_index[valid_token_mask]] = x[ - valid_token_mask - ] + batch_index = ( + torch.arange(num_images, device=hidden_states.device, dtype=torch.long).unsqueeze(1).expand_as(token_positions) + ) + shuffled_4d[batch_index[valid_token_mask], output_index[valid_token_mask], sub_index[valid_token_mask]] = ( + hidden_states[valid_token_mask] + ) shuffled = shuffled_4d.view(num_images, max_output_tokens, output_dim) return shuffled @@ -456,26 +458,28 @@ def _init_weights(self, module): @capture_outputs(tie_last_hidden_states=False) def forward( self, - vision_patches: torch.Tensor, - vision_token_grids: torch.Tensor, - image_patch_attention_mask: torch.Tensor, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: """ Inputs: - vision_patches (`torch.Tensor`): + pixel_values (`torch.Tensor`): Patches shaped `(num_images, max_patches, patch_dim)`. - vision_token_grids (`torch.Tensor`): - Token grids shaped `(num_images, 2)` with per-image `(H_tokens, W_tokens)`. - image_patch_attention_mask (`torch.Tensor`): - Patch mask shaped `(num_images, max_patches)`. + image_grid_thw (`torch.Tensor`): + Grid tensor shaped `(num_images, 3)` with per-image `(T=1, H_tokens, W_tokens)`. Returns: `BaseModelOutputWithPooling` with pixel-shuffled embeddings in `last_hidden_state`. """ + vision_token_grids = image_grid_thw[:, 1:].to(dtype=torch.long) + full_lengths = vision_token_grids[:, 0] * vision_token_grids[:, 1] + token_positions = torch.arange(pixel_values.shape[1], device=pixel_values.device, dtype=torch.long) + image_patch_attention_mask = token_positions.unsqueeze(0) < full_lengths.unsqueeze(1) + image_patch_attention_mask = image_patch_attention_mask.to(dtype=torch.long) hidden_states = self.embeddings( - vision_patches, - vision_token_grids, + pixel_values, + image_grid_thw, attention_mask=image_patch_attention_mask, ) @@ -488,7 +492,7 @@ def forward( hidden_states = self.post_layernorm(encoder_outputs.last_hidden_state) hidden_states = pixel_shuffle_padded( - x=hidden_states, + hidden_states=hidden_states, token_grids=vision_token_grids, scale_factor=self.pixel_shuffle_scale_factor, ) @@ -983,325 +987,152 @@ def _deepstack_process( return hidden_states -@use_kernel_forward_from_hub("RMSNorm") -class IsaacRMSNorm(nn.Module): - def __init__(self, hidden_size, eps: float = 1e-6) -> None: - """ - IsaacRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -@use_kernelized_func(apply_rotary_pos_emb) -class IsaacAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: IsaacConfig, layer_idx: int): - super().__init__() - self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) - self.q_norm = IsaacRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! - self.k_norm = IsaacRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape - self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: torch.Tensor | None, - past_key_values: Cache | None = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) - key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_values is not None: - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) - - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( - self.config._attn_implementation, eager_attention_forward - ) - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - sliding_window=self.sliding_window, # diff with Llama - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - -class IsaacDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: IsaacConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = IsaacAttention(config=config, layer_idx=layer_idx) - - self.mlp = IsaacMLP(config) - self.input_layernorm = IsaacRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = IsaacRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - use_cache: bool | None = False, - position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, - **kwargs: Unpack[TransformersKwargs], - ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - # Self Attention - hidden_states, _ = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=use_cache, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states - - @auto_docstring -class IsaacModel(PreTrainedModel): - config: IsaacConfig +class IsaacModel(IsaacPreTrainedModel): base_model_prefix = "model" + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: IsaacConfig + _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionEncoderLayer"] + input_modalities = ("image", "text") supports_gradient_checkpointing = True - _no_split_modules = ["IsaacDecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn = True - _supports_sdpa = True - _supports_flex_attn = False _can_compile_fullgraph = False - _supports_attention_backend = True - _can_record_outputs = { - "hidden_states": IsaacDecoderLayer, - "attentions": IsaacAttention, - } + _supports_flex_attn = False _tied_weights_keys = {} - _input_embed_layer = "text_model.embed_tokens" + _input_embed_layer = "language_model.embed_tokens" def __init__(self, config: IsaacConfig): - Qwen3PreTrainedModel.__init__(self, config) - self.text_model = IsaacTextModel._from_config(config.text_config) - - self.vision_tower = IsaacVisionTransformer(config.vision_config) + super().__init__(config) + self.language_model = IsaacTextModel._from_config(config.text_config) + self.visual = IsaacVisionTransformer(config.vision_config) self.multimodal_projector = IsaacMultiModalProjector(config) self.max_sequence_length = config.max_sequence_length self.vision_rescale_factor = config.vision_rescale_factor - self.vision_token = config.vision_token self.rope_deltas = None self.post_init() - @can_return_tuple - @auto_docstring - def get_image_features( - self, - pixel_values: torch.Tensor, - image_token_grids: torch.Tensor, - image_patch_attention_mask: torch.Tensor | None = None, - image_token_offsets: torch.Tensor | None = None, - image_token_lengths: torch.Tensor | None = None, - **kwargs: Unpack[TransformersKwargs], - ) -> tuple | BaseModelOutputWithPooling: - """ - Args: - pixel_values (`torch.Tensor`): - Padded per-image patch vectors with shape `(batch_size, max_images, max_patches, patch_dim)`. - image_token_grids (`torch.Tensor`): - Per-image token grids shaped `(batch_size, max_images, 2)` with `(height, width)` entries. - image_patch_attention_mask (`torch.Tensor`, *optional*): - Mask for valid patch rows in `pixel_values`, shaped `(batch_size, max_images, max_patches)`. - image_token_offsets (`torch.Tensor`, *optional*): - Start offsets inside each per-image embedding sequence, shaped `(batch_size, max_images)`. - image_token_lengths (`torch.Tensor`, *optional*): - Number of image tokens to gather per image for placeholder scattering, shaped `(batch_size, max_images)`. - """ - image_token_grids = image_token_grids.to(dtype=torch.long) - patch_attention_mask = image_patch_attention_mask.to(dtype=torch.long) - if image_token_lengths is not None: - image_attention_mask = image_token_lengths > 0 - else: - image_attention_mask = image_token_grids.any(dim=-1) - - torch_compilable_check( - image_attention_mask.any(), - "IsaacModel.get_image_features expects at least one active image slot; text-only inputs should skip this method.", - ) - - batch_size, max_images = pixel_values.shape[:2] - hidden_size = self.config.get_text_config().hidden_size - - vision_outputs = self.vision_tower( - vision_patches=pixel_values[image_attention_mask], - vision_token_grids=image_token_grids[image_attention_mask], - image_patch_attention_mask=patch_attention_mask[image_attention_mask], - return_dict=True, - **kwargs, - ) - flat_projected_features = self.multimodal_projector(vision_outputs.last_hidden_state) - max_tokens = flat_projected_features.shape[1] - projected_features = flat_projected_features.new_zeros((batch_size, max_images, max_tokens, hidden_size)) - projected_features[image_attention_mask] = flat_projected_features - feature_device = flat_projected_features.device - offsets = ( - image_token_offsets.to(dtype=torch.long) - if image_token_offsets is not None - else torch.zeros((batch_size, max_images), device=feature_device, dtype=torch.long) - ) - lengths = ( - image_token_lengths.to(dtype=torch.long) - if image_token_lengths is not None - else torch.full((batch_size, max_images), max_tokens, device=feature_device, dtype=torch.long) - ) - flat_offsets = offsets[image_attention_mask] - flat_lengths = lengths[image_attention_mask] - token_positions = torch.arange(flat_lengths.max(), device=feature_device, dtype=torch.long) - gather_positions = flat_offsets[:, None] + token_positions[None, :] - gather_mask = token_positions[None, :] < flat_lengths[:, None] - image_features = flat_projected_features[ - torch.arange(flat_projected_features.shape[0], device=feature_device, dtype=torch.long)[:, None], - gather_positions, - ][gather_mask] - hidden_states = vision_outputs.hidden_states - attentions = vision_outputs.attentions - - return BaseModelOutputWithPooling( - last_hidden_state=projected_features, - pooler_output=image_features, - hidden_states=hidden_states, - attentions=attentions, - ) + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() - def get_placeholder_mask( - self, - mm_token_type_ids: torch.LongTensor, - inputs_embeds: torch.FloatTensor, - image_features: torch.FloatTensor, - ) -> torch.BoolTensor: - image_token_mask = mm_token_type_ids.to(dtype=torch.long) == 1 - n_image_tokens = image_token_mask.sum() - image_token_mask = image_token_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - torch_compilable_check( - inputs_embeds[image_token_mask].numel() == image_features.numel(), - f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", - ) - return image_token_mask + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) def get_vision_position_ids( self, start_position: int, - grid_hw: torch.LongTensor, - token_offset: int, - token_length: int, + grid_thw: torch.LongTensor, + image_metadata: torch.LongTensor, ) -> torch.LongTensor: - height, width = grid_hw[0].item(), grid_hw[1].item() - token_positions = torch.arange(height * width, device=grid_hw.device, dtype=torch.long) + """ + Compute 3D positional indices for vision tokens derived from a single image or video input. + + The positions are generated from the input grid defined by temporal (T), height (H), and + width (W) dimensions. Temporal and spatial dimensions can be downscaled according to the + merge sizes used in the vision backbone. The resulting positions are offset by `start_position`. + + Args: + start_position (`int`): + Offset added to all computed positional indices. + grid_thw (`Sequence[int]` or `torch.Tensor` of shape `(3,)`): + The (T, H, W) grid representing the feature layout of the current image or video after patch embedding. + temp_merge_size (`int`, *optional*): + Factor by which the temporal dimension is reduced in the backbone. The temporal grid size is divided + by this value. Defaults to 1. + spatial_merge_size (`int`, *optional*): + Factor by which the spatial dimensions (H and W) are reduced in the backbone. Both H and W are divided + by this value. Defaults to 1. + time_interval (`int`, *optional*): + Spacing factor applied between consecutive temporal position indices.Defaults to 1. + device (`str` or `torch.device`, *optional*): + Device on which the resulting tensor is allocated. If `None`, uses the current default device. + + Returns: + torch.LongTensor of shape (3, sequence_length): + Positional indices for temporal, height, and width dimensions, + flattened into sequence form and offset by `start_position`. + """ + pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor + height = grid_thw[1].div(pixel_shuffle_scale, rounding_mode="floor").item() + width = grid_thw[2].div(pixel_shuffle_scale, rounding_mode="floor").item() + token_positions = torch.arange(height * width, device=grid_thw.device, dtype=torch.long) vision_position_ids = torch.stack( ( - torch.full((token_positions.shape[0],), start_position, device=grid_hw.device, dtype=torch.long), + torch.full((token_positions.shape[0],), start_position, device=grid_thw.device, dtype=torch.long), token_positions.div(width, rounding_mode="floor"), token_positions.remainder(width), ), dim=0, ) + token_offset = int(image_metadata[0].item()) + token_length = int(image_metadata[1].item()) return vision_position_ids[:, token_offset : token_offset + token_length] def get_rope_index( self, + input_ids: torch.LongTensor | None, mm_token_type_ids: torch.Tensor, - image_token_grids: torch.Tensor, - image_token_offsets: torch.Tensor, - image_token_lengths: torch.Tensor, - attention_mask: torch.Tensor, - inputs_embeds: torch.Tensor | None = None, + image_grid_thw: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + image_metadata: torch.Tensor | None = None, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: - """Prepare multimodal RoPE positions for the current prefill sequence. + """ + Difference from Qwen2VL/Qwen2.5VL's get_rope_index: + - Since Qwen3.5 use timestamps to seperate videos, like , the video_grid_thw should also be split too. - Unlike vanilla 1D RoPE, Isaac builds 3-axis indices for text and vision tokens. - If callers do not supply positions, we synthesize text-style positions from - `attention_mask`. The returned `rope_deltas` capture any custom offset between - the attended sequence length and Isaac's multimodal positions so decode steps can - keep counting forward from the cached prefix.""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + mm_token_type_ids (`torch.IntTensor` of shape `(batch_size, sequence_length)`): + Token type ids matching each modality to a different value in the input sequence, i.e. text (0), image (1), video (2). + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + if image_grid_thw is None or image_metadata is None: + raise ValueError("Isaac multimodal RoPE requires both `image_grid_thw` and `image_metadata`.") + + if attention_mask is None: + if input_ids is None: + attention_mask = mm_token_type_ids.new_ones(mm_token_type_ids.shape, dtype=torch.long) + else: + attention_mask = input_ids.new_ones(input_ids.shape, dtype=torch.long) + + if input_ids is None: + batch_size, seq_len = attention_mask.shape + position_dtype = torch.long + else: + batch_size, seq_len = input_ids.shape + position_dtype = input_ids.dtype device = attention_mask.device - batch_size, seq_len = attention_mask.shape mm_token_type_ids = mm_token_type_ids.to(dtype=torch.long) - image_token_grids = image_token_grids.to(dtype=torch.long) - image_token_offsets = image_token_offsets.to(dtype=torch.long) - image_token_lengths = image_token_lengths.to(dtype=torch.long) + image_grid_thw = image_grid_thw.to(dtype=torch.long) + image_metadata = image_metadata.to(dtype=torch.long) attention_mask = attention_mask.to(dtype=torch.long) - image_attention_mask = image_token_lengths > 0 + active_slot_mask = image_grid_thw[..., 0].eq(1) - position_ids = torch.zeros((3, batch_size, seq_len), device=device, dtype=torch.long) + position_ids = torch.zeros((3, batch_size, seq_len), device=device, dtype=position_dtype) rope_deltas = torch.zeros((batch_size, 1), device=device, dtype=torch.long) - pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor for batch_idx in range(batch_size): sample_attention_mask = attention_mask[batch_idx].bool() sample_token_types = mm_token_type_ids[batch_idx][sample_attention_mask] - sample_grids = image_token_grids[batch_idx][image_attention_mask[batch_idx]] - sample_offsets = image_token_offsets[batch_idx][image_attention_mask[batch_idx]] - sample_lengths = image_token_lengths[batch_idx][image_attention_mask[batch_idx]] + sample_grids = image_grid_thw[batch_idx] + sample_metadata = image_metadata[batch_idx] + sample_active_slots = active_slot_mask[batch_idx] current_pos = 0 image_idx = 0 @@ -1310,56 +1141,191 @@ def get_rope_index( while seq_pos < sample_token_types.shape[0]: modality_type = int(sample_token_types[seq_pos].item()) - group_end = seq_pos + 1 - while group_end < sample_token_types.shape[0] and sample_token_types[group_end] == modality_type: - group_end += 1 - - group_length = group_end - seq_pos if modality_type == 0: + group_end = seq_pos + 1 + while group_end < sample_token_types.shape[0] and sample_token_types[group_end] == 0: + group_end += 1 + group_length = group_end - seq_pos llm_pos_ids_list.append( torch.arange(group_length, device=device, dtype=torch.long).view(1, -1).expand(3, -1) + current_pos ) current_pos += group_length + seq_pos = group_end else: - grid_hw = sample_grids[image_idx].div(pixel_shuffle_scale, rounding_mode="floor") - token_offset = int(sample_offsets[image_idx].item()) - token_length = int(sample_lengths[image_idx].item()) + while image_idx < sample_metadata.shape[0] and ( + not bool(sample_active_slots[image_idx].item()) or sample_metadata[image_idx, 1].item() == 0 + ): + image_idx += 1 + torch_compilable_check( + image_idx < sample_metadata.shape[0], + "Isaac multimodal sequence has more visible image tokens than batch-major image metadata slots.", + ) + token_length = int(sample_metadata[image_idx, 1].item()) + torch_compilable_check( + token_length <= sample_token_types.shape[0] - seq_pos, + "Isaac image metadata length exceeds the remaining multimodal placeholder span.", + ) llm_pos_ids_list.append( - self.get_vision_position_ids(current_pos, grid_hw, token_offset, token_length) + self.get_vision_position_ids(current_pos, sample_grids[image_idx], sample_metadata[image_idx]) ) current_pos += 1 + seq_pos += token_length image_idx += 1 - seq_pos = group_end - - llm_positions = torch.cat(llm_pos_ids_list, dim=1) + llm_positions = ( + torch.cat(llm_pos_ids_list, dim=1) + if llm_pos_ids_list + else torch.zeros((3, 0), device=device, dtype=torch.long) + ) position_ids[:, batch_idx, sample_attention_mask] = llm_positions - rope_deltas[batch_idx, 0] = llm_positions.max() + 1 - sample_token_types.shape[0] + rope_deltas[batch_idx, 0] = ( + llm_positions.max() + 1 - sample_token_types.shape[0] if llm_positions.numel() > 0 else 0 + ) return position_ids, rope_deltas + @can_return_tuple + @auto_docstring + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + raise ValueError("Isaac is image-only and does not support `pixel_values_videos` or `video_grid_thw`.") + + @can_return_tuple + @auto_docstring + def get_image_features( + self, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, + image_metadata: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + """ + Args: + pixel_values (`torch.Tensor`): + Batch-major patch vectors with shape `(batch_size, max_images, max_patches, patch_dim)`. + image_grid_thw (`torch.Tensor`): + Batch-major grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. + image_metadata (`torch.Tensor`, *optional*): + Batch-major per-slot metadata `(offset, length)` shaped `(batch_size, max_images, 2)`. + """ + if pixel_values.shape[0] == 0: + hidden_size = self.config.get_text_config().hidden_size + return BaseModelOutputWithPooling( + last_hidden_state=pixel_values.new_zeros((0, 0, hidden_size)), + pooler_output=(), + hidden_states=None, + attentions=None, + ) + + image_grid_thw = image_grid_thw.to(dtype=torch.long) + active_slot_mask = image_grid_thw[..., 0].eq(1) + if not active_slot_mask.any(): + hidden_size = self.config.get_text_config().hidden_size + return BaseModelOutputWithPooling( + last_hidden_state=pixel_values.new_zeros((0, 0, hidden_size)), + pooler_output=(), + hidden_states=None, + attentions=None, + ) + + flat_pixel_values = pixel_values[active_slot_mask] + flat_image_grid_thw = image_grid_thw[active_slot_mask] + vision_outputs: BaseModelOutputWithPooling = self.visual( + pixel_values=flat_pixel_values, + image_grid_thw=flat_image_grid_thw, + return_dict=True, + **kwargs, + ) + projected_features = self.multimodal_projector(vision_outputs.last_hidden_state) + + pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor + full_lengths = flat_image_grid_thw[:, 1].div(pixel_shuffle_scale, rounding_mode="floor") * flat_image_grid_thw[ + :, 2 + ].div(pixel_shuffle_scale, rounding_mode="floor") + if image_metadata is None: + offsets = torch.zeros_like(full_lengths) + lengths = full_lengths + else: + torch_compilable_check( + image_metadata.shape[:2] == image_grid_thw.shape[:2], + "IsaacModel.get_image_features expects batch-major metadata aligned with `image_grid_thw`.", + ) + active_metadata = image_metadata[active_slot_mask] + offsets = active_metadata[:, 0].to(device=projected_features.device, dtype=torch.long) + lengths = active_metadata[:, 1].to(device=projected_features.device, dtype=torch.long) + + image_features = tuple( + projected_features[image_idx, offset : offset + length] + for image_idx, (offset, length) in enumerate(zip(offsets.tolist(), lengths.tolist(), strict=True)) + ) + + return BaseModelOutputWithPooling( + last_hidden_state=projected_features, + pooler_output=image_features, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + ) + + def get_placeholder_mask( + self, + mm_token_type_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor, + ) -> torch.BoolTensor: + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + image_token_mask = mm_token_type_ids.to(dtype=torch.long) == 1 + n_image_tokens = image_token_mask.sum() + image_token_mask = image_token_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[image_token_mask].numel() == image_features.numel(), + f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", + ) + return image_token_mask + def compute_3d_position_ids( self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor | None, mm_token_type_ids: torch.Tensor | None = None, - image_token_grids: torch.Tensor | None = None, - image_token_offsets: torch.Tensor | None = None, - image_token_lengths: torch.Tensor | None = None, - past_key_values: torch.Tensor | None = None, + image_grid_thw: torch.Tensor | None = None, + image_metadata: torch.Tensor | None = None, + past_key_values: Cache | None = None, ) -> torch.Tensor: past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() - if image_token_lengths is not None and image_token_lengths.gt(0).any() and past_seen_tokens == 0: + has_multimodal = ( + image_grid_thw is not None + and image_metadata is not None + and bool(image_grid_thw[..., 0].eq(1).any().item()) + ) + if has_multimodal and mm_token_type_ids is None and input_ids is not None: + raise ValueError( + "Multimodal data was passed (via `image_grid_thw`) but `mm_token_type_ids` is missing. " + "Please pass `mm_token_type_ids` so Isaac can build multimodal RoPE positions." + ) + + if has_multimodal and past_seen_tokens == 0: position_ids, rope_deltas = self.get_rope_index( + input_ids=input_ids, mm_token_type_ids=mm_token_type_ids, - image_token_grids=image_token_grids, - image_token_offsets=image_token_offsets, - image_token_lengths=image_token_lengths, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, attention_mask=attention_mask, - inputs_embeds=inputs_embeds, ) self.rope_deltas = rope_deltas return position_ids @@ -1403,15 +1369,12 @@ def forward( self, input_ids: torch.LongTensor | None = None, mm_token_type_ids: torch.LongTensor | None = None, - vision_patches: torch.Tensor | None = None, - image_patch_attention_mask: torch.Tensor | None = None, - vision_token_grids: torch.LongTensor | None = None, - image_token_grids: torch.LongTensor | None = None, - vision_token_offsets: torch.LongTensor | None = None, - vision_token_lengths: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + image_metadata: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -1419,48 +1382,56 @@ def forward( """ Args: mm_token_type_ids (`torch.LongTensor`, *optional*): - Multimodal token type ids aligned with the embedded sequence, shaped `(batch_size, seq_len)`. Isaac - follows the standard convention `0 -> text`, `1 -> image`. Treated as text-only when omitted. - vision_patches (`torch.FloatTensor`, *optional*): - Padded per-image patch vectors of shape `(batch_size, max_images, max_patches, patch_dim)`. - image_patch_attention_mask (`torch.LongTensor`, *optional*): - Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. - vision_token_grids (`torch.LongTensor`, *optional*): - Per-image patch grids `(h, w)` with shape `(batch_size, max_images, 2)`. - image_token_grids (`torch.LongTensor`, *optional*): - Alias for `vision_token_grids`. - vision_token_offsets (`torch.LongTensor`, *optional*): - Start offsets inside the per-image vision embedding sequence, shape `(batch_size, max_images)`. - vision_token_lengths (`torch.LongTensor`, *optional*): - Number of vision tokens to consume per image, shape `(batch_size, max_images)`. + Multimodal token type ids aligned with the token sequence, using `0 -> text` and `1 -> image`. + pixel_values (`torch.FloatTensor`, *optional*): + Batch-major patch vectors shaped `(batch_size, max_images, max_patches, patch_dim)`. + image_grid_thw (`torch.LongTensor`, *optional*): + Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. + image_metadata (`torch.LongTensor`, *optional*): + Batch-major per-slot metadata shaped `(batch_size, max_images, 2)` with `(offset, length)`. """ - created_inputs_embeds = inputs_embeds is None - if created_inputs_embeds: - inputs_embeds = self.text_model.embed_tokens(input_ids) + if (input_ids is None) == (inputs_embeds is None): + raise ValueError("You must specify exactly one of `input_ids` or `inputs_embeds`.") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + batch_size, seq_len = inputs_embeds.shape[:2] if mm_token_type_ids is None: - batch_size, seq_len = inputs_embeds.shape[:2] mm_token_type_ids = torch.full((batch_size, seq_len), 0, device=inputs_embeds.device, dtype=torch.long) else: - mm_token_type_ids = mm_token_type_ids.to(dtype=torch.long) - - image_token_mask = mm_token_type_ids == 1 - if created_inputs_embeds and torch.any(image_token_mask): + mm_token_type_ids = mm_token_type_ids.to(device=inputs_embeds.device, dtype=torch.long) + if mm_token_type_ids.shape[1] < seq_len: + padding = mm_token_type_ids.new_zeros((batch_size, seq_len - mm_token_type_ids.shape[1])) + mm_token_type_ids = torch.cat([mm_token_type_ids, padding], dim=1) + elif mm_token_type_ids.shape[1] > seq_len: + mm_token_type_ids = mm_token_type_ids[:, -seq_len:] + + if image_metadata is not None: + image_metadata = image_metadata.to(device=inputs_embeds.device, dtype=torch.long) + + image_mask = None + has_active_images = ( + pixel_values is not None and image_grid_thw is not None and bool(image_grid_thw[..., 0].eq(1).any().item()) + ) + if has_active_images: image_outputs = self.get_image_features( - pixel_values=vision_patches, - image_token_grids=vision_token_grids, - image_patch_attention_mask=image_patch_attention_mask, - image_token_offsets=vision_token_offsets, - image_token_lengths=vision_token_lengths, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, return_dict=True, ) - image_features = image_outputs.pooler_output.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype) - scatter_mask = self.get_placeholder_mask( - mm_token_type_ids=mm_token_type_ids, - inputs_embeds=inputs_embeds, - image_features=image_features, - ) - inputs_embeds = inputs_embeds.masked_scatter(scatter_mask, image_features) + image_embeds = image_outputs.pooler_output + if len(image_embeds) > 0: + image_embeds = torch.cat(image_embeds, dim=0).to( + device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + image_mask = self.get_placeholder_mask( + mm_token_type_ids=mm_token_type_ids, + inputs_embeds=inputs_embeds, + image_features=image_embeds, + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if isinstance(attention_mask, dict): attention_mask = attention_mask["full_attention"] @@ -1470,9 +1441,8 @@ def forward( input_ids=input_ids, inputs_embeds=inputs_embeds, mm_token_type_ids=mm_token_type_ids, - image_token_grids=vision_token_grids, - image_token_offsets=vision_token_offsets, - image_token_lengths=vision_token_lengths, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, attention_mask=attention_mask, past_key_values=past_key_values, ) @@ -1485,29 +1455,77 @@ def forward( if position_ids.ndim == 2: position_ids = position_ids.view(1, position_ids.shape[0], -1).expand(3, -1, -1) - text_model_outputs = self.text_model( + outputs = self.language_model( + input_ids=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, + visual_pos_masks=image_mask[..., 0] if image_mask is not None else None, + deepstack_visual_embeds=None, use_cache=use_cache, **kwargs, ) - return BaseModelOutputWithPast( - last_hidden_state=text_model_outputs.last_hidden_state, - past_key_values=text_model_outputs.past_key_values, - hidden_states=text_model_outputs.hidden_states, - attentions=text_model_outputs.attentions, + outputs_with_rope = BaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) + outputs_with_rope["rope_deltas"] = self.rope_deltas + return outputs_with_rope + + +@dataclass +@auto_docstring +class BaseModelOutputWithDeepstackFeatures(BaseModelOutputWithPooling): + r""" + deepstack_features (`List[torch.FloatTensor]`, *optional*): + List of hidden-states (feature maps) from deepstack layers. + """ + + deepstack_features: list[torch.FloatTensor] | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Isaac causal language model (or autoregressive) outputs. + """ +) +class IsaacCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + rope_deltas: torch.LongTensor | None = None @auto_docstring class IsaacForConditionalGeneration(IsaacPreTrainedModel, GenerationMixin): - _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"} - _tp_plan = {"lm_head": "colwise_gather_output"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: IsaacConfig config_class = IsaacConfig + input_modalities = ("image", "text") + _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionEncoderLayer"] _can_compile_fullgraph = False def __init__(self, config: IsaacConfig): @@ -1515,122 +1533,195 @@ def __init__(self, config: IsaacConfig): self.model = IsaacModel(config) self.vocab_size = config.get_text_config().vocab_size self.lm_head = nn.Linear(config.get_text_config().hidden_size, config.get_text_config().vocab_size, bias=False) - - # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + @auto_docstring + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithDeepstackFeatures: + r""" + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + return self.model.get_video_features( + pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs + ) + + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithDeepstackFeatures: + r""" + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + @can_return_tuple def forward( self, - input_ids: torch.LongTensor | None = None, - mm_token_type_ids: torch.LongTensor | None = None, - vision_patches: torch.Tensor | None = None, - pixel_values: torch.Tensor | None = None, - image_patch_attention_mask: torch.Tensor | None = None, - vision_token_grids: torch.LongTensor | None = None, - image_token_grids: torch.LongTensor | None = None, - vision_token_offsets: torch.LongTensor | None = None, - vision_token_lengths: torch.LongTensor | None = None, + input_ids: torch.LongTensor = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, - use_cache: bool | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + mm_token_type_ids: torch.IntTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> tuple | CausalLMOutputWithPast: + ) -> tuple | IsaacCausalLMOutputWithPast: r""" - mm_token_type_ids (`torch.LongTensor`, *optional*): - Multimodal token type ids aligned with the token sequence, shaped `(batch_size, seq_len)`, using - `0 -> text` and `1 -> image`. - vision_patches (`torch.FloatTensor`, *optional*): - Padded per-image patch vectors of shape `(batch_size, max_images, max_patches, patch_dim)`. - pixel_values (`torch.FloatTensor`, *optional*): - Alias for `vision_patches` accepted by generic image-feature and generation helpers. - image_patch_attention_mask (`torch.LongTensor`, *optional*): - Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. - vision_token_grids (`torch.LongTensor`, *optional*): - Per-image patch grids `(h, w)` with shape `(batch_size, max_images, 2)`. - image_token_grids (`torch.LongTensor`, *optional*): - Alias for `vision_token_grids`. - vision_token_offsets (`torch.LongTensor`, *optional*): - Start offsets inside the per-image vision embedding sequence, shape `(batch_size, max_images)`. - vision_token_lengths (`torch.LongTensor`, *optional*): - Number of vision tokens to consume per image, shape `(batch_size, max_images)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + + Example: + + ```python + >>> from transformers import AutoProcessor, IsaacForConditionalGeneration + + >>> model = IsaacForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL-8B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", + }, + {"type": "text", "text": "Describe the image."}, + ], + } + ] + + >>> inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt" + ) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=1024) + >>> generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] + >>> output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + >>> print(output_text) + ``` """ + outputs = self.model( input_ids=input_ids, - mm_token_type_ids=mm_token_type_ids, - vision_patches=vision_patches, - image_patch_attention_mask=image_patch_attention_mask, - vision_token_grids=vision_token_grids, - vision_token_offsets=vision_token_offsets, - vision_token_lengths=vision_token_lengths, - attention_mask=attention_mask, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, position_ids=position_ids, + attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - use_cache=use_cache, + mm_token_type_ids=mm_token_type_ids, **kwargs, ) + hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=logits.shape[-1]) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) - return CausalLMOutputWithPast( + return IsaacCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, ) def prepare_inputs_for_generation( self, - input_ids: torch.LongTensor, - past_key_values: list[torch.FloatTensor] | None = None, + input_ids, + past_key_values=None, attention_mask: torch.Tensor | None = None, inputs_embeds: torch.FloatTensor | None = None, mm_token_type_ids: torch.LongTensor | None = None, - vision_patches: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None, - image_patch_attention_mask: torch.Tensor | None = None, - vision_token_grids: torch.LongTensor | None = None, - image_token_grids: torch.LongTensor | None = None, - vision_token_offsets: torch.LongTensor | None = None, - vision_token_lengths: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + image_metadata: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, is_first_iteration=False, use_cache=True, **kwargs, ) -> dict[str, Any]: - if vision_patches is None: - vision_token_grids = image_token_grids if vision_token_grids is None else vision_token_grids model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, position_ids=position_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, is_first_iteration=is_first_iteration, use_cache=use_cache, **kwargs, ) + is_prefill = is_first_iteration or not use_cache multimodal_inputs = { "mm_token_type_ids": mm_token_type_ids, - "vision_patches": vision_patches, - "image_patch_attention_mask": image_patch_attention_mask, - "vision_token_grids": vision_token_grids, - "vision_token_offsets": vision_token_offsets, - "vision_token_lengths": vision_token_lengths, + "pixel_values": pixel_values, + "image_grid_thw": image_grid_thw, + "image_metadata": image_metadata, } - is_prefill = is_first_iteration or not use_cache for key, value in multimodal_inputs.items(): model_inputs[key] = value if is_prefill else None - + if model_inputs["mm_token_type_ids"] is not None: + sequence_length = None + if model_inputs.get("input_ids") is not None: + sequence_length = model_inputs["input_ids"].shape[1] + elif model_inputs.get("inputs_embeds") is not None: + sequence_length = model_inputs["inputs_embeds"].shape[1] + + if sequence_length is not None: + current_length = model_inputs["mm_token_type_ids"].shape[1] + if current_length < sequence_length: + padding = model_inputs["mm_token_type_ids"].new_zeros( + (model_inputs["mm_token_type_ids"].shape[0], sequence_length - current_length) + ) + model_inputs["mm_token_type_ids"] = torch.cat([model_inputs["mm_token_type_ids"], padding], dim=1) + elif current_length > sequence_length: + model_inputs["mm_token_type_ids"] = model_inputs["mm_token_type_ids"][:, -sequence_length:] return model_inputs def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs): @@ -1645,18 +1736,15 @@ def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs): if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0: inputs_tensor = model_kwargs["input_ids"] + is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long] if ( - model_kwargs.get("vision_token_lengths") is not None - and len(inputs_tensor.shape) == 2 - and inputs_tensor.dtype in [torch.int, torch.long] + is_input_ids + and model_kwargs.get("mm_token_type_ids") is not None + and model_kwargs.get("image_grid_thw") is not None + and model_kwargs.get("image_metadata") is not None ): - vision_positions, rope_deltas = self.model.get_rope_index( - mm_token_type_ids=model_kwargs["mm_token_type_ids"], - image_token_grids=model_kwargs["vision_token_grids"], - image_token_offsets=model_kwargs["vision_token_offsets"], - image_token_lengths=model_kwargs["vision_token_lengths"], - attention_mask=model_kwargs.get("attention_mask"), - ) + model_kwargs = {k: v for k, v in model_kwargs.items() if k != "input_ids"} + vision_positions, rope_deltas = self.model.get_rope_index(inputs_tensor, **model_kwargs) self.model.rope_deltas = rope_deltas else: vision_positions = text_positions.unsqueeze(0).expand(3, -1, -1) @@ -1666,6 +1754,57 @@ def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs): return torch.cat([text_positions[None, ...], vision_positions], dim=0) + def _get_image_nums_and_video_nums( + self, + input_ids: torch.LongTensor | None, + inputs_embeds: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + if inputs_embeds is not None: + vision_start_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + image_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + video_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + else: + vision_start_mask = input_ids == vision_start_token_id + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + def _expand_inputs_for_generation( self, expand_size: int = 1, @@ -1674,12 +1813,26 @@ def _expand_inputs_for_generation( **model_kwargs, ) -> tuple[torch.LongTensor, dict[str, Any]]: position_ids = model_kwargs.pop("position_ids", None) - input_ids, model_kwargs = super()._expand_inputs_for_generation( - expand_size=expand_size, - is_encoder_decoder=is_encoder_decoder, - input_ids=input_ids, - **model_kwargs, - ) + if expand_size == 1: + if position_ids is not None: + model_kwargs["position_ids"] = position_ids + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "image_metadata"] + for key in visual_keys: + value = model_kwargs.get(key) + if value is not None: + model_kwargs[key] = value.repeat_interleave(expand_size, dim=0) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + for key, value in list(model_kwargs.items()): + if key == "position_ids" and value is not None and value.ndim == 3: + model_kwargs[key] = value.repeat_interleave(expand_size, dim=1) + elif value is not None and isinstance(value, torch.Tensor) and key not in visual_keys: + model_kwargs[key] = value.repeat_interleave(expand_size, dim=0) + if position_ids is not None: dim = 1 if position_ids.ndim == 3 else 0 model_kwargs["position_ids"] = position_ids.repeat_interleave(expand_size, dim=dim) diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index f37d82f6f23a..f9463d5d5b08 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -23,20 +23,17 @@ from ... import TorchvisionBackend from ... import initialization as init +from ...cache_utils import Cache from ...configuration_utils import PretrainedConfig from ...feature_extraction_utils import BatchFeature from ...generation.utils import GenerationMixin from ...image_transforms import group_images_by_shape, reorder_images from ...image_utils import ImageInput, PILImageResampling, SizeDict, make_nested_list_of_images from ...masking_utils import create_bidirectional_mask -from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...models.qwen3.configuration_qwen3 import Qwen3Config -from ...models.qwen3.modeling_qwen3 import ( - Qwen3ForCausalLM, - Qwen3PreTrainedModel, -) -from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...utils import TensorType, auto_docstring, torch_compilable_check from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD @@ -49,6 +46,8 @@ ) from ...utils.output_capturing import capture_outputs from ..qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLForConditionalGeneration, + Qwen3VLModel, Qwen3VLTextAttention, Qwen3VLTextDecoderLayer, Qwen3VLTextModel, @@ -110,6 +109,7 @@ class IsaacProcessorKwargs(ProcessingKwargs, total=False): "text_kwargs": { "padding": True, "return_attention_mask": True, + "return_mm_token_type_ids": True, }, } @@ -187,7 +187,7 @@ def clean_text_and_extract_points( @auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") -@strict(accept_kwargs=True) +@strict class IsaacVisionConfig(Siglip2VisionConfig): r""" num_patches (`int`, *optional*, defaults to 256): @@ -205,7 +205,7 @@ class IsaacVisionConfig(Siglip2VisionConfig): @auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") -@strict(accept_kwargs=True) +@strict class IsaacTextConfig(Qwen3Config): r""" Example: @@ -222,9 +222,16 @@ class IsaacTextConfig(Qwen3Config): model_type = "isaac_text" ignore_keys_at_rope_validation = {"mrope_section", "mrope_interleaved"} max_position_embeddings: int = 32768 + sliding_window = AttributeError() def __post_init__(self, **kwargs): - super().__post_init__(**kwargs) + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.layer_types is None: + self.layer_types = ["full_attention" for _ in range(self.num_hidden_layers)] + + PretrainedConfig.__post_init__(self, **kwargs) self.validate_layer_type() @@ -252,9 +259,8 @@ class IsaacImageProcessor(TorchvisionBackend): resample = PILImageResampling.BILINEAR model_input_names = [ - "vision_patches", - "image_patch_attention_mask", - "vision_token_grids", + "pixel_values", + "image_grid_thw", ] valid_kwargs = IsaacImageProcessorKwargs @@ -296,37 +302,65 @@ def resize( return image.clamp(0, 255).round().to(torch.uint8) return F.interpolate(image, size=(size.height, size.width), mode="bilinear", align_corners=False) - def pad( + def get_number_of_image_patches( + self, + image_height: int, + image_width: int, + images_kwargs: dict[str, Any] | None = None, + ) -> int: + images_kwargs = images_kwargs or {} + patch_size = images_kwargs.get("patch_size", self.patch_size) + max_num_patches = images_kwargs.get("max_num_patches", self.max_num_patches) + min_num_patches = images_kwargs.get("min_num_patches", self.min_num_patches) + pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale", self.pixel_shuffle_scale) + + target_height, target_width = get_image_size_for_max_num_patches( + image_height, + image_width, + patch_size, + max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + ) + return (target_height // patch_size) * (target_width // patch_size) + + def pack_images( self, vision_patches: list[list[torch.Tensor]], vision_token_grids: list[list[torch.Tensor]], - ) -> dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor | None]: batch_size = len(vision_patches) - first_patch = next(patches for sample_patches in vision_patches for patches in sample_patches) - max_images = max(len(sample_patches) for sample_patches in vision_patches) - max_patches = max(patches.shape[0] for sample_patches in vision_patches for patches in sample_patches) + max_images = max((len(sample_patches) for sample_patches in vision_patches), default=0) + flat_patches = [patches for sample_patches in vision_patches for patches in sample_patches] + if max_images == 0 or not flat_patches: + return { + "pixel_values": None, + "image_grid_thw": None, + } + + first_patch = flat_patches[0] + max_patches = max(patches.shape[0] for patches in flat_patches) patch_dim = first_patch.shape[-1] patch_dtype = first_patch.dtype patch_device = first_patch.device tensors = { - "vision_patches": torch.zeros( - (batch_size, max_images, max_patches, patch_dim), device=patch_device, dtype=patch_dtype - ), - "image_patch_attention_mask": torch.zeros( - (batch_size, max_images, max_patches), device=patch_device, dtype=torch.long + "pixel_values": torch.zeros( + (batch_size, max_images, max_patches, patch_dim), + device=patch_device, + dtype=patch_dtype, ), - "vision_token_grids": torch.zeros((batch_size, max_images, 2), device=patch_device, dtype=torch.long), + "image_grid_thw": torch.zeros((batch_size, max_images, 3), device=patch_device, dtype=torch.long), } for batch_idx, (sample_patches, sample_token_grids) in enumerate( zip(vision_patches, vision_token_grids, strict=True) ): - for image_idx, (patches, token_grid) in enumerate(zip(sample_patches, sample_token_grids, strict=True)): + for image_slot, (patches, token_grid) in enumerate(zip(sample_patches, sample_token_grids, strict=True)): patch_count = int(patches.shape[0]) - tensors["vision_patches"][batch_idx, image_idx, :patch_count] = patches - tensors["image_patch_attention_mask"][batch_idx, image_idx, :patch_count] = 1 - tensors["vision_token_grids"][batch_idx, image_idx] = token_grid + tensors["pixel_values"][batch_idx, image_slot, :patch_count] = patches + tensors["image_grid_thw"][batch_idx, image_slot, 0] = 1 + tensors["image_grid_thw"][batch_idx, image_slot, 1:] = token_grid return tensors @@ -350,15 +384,12 @@ def _preprocess( **kwargs, ) -> BatchFeature: resample = kwargs.pop("interpolation", resample) - batch_size = len(images) # IsaacProcessor routes text-only calls here as an empty image list per sample. - # This returns empty vision tensors to preserve the multimodal output schema; - # image-token/image-count mismatches are validated earlier in processor's _preprocess call. + # Return `None` visual fields so text-only batches skip multimodal codepaths like other VLMs. if all(len(sample_images) == 0 for sample_images in images): tensors = { - "vision_patches": torch.zeros((batch_size, 0, 0, 0), dtype=torch.float32), - "image_patch_attention_mask": torch.zeros((batch_size, 0, 0), dtype=torch.long), - "vision_token_grids": torch.zeros((batch_size, 0, 2), dtype=torch.long), + "pixel_values": None, + "image_grid_thw": None, } return BatchFeature(data=tensors, tensor_type=return_tensors) @@ -427,7 +458,7 @@ def _preprocess( if not do_pad: raise ValueError("IsaacImageProcessor doesn't support `do_pad=False` mode.") - tensors = self.pad( + tensors = self.pack_images( vision_patches=nested_outputs["vision_patches"], vision_token_grids=nested_outputs["vision_token_grids"], ) @@ -456,7 +487,7 @@ def __init__(self, config: IsaacVisionConfig): def forward( self, pixel_values: torch.Tensor, - spatial_shapes: torch.Tensor, + image_grid_thw: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: # pixel_values: (num_images, max_patches, patch_dim) @@ -465,13 +496,16 @@ def forward( resized_positional_embeddings = self.resize_positional_embeddings( self.position_embedding, - spatial_shapes, + image_grid_thw[:, 1:], max_length=pixel_values.shape[1], ) + resized_positional_embeddings = resized_positional_embeddings.to( + device=patch_embeds.device, dtype=patch_embeds.dtype + ) embeddings = patch_embeds + resized_positional_embeddings if attention_mask is not None: - embeddings = embeddings * attention_mask.unsqueeze(-1).to(dtype=embeddings.dtype) + embeddings = embeddings * attention_mask.unsqueeze(-1).to(device=embeddings.device, dtype=embeddings.dtype) return embeddings @@ -499,7 +533,7 @@ def __init__(self, config: IsaacVisionConfig): def pixel_shuffle_padded( - x: torch.Tensor, + hidden_states: torch.Tensor, token_grids: torch.Tensor, scale_factor: int = 1, ) -> torch.Tensor: @@ -517,10 +551,10 @@ def pixel_shuffle_padded( `torch.Tensor`: Pixel-shuffled embeddings of shape `(num_images, max_tokens, hidden_size * scale_factor**2)`. """ - num_images, max_patches, embed_dim = x.shape + num_images, max_patches, embed_dim = hidden_states.shape output_dim = embed_dim * scale_factor * scale_factor - token_grids = token_grids.to(device=x.device, dtype=torch.long) + token_grids = token_grids.to(device=hidden_states.device, dtype=torch.long) heights = token_grids[:, 0] widths = token_grids[:, 1] full_lengths = heights * widths @@ -535,9 +569,11 @@ def pixel_shuffle_padded( output_lengths = (heights // scale_factor) * (widths // scale_factor) max_output_tokens = output_lengths.max() - shuffled_4d = x.new_zeros((num_images, max_output_tokens, scale_factor * scale_factor, embed_dim)) + shuffled_4d = hidden_states.new_zeros((num_images, max_output_tokens, scale_factor * scale_factor, embed_dim)) - token_positions = torch.arange(max_patches, device=x.device, dtype=torch.long).unsqueeze(0).expand(num_images, -1) + token_positions = ( + torch.arange(max_patches, device=hidden_states.device, dtype=torch.long).unsqueeze(0).expand(num_images, -1) + ) valid_token_mask = token_positions < full_lengths.unsqueeze(1) safe_widths = torch.where(widths > 0, widths, torch.ones_like(widths)) @@ -549,10 +585,12 @@ def pixel_shuffle_padded( output_index = output_index + col_index.div(scale_factor, rounding_mode="floor") sub_index = row_index.remainder(scale_factor) * scale_factor + col_index.remainder(scale_factor) - batch_index = torch.arange(num_images, device=x.device, dtype=torch.long).unsqueeze(1).expand_as(token_positions) - shuffled_4d[batch_index[valid_token_mask], output_index[valid_token_mask], sub_index[valid_token_mask]] = x[ - valid_token_mask - ] + batch_index = ( + torch.arange(num_images, device=hidden_states.device, dtype=torch.long).unsqueeze(1).expand_as(token_positions) + ) + shuffled_4d[batch_index[valid_token_mask], output_index[valid_token_mask], sub_index[valid_token_mask]] = ( + hidden_states[valid_token_mask] + ) shuffled = shuffled_4d.view(num_images, max_output_tokens, output_dim) return shuffled @@ -594,26 +632,28 @@ def _init_weights(self, module): @capture_outputs(tie_last_hidden_states=False) def forward( self, - vision_patches: torch.Tensor, - vision_token_grids: torch.Tensor, - image_patch_attention_mask: torch.Tensor, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: """ Inputs: - vision_patches (`torch.Tensor`): + pixel_values (`torch.Tensor`): Patches shaped `(num_images, max_patches, patch_dim)`. - vision_token_grids (`torch.Tensor`): - Token grids shaped `(num_images, 2)` with per-image `(H_tokens, W_tokens)`. - image_patch_attention_mask (`torch.Tensor`): - Patch mask shaped `(num_images, max_patches)`. + image_grid_thw (`torch.Tensor`): + Grid tensor shaped `(num_images, 3)` with per-image `(T=1, H_tokens, W_tokens)`. Returns: `BaseModelOutputWithPooling` with pixel-shuffled embeddings in `last_hidden_state`. """ + vision_token_grids = image_grid_thw[:, 1:].to(dtype=torch.long) + full_lengths = vision_token_grids[:, 0] * vision_token_grids[:, 1] + token_positions = torch.arange(pixel_values.shape[1], device=pixel_values.device, dtype=torch.long) + image_patch_attention_mask = token_positions.unsqueeze(0) < full_lengths.unsqueeze(1) + image_patch_attention_mask = image_patch_attention_mask.to(dtype=torch.long) hidden_states = self.embeddings( - vision_patches, - vision_token_grids, + pixel_values, + image_grid_thw, attention_mask=image_patch_attention_mask, ) @@ -626,7 +666,7 @@ def forward( hidden_states = self.post_layernorm(encoder_outputs.last_hidden_state) hidden_states = pixel_shuffle_padded( - x=hidden_states, + hidden_states=hidden_states, token_grids=vision_token_grids, scale_factor=self.pixel_shuffle_scale_factor, ) @@ -750,7 +790,7 @@ def get_image_size_for_max_num_patches( @auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") -@strict(accept_kwargs=True) +@strict class IsaacConfig(PretrainedConfig): r""" vision_config (`IsaacVisionConfig` or `dict`, *optional*): @@ -762,9 +802,6 @@ class IsaacConfig(PretrainedConfig): Rescale factor applied by the image processor before normalization. max_sequence_length (`int`, *optional*, defaults to 16384): Maximum multimodal sequence length produced by the processor and expected by the model. - vision_token (`str`, *optional*, defaults to `""`): - Placeholder string inserted into text prompts to mark image positions. - Example: ```python @@ -782,7 +819,6 @@ class IsaacConfig(PretrainedConfig): text_config: IsaacTextConfig | dict | None = None vision_rescale_factor: float = 1 / 255 max_sequence_length: int = 16384 - vision_token: str = "" def __post_init__(self, **kwargs): for key in ("use_cache", "rope_theta", "max_position_embeddings"): @@ -818,15 +854,12 @@ def __init__( image_processor, tokenizer, chat_template: str | dict[str, str] | None = None, - vision_token: str = "", max_sequence_length: int = 16384, ): """ Args: chat_template (`str` or `dict[str, str]`, *optional*): Chat template override forwarded to [`~processing_utils.ProcessorMixin`]. - vision_token (`str`, *optional*, defaults to `""`): - Placeholder token used inside text prompts to mark image positions. max_sequence_length (`int`, *optional*, defaults to 16384): Maximum packed multimodal sequence length produced by the processor. """ @@ -836,13 +869,33 @@ def __init__( self.image_processor = image_processor super().__init__(image_processor, tokenizer, chat_template=chat_template) self.text_pad_token_id = self.pad_token_id = tokenizer.pad_token_id - self.image_pad_token_id = tokenizer.image_pad_token_id - self.image_token = tokenizer.image_pad_token - self.image_token_id = self.image_pad_token_id + self.image_token = getattr(tokenizer, "image_pad_token", None) or getattr(tokenizer, "image_token", None) + self.image_token_id = getattr(tokenizer, "image_pad_token_id", None) or getattr( + tokenizer, "image_token_id", None + ) - self.vision_token = vision_token self.max_sequence_length = max_sequence_length + @property + def model_input_names(self): + return super().model_input_names + ["mm_token_type_ids", "image_metadata"] + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + vision_data = {} + if image_sizes is not None: + images_kwargs = dict(IsaacProcessorKwargs._defaults.get("images_kwargs", {})) + images_kwargs.update(kwargs) + + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) + for image_size in image_sizes + ] + pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale") or self.image_processor.pixel_shuffle_scale + num_image_tokens = [num_patches // pixel_shuffle_scale**2 for num_patches in num_image_patches] + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + def post_process_generation( self, text: str, @@ -885,8 +938,10 @@ def __call__( padding = text_kwargs.pop("padding", True) padding_side = text_kwargs.pop("padding_side", "left") return_attention_mask = text_kwargs.pop("return_attention_mask", True) + return_mm_token_type_ids = text_kwargs.pop("return_mm_token_type_ids", True) pad_to_multiple_of = text_kwargs.pop("pad_to_multiple_of", None) text_kwargs.pop("return_tensors", None) + text_kwargs.pop("return_overflowing_tokens", None) text_kwargs.setdefault("add_special_tokens", False) texts = [text] if isinstance(text, str) else text @@ -896,7 +951,7 @@ def __call__( fetched_images = self.image_processor.fetch_images(images) batched_images = make_nested_list_of_images(fetched_images) if len(batched_images) != len(texts): - num_images_in_text = [text_value.count(self.vision_token) for text_value in texts] + num_images_in_text = [text_value.count(self.image_token) for text_value in texts] num_images_in_images = [len(sample_images) for sample_images in batched_images] add_message = "" if sum(num_images_in_text) == sum(num_images_in_images): @@ -908,19 +963,23 @@ def __call__( pairs = list(zip(texts, batched_images, strict=True)) image_inputs = self.image_processor(images=batched_images, return_tensors=TensorType.PYTORCH) - vision_token_grids = image_inputs["vision_token_grids"] - vision_segment_lengths = (vision_token_grids[..., 0] // self.image_processor.pixel_shuffle_scale) * ( - vision_token_grids[..., 1] // self.image_processor.pixel_shuffle_scale - ) - vision_token_offsets = torch.zeros_like(vision_segment_lengths) - vision_token_lengths = torch.zeros_like(vision_segment_lengths) + image_grid_thw = image_inputs["image_grid_thw"] + image_metadata = None + vision_segment_lengths = None + if image_grid_thw is not None: + batch_size, max_images = image_grid_thw.shape[:2] + image_metadata = torch.zeros((batch_size, max_images, 2), dtype=torch.long) + grid_heights = image_grid_thw[..., 1] + grid_widths = image_grid_thw[..., 2] + vision_segment_lengths = (grid_heights // self.image_processor.pixel_shuffle_scale) * ( + grid_widths // self.image_processor.pixel_shuffle_scale + ) - sample_input_ids: list[torch.Tensor] = [] expanded_texts = [] expected_image_lengths_per_sample = [] for batch_idx, (text_value, sample_images) in enumerate(pairs): - segments = text_value.split(self.vision_token) + segments = text_value.split(self.image_token) num_images = len(segments) - 1 num_provided_images = len(sample_images) if num_images != num_provided_images: @@ -928,81 +987,91 @@ def __call__( f"IsaacProcessor expects one image per image token, got {num_images} tokens and {num_provided_images} images in sample with text {text_value} " ) - expected_image_lengths = [ - int(vision_segment_lengths[batch_idx, image_idx].item()) for image_idx in range(num_images) - ] - expected_image_lengths_per_sample.append(expected_image_lengths) - - expanded_text = segments[0] - for image_idx, segment_length in enumerate(expected_image_lengths): - expanded_text += (self.image_token * segment_length) + segments[image_idx + 1] - expanded_texts.append(expanded_text) + expected_image_lengths = [] + expanded_text_parts = [segments[0]] + for image_idx in range(num_images): + segment_length = int(vision_segment_lengths[batch_idx, image_idx].item()) + expected_image_lengths.append(segment_length) + expanded_text_parts.append(self.image_token * segment_length) + expanded_text_parts.append(segments[image_idx + 1]) - tokenized_text_inputs = self.tokenizer(expanded_texts, return_tensors=None, **text_kwargs) - self._check_special_mm_tokens(expanded_texts, tokenized_text_inputs, modalities=["image"]) + expected_image_lengths_per_sample.append(expected_image_lengths) + expanded_texts.append("".join(expanded_text_parts)) effective_max_length = self.max_sequence_length - if truncation and max_length is not None: + if max_length is not None and (truncation or padding == "max_length"): effective_max_length = max_length - for batch_idx, (expected_image_lengths, sample_input_ids_list) in enumerate( - zip(expected_image_lengths_per_sample, tokenized_text_inputs["input_ids"], strict=True) - ): - sample_input = torch.tensor(sample_input_ids_list, dtype=torch.long) - image_positions = sample_input.eq(self.image_pad_token_id).nonzero(as_tuple=False).flatten() - image_spans = image_positions.split(expected_image_lengths) if expected_image_lengths else () - image_bounds = [] - - for image_idx, (segment_length, image_span) in enumerate( - zip(expected_image_lengths, image_spans, strict=True) - ): - image_start = int(image_span[0].item()) - image_end = int(image_span[-1].item()) + 1 - image_bounds.append((image_start, image_end)) - total = int(sample_input.shape[0]) - start = max(0, total - effective_max_length) - sample_input_ids.append(sample_input[start:]) - - for image_idx, (image_start, image_end) in enumerate(image_bounds): - kept_start = max(start, image_start) - kept_end = min(total, image_end) - if kept_end > kept_start: - vision_token_offsets[batch_idx, image_idx] = kept_start - image_start - vision_token_lengths[batch_idx, image_idx] = kept_end - kept_start - - # Pad only after Isaac-specific truncation so image span offsets and lengths stay aligned. - padded_text_inputs = self.tokenizer.pad( - {"input_ids": [sample_input.tolist() for sample_input in sample_input_ids]}, + self.tokenizer.truncation_side = "left" + self.tokenizer.padding_side = padding_side + tokenized_text_inputs = self.tokenizer( + expanded_texts, + truncation=True, + max_length=effective_max_length, padding=padding, - max_length=max_length if padding == "max_length" else None, pad_to_multiple_of=pad_to_multiple_of, - padding_side=padding_side, return_attention_mask=return_attention_mask, - return_tensors=TensorType.PYTORCH, + return_overflowing_tokens=True, + stride=0, + return_tensors=None, + **text_kwargs, ) - input_ids = padded_text_inputs["input_ids"] - attention_mask = padded_text_inputs.get("attention_mask") - if attention_mask is None: - attention_mask = input_ids.ne(self.pad_token_id).to(dtype=torch.long) - mm_token_type_ids = input_ids.eq(self.image_pad_token_id).to(dtype=torch.long) - vision_image_attention_mask = vision_token_lengths.gt(0).to(dtype=torch.long) + kept_input_ids_per_sample: list[list[int] | None] = [None] * len(texts) + overflow_input_ids_per_sample: list[list[list[int]]] = [[] for _ in texts] + overflow_to_sample_mapping = tokenized_text_inputs.get("overflow_to_sample_mapping") + if overflow_to_sample_mapping is None: + overflow_to_sample_mapping = list(range(len(tokenized_text_inputs["input_ids"]))) + + for row_input_ids, sample_idx in zip( + tokenized_text_inputs["input_ids"], overflow_to_sample_mapping, strict=True + ): + sample_idx = int(sample_idx) + if kept_input_ids_per_sample[sample_idx] is None: + kept_input_ids_per_sample[sample_idx] = row_input_ids + else: + overflow_input_ids_per_sample[sample_idx].append(row_input_ids) + + for batch_idx, expected_image_lengths in enumerate(expected_image_lengths_per_sample): + dropped_image_tokens = sum( + overflow_input_ids.count(self.image_token_id) + for overflow_input_ids in overflow_input_ids_per_sample[batch_idx] + ) - vision_patches = image_inputs["vision_patches"] - image_patch_attention_mask = image_inputs["image_patch_attention_mask"] + remaining_dropped = dropped_image_tokens + for image_idx, expected_length in enumerate(expected_image_lengths): + if remaining_dropped <= 0: + offset = 0 + length = expected_length + elif remaining_dropped < expected_length: + offset = remaining_dropped + length = expected_length - offset + remaining_dropped = 0 + else: + offset = 0 + length = 0 + remaining_dropped -= expected_length + + # Record which suffix of this image's placeholder span survives left truncation. + # The model still encodes the full image and uses this window for both feature gathering and vision RoPE. + image_metadata[batch_idx, image_idx, 0] = offset + image_metadata[batch_idx, image_idx, 1] = length + + input_ids = torch.tensor(kept_input_ids_per_sample, dtype=torch.long) + attention_mask = input_ids.ne(self.pad_token_id).to(dtype=torch.long) + + data = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": image_inputs["pixel_values"], + "image_grid_thw": image_grid_thw, + "image_metadata": image_metadata, + } + if return_mm_token_type_ids: + data["mm_token_type_ids"] = input_ids.eq(self.image_token_id).to(dtype=torch.long) return BatchFeature( - data={ - "input_ids": input_ids, - "attention_mask": attention_mask, - "mm_token_type_ids": mm_token_type_ids, - "vision_patches": vision_patches, - "image_patch_attention_mask": image_patch_attention_mask, - "vision_token_grids": vision_token_grids, - "vision_token_offsets": vision_token_offsets, - "vision_token_lengths": vision_token_lengths, - "vision_image_attention_mask": vision_image_attention_mask, - }, + data=data, tensor_type=return_tensors, ) @@ -1050,22 +1119,22 @@ def __init__(self, config: IsaacTextConfig): @auto_docstring -class IsaacModel(Qwen3PreTrainedModel): +class IsaacModel(Qwen3VLModel): + input_modalities = ("image", "text") supports_gradient_checkpointing = True + _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionEncoderLayer"] _can_compile_fullgraph = False _supports_flex_attn = False _tied_weights_keys = {} - _input_embed_layer = "text_model.embed_tokens" + _input_embed_layer = "language_model.embed_tokens" def __init__(self, config: IsaacConfig): - Qwen3PreTrainedModel.__init__(self, config) - self.text_model = IsaacTextModel._from_config(config.text_config) - - self.vision_tower = IsaacVisionTransformer(config.vision_config) + PreTrainedModel.__init__(self, config) + self.language_model = IsaacTextModel._from_config(config.text_config) + self.visual = IsaacVisionTransformer(config.vision_config) self.multimodal_projector = IsaacMultiModalProjector(config) self.max_sequence_length = config.max_sequence_length self.vision_rescale_factor = config.vision_rescale_factor - self.vision_token = config.vision_token self.rope_deltas = None self.post_init() @@ -1075,79 +1144,75 @@ def __init__(self, config: IsaacConfig): def get_image_features( self, pixel_values: torch.Tensor, - image_token_grids: torch.Tensor, - image_patch_attention_mask: torch.Tensor | None = None, - image_token_offsets: torch.Tensor | None = None, - image_token_lengths: torch.Tensor | None = None, + image_grid_thw: torch.Tensor, + image_metadata: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ Args: pixel_values (`torch.Tensor`): - Padded per-image patch vectors with shape `(batch_size, max_images, max_patches, patch_dim)`. - image_token_grids (`torch.Tensor`): - Per-image token grids shaped `(batch_size, max_images, 2)` with `(height, width)` entries. - image_patch_attention_mask (`torch.Tensor`, *optional*): - Mask for valid patch rows in `pixel_values`, shaped `(batch_size, max_images, max_patches)`. - image_token_offsets (`torch.Tensor`, *optional*): - Start offsets inside each per-image embedding sequence, shaped `(batch_size, max_images)`. - image_token_lengths (`torch.Tensor`, *optional*): - Number of image tokens to gather per image for placeholder scattering, shaped `(batch_size, max_images)`. + Batch-major patch vectors with shape `(batch_size, max_images, max_patches, patch_dim)`. + image_grid_thw (`torch.Tensor`): + Batch-major grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. + image_metadata (`torch.Tensor`, *optional*): + Batch-major per-slot metadata `(offset, length)` shaped `(batch_size, max_images, 2)`. """ - image_token_grids = image_token_grids.to(dtype=torch.long) - patch_attention_mask = image_patch_attention_mask.to(dtype=torch.long) - if image_token_lengths is not None: - image_attention_mask = image_token_lengths > 0 - else: - image_attention_mask = image_token_grids.any(dim=-1) - - torch_compilable_check( - image_attention_mask.any(), - "IsaacModel.get_image_features expects at least one active image slot; text-only inputs should skip this method.", - ) + if pixel_values.shape[0] == 0: + hidden_size = self.config.get_text_config().hidden_size + return BaseModelOutputWithPooling( + last_hidden_state=pixel_values.new_zeros((0, 0, hidden_size)), + pooler_output=(), + hidden_states=None, + attentions=None, + ) - batch_size, max_images = pixel_values.shape[:2] - hidden_size = self.config.get_text_config().hidden_size + image_grid_thw = image_grid_thw.to(dtype=torch.long) + active_slot_mask = image_grid_thw[..., 0].eq(1) + if not active_slot_mask.any(): + hidden_size = self.config.get_text_config().hidden_size + return BaseModelOutputWithPooling( + last_hidden_state=pixel_values.new_zeros((0, 0, hidden_size)), + pooler_output=(), + hidden_states=None, + attentions=None, + ) - vision_outputs = self.vision_tower( - vision_patches=pixel_values[image_attention_mask], - vision_token_grids=image_token_grids[image_attention_mask], - image_patch_attention_mask=patch_attention_mask[image_attention_mask], + flat_pixel_values = pixel_values[active_slot_mask] + flat_image_grid_thw = image_grid_thw[active_slot_mask] + vision_outputs: BaseModelOutputWithPooling = self.visual( + pixel_values=flat_pixel_values, + image_grid_thw=flat_image_grid_thw, return_dict=True, **kwargs, ) - flat_projected_features = self.multimodal_projector(vision_outputs.last_hidden_state) - max_tokens = flat_projected_features.shape[1] - projected_features = flat_projected_features.new_zeros((batch_size, max_images, max_tokens, hidden_size)) - projected_features[image_attention_mask] = flat_projected_features - feature_device = flat_projected_features.device - offsets = ( - image_token_offsets.to(dtype=torch.long) - if image_token_offsets is not None - else torch.zeros((batch_size, max_images), device=feature_device, dtype=torch.long) - ) - lengths = ( - image_token_lengths.to(dtype=torch.long) - if image_token_lengths is not None - else torch.full((batch_size, max_images), max_tokens, device=feature_device, dtype=torch.long) + projected_features = self.multimodal_projector(vision_outputs.last_hidden_state) + + pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor + full_lengths = flat_image_grid_thw[:, 1].div(pixel_shuffle_scale, rounding_mode="floor") * flat_image_grid_thw[ + :, 2 + ].div(pixel_shuffle_scale, rounding_mode="floor") + if image_metadata is None: + offsets = torch.zeros_like(full_lengths) + lengths = full_lengths + else: + torch_compilable_check( + image_metadata.shape[:2] == image_grid_thw.shape[:2], + "IsaacModel.get_image_features expects batch-major metadata aligned with `image_grid_thw`.", + ) + active_metadata = image_metadata[active_slot_mask] + offsets = active_metadata[:, 0].to(device=projected_features.device, dtype=torch.long) + lengths = active_metadata[:, 1].to(device=projected_features.device, dtype=torch.long) + + image_features = tuple( + projected_features[image_idx, offset : offset + length] + for image_idx, (offset, length) in enumerate(zip(offsets.tolist(), lengths.tolist(), strict=True)) ) - flat_offsets = offsets[image_attention_mask] - flat_lengths = lengths[image_attention_mask] - token_positions = torch.arange(flat_lengths.max(), device=feature_device, dtype=torch.long) - gather_positions = flat_offsets[:, None] + token_positions[None, :] - gather_mask = token_positions[None, :] < flat_lengths[:, None] - image_features = flat_projected_features[ - torch.arange(flat_projected_features.shape[0], device=feature_device, dtype=torch.long)[:, None], - gather_positions, - ][gather_mask] - hidden_states = vision_outputs.hidden_states - attentions = vision_outputs.attentions return BaseModelOutputWithPooling( last_hidden_state=projected_features, pooler_output=image_features, - hidden_states=hidden_states, - attentions=attentions, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, ) def get_placeholder_mask( @@ -1165,61 +1230,77 @@ def get_placeholder_mask( ) return image_token_mask + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + raise ValueError("Isaac is image-only and does not support `pixel_values_videos` or `video_grid_thw`.") + def get_vision_position_ids( self, start_position: int, - grid_hw: torch.LongTensor, - token_offset: int, - token_length: int, + grid_thw: torch.LongTensor, + image_metadata: torch.LongTensor, ) -> torch.LongTensor: - height, width = grid_hw[0].item(), grid_hw[1].item() - token_positions = torch.arange(height * width, device=grid_hw.device, dtype=torch.long) + pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor + height = grid_thw[1].div(pixel_shuffle_scale, rounding_mode="floor").item() + width = grid_thw[2].div(pixel_shuffle_scale, rounding_mode="floor").item() + token_positions = torch.arange(height * width, device=grid_thw.device, dtype=torch.long) vision_position_ids = torch.stack( ( - torch.full((token_positions.shape[0],), start_position, device=grid_hw.device, dtype=torch.long), + torch.full((token_positions.shape[0],), start_position, device=grid_thw.device, dtype=torch.long), token_positions.div(width, rounding_mode="floor"), token_positions.remainder(width), ), dim=0, ) + token_offset = int(image_metadata[0].item()) + token_length = int(image_metadata[1].item()) return vision_position_ids[:, token_offset : token_offset + token_length] def get_rope_index( self, + input_ids: torch.LongTensor | None, mm_token_type_ids: torch.Tensor, - image_token_grids: torch.Tensor, - image_token_offsets: torch.Tensor, - image_token_lengths: torch.Tensor, - attention_mask: torch.Tensor, - inputs_embeds: torch.Tensor | None = None, + image_grid_thw: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + image_metadata: torch.Tensor | None = None, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: - """Prepare multimodal RoPE positions for the current prefill sequence. + if image_grid_thw is None or image_metadata is None: + raise ValueError("Isaac multimodal RoPE requires both `image_grid_thw` and `image_metadata`.") - Unlike vanilla 1D RoPE, Isaac builds 3-axis indices for text and vision tokens. - If callers do not supply positions, we synthesize text-style positions from - `attention_mask`. The returned `rope_deltas` capture any custom offset between - the attended sequence length and Isaac's multimodal positions so decode steps can - keep counting forward from the cached prefix.""" + if attention_mask is None: + if input_ids is None: + attention_mask = mm_token_type_ids.new_ones(mm_token_type_ids.shape, dtype=torch.long) + else: + attention_mask = input_ids.new_ones(input_ids.shape, dtype=torch.long) + + if input_ids is None: + batch_size, seq_len = attention_mask.shape + position_dtype = torch.long + else: + batch_size, seq_len = input_ids.shape + position_dtype = input_ids.dtype device = attention_mask.device - batch_size, seq_len = attention_mask.shape mm_token_type_ids = mm_token_type_ids.to(dtype=torch.long) - image_token_grids = image_token_grids.to(dtype=torch.long) - image_token_offsets = image_token_offsets.to(dtype=torch.long) - image_token_lengths = image_token_lengths.to(dtype=torch.long) + image_grid_thw = image_grid_thw.to(dtype=torch.long) + image_metadata = image_metadata.to(dtype=torch.long) attention_mask = attention_mask.to(dtype=torch.long) - image_attention_mask = image_token_lengths > 0 + active_slot_mask = image_grid_thw[..., 0].eq(1) - position_ids = torch.zeros((3, batch_size, seq_len), device=device, dtype=torch.long) + position_ids = torch.zeros((3, batch_size, seq_len), device=device, dtype=position_dtype) rope_deltas = torch.zeros((batch_size, 1), device=device, dtype=torch.long) - pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor for batch_idx in range(batch_size): sample_attention_mask = attention_mask[batch_idx].bool() sample_token_types = mm_token_type_ids[batch_idx][sample_attention_mask] - sample_grids = image_token_grids[batch_idx][image_attention_mask[batch_idx]] - sample_offsets = image_token_offsets[batch_idx][image_attention_mask[batch_idx]] - sample_lengths = image_token_lengths[batch_idx][image_attention_mask[batch_idx]] + sample_grids = image_grid_thw[batch_idx] + sample_metadata = image_metadata[batch_idx] + sample_active_slots = active_slot_mask[batch_idx] current_pos = 0 image_idx = 0 @@ -1228,32 +1309,47 @@ def get_rope_index( while seq_pos < sample_token_types.shape[0]: modality_type = int(sample_token_types[seq_pos].item()) - group_end = seq_pos + 1 - while group_end < sample_token_types.shape[0] and sample_token_types[group_end] == modality_type: - group_end += 1 - - group_length = group_end - seq_pos if modality_type == 0: + group_end = seq_pos + 1 + while group_end < sample_token_types.shape[0] and sample_token_types[group_end] == 0: + group_end += 1 + group_length = group_end - seq_pos llm_pos_ids_list.append( torch.arange(group_length, device=device, dtype=torch.long).view(1, -1).expand(3, -1) + current_pos ) current_pos += group_length + seq_pos = group_end else: - grid_hw = sample_grids[image_idx].div(pixel_shuffle_scale, rounding_mode="floor") - token_offset = int(sample_offsets[image_idx].item()) - token_length = int(sample_lengths[image_idx].item()) + while image_idx < sample_metadata.shape[0] and ( + not bool(sample_active_slots[image_idx].item()) or sample_metadata[image_idx, 1].item() == 0 + ): + image_idx += 1 + torch_compilable_check( + image_idx < sample_metadata.shape[0], + "Isaac multimodal sequence has more visible image tokens than batch-major image metadata slots.", + ) + token_length = int(sample_metadata[image_idx, 1].item()) + torch_compilable_check( + token_length <= sample_token_types.shape[0] - seq_pos, + "Isaac image metadata length exceeds the remaining multimodal placeholder span.", + ) llm_pos_ids_list.append( - self.get_vision_position_ids(current_pos, grid_hw, token_offset, token_length) + self.get_vision_position_ids(current_pos, sample_grids[image_idx], sample_metadata[image_idx]) ) current_pos += 1 + seq_pos += token_length image_idx += 1 - seq_pos = group_end - - llm_positions = torch.cat(llm_pos_ids_list, dim=1) + llm_positions = ( + torch.cat(llm_pos_ids_list, dim=1) + if llm_pos_ids_list + else torch.zeros((3, 0), device=device, dtype=torch.long) + ) position_ids[:, batch_idx, sample_attention_mask] = llm_positions - rope_deltas[batch_idx, 0] = llm_positions.max() + 1 - sample_token_types.shape[0] + rope_deltas[batch_idx, 0] = ( + llm_positions.max() + 1 - sample_token_types.shape[0] if llm_positions.numel() > 0 else 0 + ) return position_ids, rope_deltas @@ -1263,21 +1359,30 @@ def compute_3d_position_ids( inputs_embeds: torch.Tensor, attention_mask: torch.Tensor | None, mm_token_type_ids: torch.Tensor | None = None, - image_token_grids: torch.Tensor | None = None, - image_token_offsets: torch.Tensor | None = None, - image_token_lengths: torch.Tensor | None = None, - past_key_values: torch.Tensor | None = None, + image_grid_thw: torch.Tensor | None = None, + image_metadata: torch.Tensor | None = None, + past_key_values: Cache | None = None, ) -> torch.Tensor: past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() - if image_token_lengths is not None and image_token_lengths.gt(0).any() and past_seen_tokens == 0: + has_multimodal = ( + image_grid_thw is not None + and image_metadata is not None + and bool(image_grid_thw[..., 0].eq(1).any().item()) + ) + if has_multimodal and mm_token_type_ids is None and input_ids is not None: + raise ValueError( + "Multimodal data was passed (via `image_grid_thw`) but `mm_token_type_ids` is missing. " + "Please pass `mm_token_type_ids` so Isaac can build multimodal RoPE positions." + ) + + if has_multimodal and past_seen_tokens == 0: position_ids, rope_deltas = self.get_rope_index( + input_ids=input_ids, mm_token_type_ids=mm_token_type_ids, - image_token_grids=image_token_grids, - image_token_offsets=image_token_offsets, - image_token_lengths=image_token_lengths, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, attention_mask=attention_mask, - inputs_embeds=inputs_embeds, ) self.rope_deltas = rope_deltas return position_ids @@ -1321,15 +1426,12 @@ def forward( self, input_ids: torch.LongTensor | None = None, mm_token_type_ids: torch.LongTensor | None = None, - vision_patches: torch.Tensor | None = None, - image_patch_attention_mask: torch.Tensor | None = None, - vision_token_grids: torch.LongTensor | None = None, - image_token_grids: torch.LongTensor | None = None, - vision_token_offsets: torch.LongTensor | None = None, - vision_token_lengths: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + image_metadata: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | None = None, + past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], @@ -1337,48 +1439,56 @@ def forward( """ Args: mm_token_type_ids (`torch.LongTensor`, *optional*): - Multimodal token type ids aligned with the embedded sequence, shaped `(batch_size, seq_len)`. Isaac - follows the standard convention `0 -> text`, `1 -> image`. Treated as text-only when omitted. - vision_patches (`torch.FloatTensor`, *optional*): - Padded per-image patch vectors of shape `(batch_size, max_images, max_patches, patch_dim)`. - image_patch_attention_mask (`torch.LongTensor`, *optional*): - Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. - vision_token_grids (`torch.LongTensor`, *optional*): - Per-image patch grids `(h, w)` with shape `(batch_size, max_images, 2)`. - image_token_grids (`torch.LongTensor`, *optional*): - Alias for `vision_token_grids`. - vision_token_offsets (`torch.LongTensor`, *optional*): - Start offsets inside the per-image vision embedding sequence, shape `(batch_size, max_images)`. - vision_token_lengths (`torch.LongTensor`, *optional*): - Number of vision tokens to consume per image, shape `(batch_size, max_images)`. + Multimodal token type ids aligned with the token sequence, using `0 -> text` and `1 -> image`. + pixel_values (`torch.FloatTensor`, *optional*): + Batch-major patch vectors shaped `(batch_size, max_images, max_patches, patch_dim)`. + image_grid_thw (`torch.LongTensor`, *optional*): + Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. + image_metadata (`torch.LongTensor`, *optional*): + Batch-major per-slot metadata shaped `(batch_size, max_images, 2)` with `(offset, length)`. """ - created_inputs_embeds = inputs_embeds is None - if created_inputs_embeds: - inputs_embeds = self.text_model.embed_tokens(input_ids) + if (input_ids is None) == (inputs_embeds is None): + raise ValueError("You must specify exactly one of `input_ids` or `inputs_embeds`.") + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + batch_size, seq_len = inputs_embeds.shape[:2] if mm_token_type_ids is None: - batch_size, seq_len = inputs_embeds.shape[:2] mm_token_type_ids = torch.full((batch_size, seq_len), 0, device=inputs_embeds.device, dtype=torch.long) else: - mm_token_type_ids = mm_token_type_ids.to(dtype=torch.long) - - image_token_mask = mm_token_type_ids == 1 - if created_inputs_embeds and torch.any(image_token_mask): + mm_token_type_ids = mm_token_type_ids.to(device=inputs_embeds.device, dtype=torch.long) + if mm_token_type_ids.shape[1] < seq_len: + padding = mm_token_type_ids.new_zeros((batch_size, seq_len - mm_token_type_ids.shape[1])) + mm_token_type_ids = torch.cat([mm_token_type_ids, padding], dim=1) + elif mm_token_type_ids.shape[1] > seq_len: + mm_token_type_ids = mm_token_type_ids[:, -seq_len:] + + if image_metadata is not None: + image_metadata = image_metadata.to(device=inputs_embeds.device, dtype=torch.long) + + image_mask = None + has_active_images = ( + pixel_values is not None and image_grid_thw is not None and bool(image_grid_thw[..., 0].eq(1).any().item()) + ) + if has_active_images: image_outputs = self.get_image_features( - pixel_values=vision_patches, - image_token_grids=vision_token_grids, - image_patch_attention_mask=image_patch_attention_mask, - image_token_offsets=vision_token_offsets, - image_token_lengths=vision_token_lengths, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, return_dict=True, ) - image_features = image_outputs.pooler_output.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype) - scatter_mask = self.get_placeholder_mask( - mm_token_type_ids=mm_token_type_ids, - inputs_embeds=inputs_embeds, - image_features=image_features, - ) - inputs_embeds = inputs_embeds.masked_scatter(scatter_mask, image_features) + image_embeds = image_outputs.pooler_output + if len(image_embeds) > 0: + image_embeds = torch.cat(image_embeds, dim=0).to( + device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + image_mask = self.get_placeholder_mask( + mm_token_type_ids=mm_token_type_ids, + inputs_embeds=inputs_embeds, + image_features=image_embeds, + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if isinstance(attention_mask, dict): attention_mask = attention_mask["full_attention"] @@ -1388,9 +1498,8 @@ def forward( input_ids=input_ids, inputs_embeds=inputs_embeds, mm_token_type_ids=mm_token_type_ids, - image_token_grids=vision_token_grids, - image_token_offsets=vision_token_offsets, - image_token_lengths=vision_token_lengths, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, attention_mask=attention_mask, past_key_values=past_key_values, ) @@ -1403,151 +1512,99 @@ def forward( if position_ids.ndim == 2: position_ids = position_ids.view(1, position_ids.shape[0], -1).expand(3, -1, -1) - text_model_outputs = self.text_model( + outputs = self.language_model( + input_ids=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, + visual_pos_masks=image_mask[..., 0] if image_mask is not None else None, + deepstack_visual_embeds=None, use_cache=use_cache, **kwargs, ) - return BaseModelOutputWithPast( - last_hidden_state=text_model_outputs.last_hidden_state, - past_key_values=text_model_outputs.past_key_values, - hidden_states=text_model_outputs.hidden_states, - attentions=text_model_outputs.attentions, + outputs_with_rope = BaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) + outputs_with_rope["rope_deltas"] = self.rope_deltas + return outputs_with_rope @auto_docstring -class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): +class IsaacForConditionalGeneration(Qwen3VLForConditionalGeneration, GenerationMixin): config_class = IsaacConfig + input_modalities = ("image", "text") + _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionEncoderLayer"] _can_compile_fullgraph = False - _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"} + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: IsaacConfig): - super().__init__(config) + PreTrainedModel.__init__(self, config) self.model = IsaacModel(config) self.vocab_size = config.get_text_config().vocab_size self.lm_head = nn.Linear(config.get_text_config().hidden_size, config.get_text_config().vocab_size, bias=False) - - @auto_docstring - @can_return_tuple - def forward( - self, - input_ids: torch.LongTensor | None = None, - mm_token_type_ids: torch.LongTensor | None = None, - vision_patches: torch.Tensor | None = None, - pixel_values: torch.Tensor | None = None, - image_patch_attention_mask: torch.Tensor | None = None, - vision_token_grids: torch.LongTensor | None = None, - image_token_grids: torch.LongTensor | None = None, - vision_token_offsets: torch.LongTensor | None = None, - vision_token_lengths: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - **kwargs: Unpack[TransformersKwargs], - ) -> tuple | CausalLMOutputWithPast: - r""" - mm_token_type_ids (`torch.LongTensor`, *optional*): - Multimodal token type ids aligned with the token sequence, shaped `(batch_size, seq_len)`, using - `0 -> text` and `1 -> image`. - vision_patches (`torch.FloatTensor`, *optional*): - Padded per-image patch vectors of shape `(batch_size, max_images, max_patches, patch_dim)`. - pixel_values (`torch.FloatTensor`, *optional*): - Alias for `vision_patches` accepted by generic image-feature and generation helpers. - image_patch_attention_mask (`torch.LongTensor`, *optional*): - Mask for valid patch entries in `vision_patches`, shaped `(batch_size, max_images, max_patches)`. - vision_token_grids (`torch.LongTensor`, *optional*): - Per-image patch grids `(h, w)` with shape `(batch_size, max_images, 2)`. - image_token_grids (`torch.LongTensor`, *optional*): - Alias for `vision_token_grids`. - vision_token_offsets (`torch.LongTensor`, *optional*): - Start offsets inside the per-image vision embedding sequence, shape `(batch_size, max_images)`. - vision_token_lengths (`torch.LongTensor`, *optional*): - Number of vision tokens to consume per image, shape `(batch_size, max_images)`. - """ - outputs = self.model( - input_ids=input_ids, - mm_token_type_ids=mm_token_type_ids, - vision_patches=vision_patches, - image_patch_attention_mask=image_patch_attention_mask, - vision_token_grids=vision_token_grids, - vision_token_offsets=vision_token_offsets, - vision_token_lengths=vision_token_lengths, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - **kwargs, - ) - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=logits.shape[-1]) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) + self.post_init() def prepare_inputs_for_generation( self, - input_ids: torch.LongTensor, - past_key_values: list[torch.FloatTensor] | None = None, + input_ids, + past_key_values=None, attention_mask: torch.Tensor | None = None, inputs_embeds: torch.FloatTensor | None = None, mm_token_type_ids: torch.LongTensor | None = None, - vision_patches: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None, - image_patch_attention_mask: torch.Tensor | None = None, - vision_token_grids: torch.LongTensor | None = None, - image_token_grids: torch.LongTensor | None = None, - vision_token_offsets: torch.LongTensor | None = None, - vision_token_lengths: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + image_metadata: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, is_first_iteration=False, use_cache=True, **kwargs, ) -> dict[str, Any]: - if vision_patches is None: - vision_token_grids = image_token_grids if vision_token_grids is None else vision_token_grids model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, position_ids=position_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, is_first_iteration=is_first_iteration, use_cache=use_cache, **kwargs, ) + is_prefill = is_first_iteration or not use_cache multimodal_inputs = { "mm_token_type_ids": mm_token_type_ids, - "vision_patches": vision_patches, - "image_patch_attention_mask": image_patch_attention_mask, - "vision_token_grids": vision_token_grids, - "vision_token_offsets": vision_token_offsets, - "vision_token_lengths": vision_token_lengths, + "pixel_values": pixel_values, + "image_grid_thw": image_grid_thw, + "image_metadata": image_metadata, } - is_prefill = is_first_iteration or not use_cache for key, value in multimodal_inputs.items(): model_inputs[key] = value if is_prefill else None - + if model_inputs["mm_token_type_ids"] is not None: + sequence_length = None + if model_inputs.get("input_ids") is not None: + sequence_length = model_inputs["input_ids"].shape[1] + elif model_inputs.get("inputs_embeds") is not None: + sequence_length = model_inputs["inputs_embeds"].shape[1] + + if sequence_length is not None: + current_length = model_inputs["mm_token_type_ids"].shape[1] + if current_length < sequence_length: + padding = model_inputs["mm_token_type_ids"].new_zeros( + (model_inputs["mm_token_type_ids"].shape[0], sequence_length - current_length) + ) + model_inputs["mm_token_type_ids"] = torch.cat([model_inputs["mm_token_type_ids"], padding], dim=1) + elif current_length > sequence_length: + model_inputs["mm_token_type_ids"] = model_inputs["mm_token_type_ids"][:, -sequence_length:] return model_inputs def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs): - text_positions = super()._prepare_position_ids_for_generation(inputs_tensor, model_kwargs) + text_positions = GenerationMixin._prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs) past_length = 0 if (cache := model_kwargs.get("past_key_values")) is not None: @@ -1558,18 +1615,15 @@ def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs): if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0: inputs_tensor = model_kwargs["input_ids"] + is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long] if ( - model_kwargs.get("vision_token_lengths") is not None - and len(inputs_tensor.shape) == 2 - and inputs_tensor.dtype in [torch.int, torch.long] + is_input_ids + and model_kwargs.get("mm_token_type_ids") is not None + and model_kwargs.get("image_grid_thw") is not None + and model_kwargs.get("image_metadata") is not None ): - vision_positions, rope_deltas = self.model.get_rope_index( - mm_token_type_ids=model_kwargs["mm_token_type_ids"], - image_token_grids=model_kwargs["vision_token_grids"], - image_token_offsets=model_kwargs["vision_token_offsets"], - image_token_lengths=model_kwargs["vision_token_lengths"], - attention_mask=model_kwargs.get("attention_mask"), - ) + model_kwargs = {k: v for k, v in model_kwargs.items() if k != "input_ids"} + vision_positions, rope_deltas = self.model.get_rope_index(inputs_tensor, **model_kwargs) self.model.rope_deltas = rope_deltas else: vision_positions = text_positions.unsqueeze(0).expand(3, -1, -1) @@ -1587,12 +1641,26 @@ def _expand_inputs_for_generation( **model_kwargs, ) -> tuple[torch.LongTensor, dict[str, Any]]: position_ids = model_kwargs.pop("position_ids", None) - input_ids, model_kwargs = super()._expand_inputs_for_generation( - expand_size=expand_size, - is_encoder_decoder=is_encoder_decoder, - input_ids=input_ids, - **model_kwargs, - ) + if expand_size == 1: + if position_ids is not None: + model_kwargs["position_ids"] = position_ids + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "image_metadata"] + for key in visual_keys: + value = model_kwargs.get(key) + if value is not None: + model_kwargs[key] = value.repeat_interleave(expand_size, dim=0) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + for key, value in list(model_kwargs.items()): + if key == "position_ids" and value is not None and value.ndim == 3: + model_kwargs[key] = value.repeat_interleave(expand_size, dim=1) + elif value is not None and isinstance(value, torch.Tensor) and key not in visual_keys: + model_kwargs[key] = value.repeat_interleave(expand_size, dim=0) + if position_ids is not None: dim = 1 if position_ids.ndim == 3 else 0 model_kwargs["position_ids"] = position_ids.repeat_interleave(expand_size, dim=dim) diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index ddc36bc2abac..a9ec49908530 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -22,7 +22,7 @@ from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, make_nested_list_of_images -from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin from ...utils import TensorType, auto_docstring from ...utils.import_utils import is_torch_available from .modeling_isaac import BoundingBox, Polygon, SinglePoint @@ -37,6 +37,7 @@ class IsaacProcessorKwargs(ProcessingKwargs, total=False): "text_kwargs": { "padding": True, "return_attention_mask": True, + "return_mm_token_type_ids": True, }, } @@ -130,15 +131,12 @@ def __init__( image_processor, tokenizer, chat_template: str | dict[str, str] | None = None, - vision_token: str = "", max_sequence_length: int = 16384, ): """ Args: chat_template (`str` or `dict[str, str]`, *optional*): Chat template override forwarded to [`~processing_utils.ProcessorMixin`]. - vision_token (`str`, *optional*, defaults to `""`): - Placeholder token used inside text prompts to mark image positions. max_sequence_length (`int`, *optional*, defaults to 16384): Maximum packed multimodal sequence length produced by the processor. """ @@ -148,13 +146,33 @@ def __init__( self.image_processor = image_processor super().__init__(image_processor, tokenizer, chat_template=chat_template) self.text_pad_token_id = self.pad_token_id = tokenizer.pad_token_id - self.image_pad_token_id = tokenizer.image_pad_token_id - self.image_token = tokenizer.image_pad_token - self.image_token_id = self.image_pad_token_id + self.image_token = getattr(tokenizer, "image_pad_token", None) or getattr(tokenizer, "image_token", None) + self.image_token_id = getattr(tokenizer, "image_pad_token_id", None) or getattr( + tokenizer, "image_token_id", None + ) - self.vision_token = vision_token self.max_sequence_length = max_sequence_length + @property + def model_input_names(self): + return super().model_input_names + ["mm_token_type_ids", "image_metadata"] + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + vision_data = {} + if image_sizes is not None: + images_kwargs = dict(IsaacProcessorKwargs._defaults.get("images_kwargs", {})) + images_kwargs.update(kwargs) + + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) + for image_size in image_sizes + ] + pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale") or self.image_processor.pixel_shuffle_scale + num_image_tokens = [num_patches // pixel_shuffle_scale**2 for num_patches in num_image_patches] + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + def post_process_generation( self, text: str, @@ -197,8 +215,10 @@ def __call__( padding = text_kwargs.pop("padding", True) padding_side = text_kwargs.pop("padding_side", "left") return_attention_mask = text_kwargs.pop("return_attention_mask", True) + return_mm_token_type_ids = text_kwargs.pop("return_mm_token_type_ids", True) pad_to_multiple_of = text_kwargs.pop("pad_to_multiple_of", None) text_kwargs.pop("return_tensors", None) + text_kwargs.pop("return_overflowing_tokens", None) text_kwargs.setdefault("add_special_tokens", False) texts = [text] if isinstance(text, str) else text @@ -208,7 +228,7 @@ def __call__( fetched_images = self.image_processor.fetch_images(images) batched_images = make_nested_list_of_images(fetched_images) if len(batched_images) != len(texts): - num_images_in_text = [text_value.count(self.vision_token) for text_value in texts] + num_images_in_text = [text_value.count(self.image_token) for text_value in texts] num_images_in_images = [len(sample_images) for sample_images in batched_images] add_message = "" if sum(num_images_in_text) == sum(num_images_in_images): @@ -220,19 +240,23 @@ def __call__( pairs = list(zip(texts, batched_images, strict=True)) image_inputs = self.image_processor(images=batched_images, return_tensors=TensorType.PYTORCH) - vision_token_grids = image_inputs["vision_token_grids"] - vision_segment_lengths = (vision_token_grids[..., 0] // self.image_processor.pixel_shuffle_scale) * ( - vision_token_grids[..., 1] // self.image_processor.pixel_shuffle_scale - ) - vision_token_offsets = torch.zeros_like(vision_segment_lengths) - vision_token_lengths = torch.zeros_like(vision_segment_lengths) + image_grid_thw = image_inputs["image_grid_thw"] + image_metadata = None + vision_segment_lengths = None + if image_grid_thw is not None: + batch_size, max_images = image_grid_thw.shape[:2] + image_metadata = torch.zeros((batch_size, max_images, 2), dtype=torch.long) + grid_heights = image_grid_thw[..., 1] + grid_widths = image_grid_thw[..., 2] + vision_segment_lengths = (grid_heights // self.image_processor.pixel_shuffle_scale) * ( + grid_widths // self.image_processor.pixel_shuffle_scale + ) - sample_input_ids: list[torch.Tensor] = [] expanded_texts = [] expected_image_lengths_per_sample = [] for batch_idx, (text_value, sample_images) in enumerate(pairs): - segments = text_value.split(self.vision_token) + segments = text_value.split(self.image_token) num_images = len(segments) - 1 num_provided_images = len(sample_images) if num_images != num_provided_images: @@ -240,81 +264,91 @@ def __call__( f"IsaacProcessor expects one image per image token, got {num_images} tokens and {num_provided_images} images in sample with text {text_value} " ) - expected_image_lengths = [ - int(vision_segment_lengths[batch_idx, image_idx].item()) for image_idx in range(num_images) - ] - expected_image_lengths_per_sample.append(expected_image_lengths) - - expanded_text = segments[0] - for image_idx, segment_length in enumerate(expected_image_lengths): - expanded_text += (self.image_token * segment_length) + segments[image_idx + 1] - expanded_texts.append(expanded_text) + expected_image_lengths = [] + expanded_text_parts = [segments[0]] + for image_idx in range(num_images): + segment_length = int(vision_segment_lengths[batch_idx, image_idx].item()) + expected_image_lengths.append(segment_length) + expanded_text_parts.append(self.image_token * segment_length) + expanded_text_parts.append(segments[image_idx + 1]) - tokenized_text_inputs = self.tokenizer(expanded_texts, return_tensors=None, **text_kwargs) - self._check_special_mm_tokens(expanded_texts, tokenized_text_inputs, modalities=["image"]) + expected_image_lengths_per_sample.append(expected_image_lengths) + expanded_texts.append("".join(expanded_text_parts)) effective_max_length = self.max_sequence_length - if truncation and max_length is not None: + if max_length is not None and (truncation or padding == "max_length"): effective_max_length = max_length - for batch_idx, (expected_image_lengths, sample_input_ids_list) in enumerate( - zip(expected_image_lengths_per_sample, tokenized_text_inputs["input_ids"], strict=True) - ): - sample_input = torch.tensor(sample_input_ids_list, dtype=torch.long) - image_positions = sample_input.eq(self.image_pad_token_id).nonzero(as_tuple=False).flatten() - image_spans = image_positions.split(expected_image_lengths) if expected_image_lengths else () - image_bounds = [] - - for image_idx, (segment_length, image_span) in enumerate( - zip(expected_image_lengths, image_spans, strict=True) - ): - image_start = int(image_span[0].item()) - image_end = int(image_span[-1].item()) + 1 - image_bounds.append((image_start, image_end)) - total = int(sample_input.shape[0]) - start = max(0, total - effective_max_length) - sample_input_ids.append(sample_input[start:]) - - for image_idx, (image_start, image_end) in enumerate(image_bounds): - kept_start = max(start, image_start) - kept_end = min(total, image_end) - if kept_end > kept_start: - vision_token_offsets[batch_idx, image_idx] = kept_start - image_start - vision_token_lengths[batch_idx, image_idx] = kept_end - kept_start - - # Pad only after Isaac-specific truncation so image span offsets and lengths stay aligned. - padded_text_inputs = self.tokenizer.pad( - {"input_ids": [sample_input.tolist() for sample_input in sample_input_ids]}, + self.tokenizer.truncation_side = "left" + self.tokenizer.padding_side = padding_side + tokenized_text_inputs = self.tokenizer( + expanded_texts, + truncation=True, + max_length=effective_max_length, padding=padding, - max_length=max_length if padding == "max_length" else None, pad_to_multiple_of=pad_to_multiple_of, - padding_side=padding_side, return_attention_mask=return_attention_mask, - return_tensors=TensorType.PYTORCH, + return_overflowing_tokens=True, + stride=0, + return_tensors=None, + **text_kwargs, ) - input_ids = padded_text_inputs["input_ids"] - attention_mask = padded_text_inputs.get("attention_mask") - if attention_mask is None: - attention_mask = input_ids.ne(self.pad_token_id).to(dtype=torch.long) - mm_token_type_ids = input_ids.eq(self.image_pad_token_id).to(dtype=torch.long) - vision_image_attention_mask = vision_token_lengths.gt(0).to(dtype=torch.long) + kept_input_ids_per_sample: list[list[int] | None] = [None] * len(texts) + overflow_input_ids_per_sample: list[list[list[int]]] = [[] for _ in texts] + overflow_to_sample_mapping = tokenized_text_inputs.get("overflow_to_sample_mapping") + if overflow_to_sample_mapping is None: + overflow_to_sample_mapping = list(range(len(tokenized_text_inputs["input_ids"]))) - vision_patches = image_inputs["vision_patches"] - image_patch_attention_mask = image_inputs["image_patch_attention_mask"] + for row_input_ids, sample_idx in zip( + tokenized_text_inputs["input_ids"], overflow_to_sample_mapping, strict=True + ): + sample_idx = int(sample_idx) + if kept_input_ids_per_sample[sample_idx] is None: + kept_input_ids_per_sample[sample_idx] = row_input_ids + else: + overflow_input_ids_per_sample[sample_idx].append(row_input_ids) + + for batch_idx, expected_image_lengths in enumerate(expected_image_lengths_per_sample): + dropped_image_tokens = sum( + overflow_input_ids.count(self.image_token_id) + for overflow_input_ids in overflow_input_ids_per_sample[batch_idx] + ) + + remaining_dropped = dropped_image_tokens + for image_idx, expected_length in enumerate(expected_image_lengths): + if remaining_dropped <= 0: + offset = 0 + length = expected_length + elif remaining_dropped < expected_length: + offset = remaining_dropped + length = expected_length - offset + remaining_dropped = 0 + else: + offset = 0 + length = 0 + remaining_dropped -= expected_length + + # Record which suffix of this image's placeholder span survives left truncation. + # The model still encodes the full image and uses this window for both feature gathering and vision RoPE. + image_metadata[batch_idx, image_idx, 0] = offset + image_metadata[batch_idx, image_idx, 1] = length + + input_ids = torch.tensor(kept_input_ids_per_sample, dtype=torch.long) + attention_mask = input_ids.ne(self.pad_token_id).to(dtype=torch.long) + + data = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": image_inputs["pixel_values"], + "image_grid_thw": image_grid_thw, + "image_metadata": image_metadata, + } + if return_mm_token_type_ids: + data["mm_token_type_ids"] = input_ids.eq(self.image_token_id).to(dtype=torch.long) return BatchFeature( - data={ - "input_ids": input_ids, - "attention_mask": attention_mask, - "mm_token_type_ids": mm_token_type_ids, - "vision_patches": vision_patches, - "image_patch_attention_mask": image_patch_attention_mask, - "vision_token_grids": vision_token_grids, - "vision_token_offsets": vision_token_offsets, - "vision_token_lengths": vision_token_lengths, - "vision_image_attention_mask": vision_image_attention_mask, - }, + data=data, tensor_type=return_tensors, ) diff --git a/tests/models/isaac/test_image_processing_isaac.py b/tests/models/isaac/test_image_processing_isaac.py index 5884a8bf1e8f..b03ec3337972 100644 --- a/tests/models/isaac/test_image_processing_isaac.py +++ b/tests/models/isaac/test_image_processing_isaac.py @@ -138,51 +138,57 @@ def _assert_output_contract( encoding, *, expected_batch_size=None, - expected_image_slots=None, + expected_max_images=None, expected_patch_dim=None, ): - self.assertEqual( - set(encoding.keys()), - {"vision_patches", "image_patch_attention_mask", "vision_token_grids"}, - ) + self.assertEqual(set(encoding.keys()), {"pixel_values", "image_grid_thw"}) + + pixel_values = encoding["pixel_values"] + image_grid_thw = encoding["image_grid_thw"] - vision_patches = encoding["vision_patches"] - image_patch_attention_mask = encoding["image_patch_attention_mask"] - vision_token_grids = encoding["vision_token_grids"] + if expected_batch_size is None: + self.assertIsNone(pixel_values) + self.assertIsNone(image_grid_thw) + return - self.assertEqual(vision_patches.dtype, torch.float32) - self.assertEqual(image_patch_attention_mask.dtype, torch.long) - self.assertEqual(vision_token_grids.dtype, torch.long) + self.assertIsNotNone(pixel_values) + self.assertIsNotNone(image_grid_thw) + self.assertEqual(pixel_values.dtype, torch.float32) + self.assertEqual(image_grid_thw.dtype, torch.long) if expected_batch_size is not None: - self.assertEqual(vision_patches.shape[0], expected_batch_size) - if expected_image_slots is not None: - self.assertEqual(vision_patches.shape[1], expected_image_slots) + self.assertEqual(pixel_values.shape[0], expected_batch_size) + self.assertEqual(image_grid_thw.shape[0], expected_batch_size) + if expected_max_images is not None: + self.assertEqual(pixel_values.shape[1], expected_max_images) + self.assertEqual(image_grid_thw.shape[1], expected_max_images) if expected_patch_dim is not None: - self.assertEqual(vision_patches.shape[-1], expected_patch_dim) + self.assertEqual(pixel_values.shape[-1], expected_patch_dim) + + self.assertEqual(tuple(image_grid_thw.shape), (pixel_values.shape[0], pixel_values.shape[1], 3)) - self.assertEqual(tuple(image_patch_attention_mask.shape), tuple(vision_patches.shape[:-1])) - self.assertEqual(tuple(vision_token_grids.shape), tuple(vision_patches.shape[:2]) + (2,)) + active_slots = image_grid_thw[..., 0].eq(1) + self.assertTrue(torch.all(image_grid_thw[~active_slots].eq(0))) + self.assertTrue(torch.all(image_grid_thw[active_slots, 1:] > 0)) - expected_patch_counts = torch.prod(vision_token_grids, dim=-1) - torch.testing.assert_close(image_patch_attention_mask.sum(dim=-1), expected_patch_counts) + expected_patch_counts = image_grid_thw[..., 1] * image_grid_thw[..., 2] + token_positions = torch.arange(pixel_values.shape[2], device=pixel_values.device).view(1, 1, -1) + image_patch_attention_mask = active_slots.unsqueeze(-1) & token_positions.lt( + expected_patch_counts.unsqueeze(-1) + ) - padded_patch_rows = vision_patches[image_patch_attention_mask == 0] + padded_patch_rows = pixel_values[~image_patch_attention_mask] if padded_patch_rows.numel() > 0: self.assertTrue(torch.all(padded_patch_rows == 0)) def _assert_encoding_close(self, eager_encoding, compiled_encoding): torch.testing.assert_close( - eager_encoding["vision_patches"], - compiled_encoding["vision_patches"], + eager_encoding["pixel_values"], + compiled_encoding["pixel_values"], atol=1e-4, rtol=1e-4, ) - torch.testing.assert_close( - eager_encoding["image_patch_attention_mask"], - compiled_encoding["image_patch_attention_mask"], - ) - torch.testing.assert_close(eager_encoding["vision_token_grids"], compiled_encoding["vision_token_grids"]) + torch.testing.assert_close(eager_encoding["image_grid_thw"], compiled_encoding["image_grid_thw"]) def test_image_processor_properties(self): for image_processing_class in self.image_processing_classes.values(): @@ -211,7 +217,7 @@ def test_call_pil(self): self._assert_output_contract( single_output, expected_batch_size=1, - expected_image_slots=1, + expected_max_images=1, expected_patch_dim=self.image_processor_tester.patch_dim, ) @@ -219,7 +225,7 @@ def test_call_pil(self): self._assert_output_contract( batched_output, expected_batch_size=self.image_processor_tester.batch_size, - expected_image_slots=1, + expected_max_images=1, expected_patch_dim=self.image_processor_tester.patch_dim, ) @@ -235,7 +241,7 @@ def test_call_numpy(self): self._assert_output_contract( single_output, expected_batch_size=1, - expected_image_slots=1, + expected_max_images=1, expected_patch_dim=self.image_processor_tester.patch_dim, ) @@ -243,7 +249,7 @@ def test_call_numpy(self): self._assert_output_contract( batched_output, expected_batch_size=self.image_processor_tester.batch_size, - expected_image_slots=1, + expected_max_images=1, expected_patch_dim=self.image_processor_tester.patch_dim, ) @@ -259,7 +265,7 @@ def test_call_pytorch(self): self._assert_output_contract( single_output, expected_batch_size=1, - expected_image_slots=1, + expected_max_images=1, expected_patch_dim=self.image_processor_tester.patch_dim, ) @@ -267,7 +273,7 @@ def test_call_pytorch(self): self._assert_output_contract( batched_output, expected_batch_size=self.image_processor_tester.batch_size, - expected_image_slots=1, + expected_max_images=1, expected_patch_dim=self.image_processor_tester.patch_dim, ) @@ -299,42 +305,27 @@ def test_nested_multi_image_batch_preserves_grids_and_padding(self): self._assert_output_contract( encoding, expected_batch_size=2, - expected_image_slots=2, + expected_max_images=2, expected_patch_dim=768, ) + self.assertEqual(tuple(encoding["pixel_values"].shape), (2, 2, 6, 768)) expected_grids = torch.tensor( [ - [[2, 2], [0, 0]], - [[2, 3], [3, 2]], - ], - dtype=torch.long, - ) - expected_patch_counts = torch.tensor( - [ - [4, 0], - [6, 6], + [[1, 2, 2], [0, 0, 0]], + [[1, 2, 3], [1, 3, 2]], ], dtype=torch.long, ) - torch.testing.assert_close(encoding["vision_token_grids"], expected_grids) - torch.testing.assert_close(encoding["image_patch_attention_mask"].sum(dim=-1), expected_patch_counts) + torch.testing.assert_close(encoding["image_grid_thw"], expected_grids) - def test_all_empty_images_returns_zero_sized_tensors(self): + def test_all_empty_images_returns_none_visual_fields(self): for image_processing_class in self.image_processing_classes.values(): image_processor = image_processing_class(**self.image_processor_dict) encoding = image_processor([[], []], return_tensors="pt") - self.assertEqual( - set(encoding.keys()), {"vision_patches", "image_patch_attention_mask", "vision_token_grids"} - ) - self.assertEqual(tuple(encoding["vision_patches"].shape), (2, 0, 0, 0)) - self.assertEqual(tuple(encoding["image_patch_attention_mask"].shape), (2, 0, 0)) - self.assertEqual(tuple(encoding["vision_token_grids"].shape), (2, 0, 2)) - self.assertEqual(encoding["vision_patches"].dtype, torch.float32) - self.assertEqual(encoding["image_patch_attention_mask"].dtype, torch.long) - self.assertEqual(encoding["vision_token_grids"].dtype, torch.long) + self._assert_output_contract(encoding, expected_batch_size=None) def test_do_resize_false_requires_patch_divisibility(self): for image_processing_class in self.image_processing_classes.values(): @@ -369,22 +360,19 @@ def test_cast_dtype_device(self): image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) encoding = image_processor(image_inputs, return_tensors="pt") - self.assertEqual(encoding["vision_patches"].device, torch.device("cpu")) - self.assertEqual(encoding["vision_patches"].dtype, torch.float32) - self.assertEqual(encoding["image_patch_attention_mask"].dtype, torch.long) - self.assertEqual(encoding["vision_token_grids"].dtype, torch.long) + self.assertEqual(encoding["pixel_values"].device, torch.device("cpu")) + self.assertEqual(encoding["pixel_values"].dtype, torch.float32) + self.assertEqual(encoding["image_grid_thw"].dtype, torch.long) encoding = image_processor(image_inputs, return_tensors="pt").to(torch.float16) - self.assertEqual(encoding["vision_patches"].device, torch.device("cpu")) - self.assertEqual(encoding["vision_patches"].dtype, torch.float16) - self.assertEqual(encoding["image_patch_attention_mask"].dtype, torch.long) - self.assertEqual(encoding["vision_token_grids"].dtype, torch.long) + self.assertEqual(encoding["pixel_values"].device, torch.device("cpu")) + self.assertEqual(encoding["pixel_values"].dtype, torch.float16) + self.assertEqual(encoding["image_grid_thw"].dtype, torch.long) encoding = image_processor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16) - self.assertEqual(encoding["vision_patches"].device, torch.device("cpu")) - self.assertEqual(encoding["vision_patches"].dtype, torch.bfloat16) - self.assertEqual(encoding["image_patch_attention_mask"].dtype, torch.long) - self.assertEqual(encoding["vision_token_grids"].dtype, torch.long) + self.assertEqual(encoding["pixel_values"].device, torch.device("cpu")) + self.assertEqual(encoding["pixel_values"].dtype, torch.bfloat16) + self.assertEqual(encoding["image_grid_thw"].dtype, torch.long) with self.assertRaises(TypeError): _ = image_processor(image_inputs, return_tensors="pt").to(torch.bfloat16, "cpu") @@ -393,10 +381,9 @@ def test_cast_dtype_device(self): encoding.update({"input_ids": torch.LongTensor([[1, 2, 3], [4, 5, 6]])}) encoding = encoding.to(torch.float16) - self.assertEqual(encoding["vision_patches"].device, torch.device("cpu")) - self.assertEqual(encoding["vision_patches"].dtype, torch.float16) - self.assertEqual(encoding["image_patch_attention_mask"].dtype, torch.long) - self.assertEqual(encoding["vision_token_grids"].dtype, torch.long) + self.assertEqual(encoding["pixel_values"].device, torch.device("cpu")) + self.assertEqual(encoding["pixel_values"].dtype, torch.float16) + self.assertEqual(encoding["image_grid_thw"].dtype, torch.long) self.assertEqual(encoding["input_ids"].dtype, torch.long) @slow diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py index 1c19f79474a6..b43c2e183ca3 100644 --- a/tests/models/isaac/test_modeling_isaac.py +++ b/tests/models/isaac/test_modeling_isaac.py @@ -17,9 +17,11 @@ import base64 import io import os +import re import unittest from functools import lru_cache from pathlib import Path +from types import SimpleNamespace from unittest.mock import patch import pytest @@ -33,7 +35,6 @@ IsaacForConditionalGeneration, IsaacModel, PythonBackend, - Qwen2Tokenizer, is_torch_available, ) from transformers.image_utils import load_image @@ -45,6 +46,7 @@ pixel_shuffle_padded, ) from transformers.models.isaac.processing_isaac import IsaacProcessor +from transformers.pipelines import ImageTextToTextPipeline from transformers.testing_utils import ( require_flash_attn, require_torch, @@ -75,10 +77,11 @@ LOCAL_CHECKPOINT = os.environ.get("ISAAC_TEST_MODEL_PATH") RED_DOT_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg==" +ISAAC_IMAGE_TOKEN = "<|image_pad|>" def document_to_messages( - document: list[dict], vision_token: str = "" + document: list[dict], image_token: str = ISAAC_IMAGE_TOKEN ) -> tuple[list[dict[str, str]], list[Image]]: """ Convert a Document to messages format compatible with chat templates. @@ -86,7 +89,7 @@ def document_to_messages( Args: document: list of dicts containing Text and/or Image content - vision_token: Token to use for image placeholder + image_token: Token to use for image placeholder Returns: Tuple of (messages, images) where messages is a list of dicts with 'role' and 'content' @@ -113,13 +116,22 @@ def document_to_messages( messages.append( { "role": item.get("role", "user"), - "content": vision_token, + "content": image_token, } ) return messages, images +def strip_trailing_stop_string(text: str, stop_strings: list[str] | tuple[str, ...] | None = None) -> str: + if stop_strings is not None: + for stop_string in stop_strings: + if text.endswith(stop_string): + text = text[: -len(stop_string)] + break + return re.sub(r"^\n{2,}", "\n", text) + + def compute_logits_statistics(tensor: torch.Tensor) -> dict[str, object]: """ Summarize logits with simple statistics that are stable across minor @@ -203,32 +215,32 @@ def create_isaac_processor( **overrides, ): """Helper to construct IsaacProcessor without requiring an IsaacConfig instance.""" + vision_config = isaac_config.vision_config params = { - "vision_token": isaac_config.vision_token, "max_sequence_length": isaac_config.max_sequence_length, - "vision_patch_size": isaac_config.vision_patch_size, - "vision_max_num_patches": isaac_config.vision_max_num_patches, - "vision_min_num_patches": isaac_config.vision_min_num_patches, - "pixel_shuffle_scale": isaac_config.pixel_shuffle_scale, + "vision_patch_size": vision_config.patch_size, + "vision_max_num_patches": vision_config.num_patches, + "vision_min_num_patches": getattr(vision_config, "min_num_patches", None), + "pixel_shuffle_scale": vision_config.pixel_shuffle_scale_factor, "rescale_factor": isaac_config.vision_rescale_factor, - "image_mean": tuple(isaac_config.vision_mean), - "image_std": tuple(isaac_config.vision_std), } params.update(overrides) processor_image = image_processor if processor_image is None: - processor_image = IsaacImageProcessor( - patch_size=params["vision_patch_size"], - max_num_patches=params["vision_max_num_patches"], - min_num_patches=params["vision_min_num_patches"], - pixel_shuffle_scale=params["pixel_shuffle_scale"], - rescale_factor=params["rescale_factor"], - image_mean=params["image_mean"], - image_std=params["image_std"], - ) + image_processor_kwargs = { + "patch_size": params["vision_patch_size"], + "max_num_patches": params["vision_max_num_patches"], + "min_num_patches": params["vision_min_num_patches"], + "pixel_shuffle_scale": params["pixel_shuffle_scale"], + "rescale_factor": params["rescale_factor"], + } + if "image_mean" in params: + image_processor_kwargs["image_mean"] = params["image_mean"] + if "image_std" in params: + image_processor_kwargs["image_std"] = params["image_std"] + processor_image = IsaacImageProcessor(**image_processor_kwargs) processor_params = { - "vision_token": isaac_config.vision_token, "max_sequence_length": isaac_config.max_sequence_length, } @@ -242,11 +254,9 @@ def create_isaac_processor( def to_model_multimodal_inputs(processor_output, device): keys = ( "mm_token_type_ids", - "vision_patches", - "image_patch_attention_mask", - "vision_token_grids", - "vision_token_offsets", - "vision_token_lengths", + "pixel_values", + "image_grid_thw", + "image_metadata", ) return { key: (value.to(device) if isinstance(value, torch.Tensor) else value) @@ -255,6 +265,31 @@ def to_model_multimodal_inputs(processor_output, device): } +def pack_image_inputs(pixel_values, image_token_grids, image_token_offsets=None, image_token_lengths=None): + batch_size, max_images, _, _ = pixel_values.shape + device = pixel_values.device + + if image_token_offsets is None: + image_token_offsets = torch.zeros((batch_size, max_images), device=device, dtype=torch.long) + if image_token_lengths is None: + image_token_lengths = image_token_grids[..., 0] * image_token_grids[..., 1] + + image_grid_thw = torch.zeros((batch_size, max_images, 3), device=device, dtype=torch.long) + active_slots = image_token_grids.prod(dim=-1).gt(0) + image_grid_thw[..., 0] = active_slots.to(dtype=torch.long) + image_grid_thw[..., 1:] = image_token_grids + + image_metadata = torch.stack( + ( + image_token_offsets.to(device=device, dtype=torch.long), + image_token_lengths.to(device=device, dtype=torch.long), + ), + dim=-1, + ) + + return pixel_values, image_grid_thw, image_metadata + + @lru_cache(maxsize=1) def _load_red_dot_image(): if Image is None: @@ -295,7 +330,7 @@ def __init__(self): "": 1, "": 2, "": 3, - "": 4, + ISAAC_IMAGE_TOKEN: 4, } self._ids_to_tokens = {idx: tok for tok, idx in self._vocab.items()} super().__init__( @@ -303,9 +338,11 @@ def __init__(self): eos_token="", pad_token="", unk_token="", - extra_special_tokens=[""], + extra_special_tokens={"image_pad_token": ISAAC_IMAGE_TOKEN}, model_max_length=512, ) + self.image_pad_token = ISAAC_IMAGE_TOKEN + self.image_pad_token_id = self._vocab[self.image_pad_token] self.chat_template = ( "{% for message in messages %}" "{{ message['role'] }}: {{ message['content'] | trim }}\n" @@ -423,21 +460,26 @@ def prepare_config_and_inputs(self): def prepare_config_and_inputs_for_common(self): config, input_ids, attention_mask, labels = self.prepare_config_and_inputs() - position_ids = torch.arange(self.seq_length, device=torch_device).view(1, -1).expand(self.batch_size, -1) patch_size = self.vision_config["patch_size"] patch_dim = self.vision_config["num_channels"] * patch_size * patch_size num_image_patches = 4 + vision_patches = torch.randn( + (self.batch_size, 1, num_image_patches, patch_dim), device=torch_device, dtype=torch.float32 + ) + image_token_grids = torch.tensor([[[2, 2]]] * self.batch_size, device=torch_device, dtype=torch.long) + pixel_values, image_grid_thw, image_metadata = pack_image_inputs( + pixel_values=vision_patches, + image_token_grids=image_token_grids, + ) + mm_token_type_ids = torch.zeros((self.batch_size, self.seq_length), device=torch_device, dtype=torch.long) + mm_token_type_ids[:, :num_image_patches] = 1 inputs_dict = { "input_ids": input_ids, "attention_mask": attention_mask, - "position_ids": position_ids, - "pixel_values": torch.randn( - (self.batch_size, 1, num_image_patches, patch_dim), device=torch_device, dtype=torch.float32 - ), - "image_patch_attention_mask": torch.ones( - (self.batch_size, 1, num_image_patches), device=torch_device, dtype=torch.long - ), - "image_token_grids": torch.tensor([[[2, 2]]] * self.batch_size, device=torch_device, dtype=torch.long), + "mm_token_type_ids": mm_token_type_ids, + "pixel_values": pixel_values, + "image_grid_thw": image_grid_thw, + "image_metadata": image_metadata, } if labels is not None: inputs_dict["labels"] = labels @@ -471,6 +513,40 @@ def test_config(self): self.maxDiff = None self.config_tester.run_common_tests() + def prepare_config_and_inputs_for_generate(self, batch_size=2): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + input_keys_to_ignore = [ + "decoder_input_ids", + "decoder_attention_mask", + "use_cache", + "labels", + ] + + filtered_inputs_dict = { + k: v[:batch_size, ...] + if isinstance(v, torch.Tensor) and k not in ["pixel_values", "image_grid_thw", "image_metadata"] + else v + for k, v in inputs_dict.items() + if k not in input_keys_to_ignore + } + + filtered_inputs_dict["pixel_values"] = inputs_dict["pixel_values"][:batch_size] + filtered_inputs_dict["image_grid_thw"] = inputs_dict["image_grid_thw"][:batch_size] + filtered_inputs_dict["image_metadata"] = inputs_dict["image_metadata"][:batch_size] + + text_gen_config = config.get_text_config(decoder=True) + if text_gen_config.eos_token_id is not None and text_gen_config.pad_token_id is None: + text_gen_config.pad_token_id = ( + text_gen_config.eos_token_id + if isinstance(text_gen_config.eos_token_id, int) + else text_gen_config.eos_token_id[0] + ) + text_gen_config.eos_token_id = None + text_gen_config.forced_eos_token_id = None + + return config, filtered_inputs_dict + @unittest.skip(reason="Assisted decoding not supported; Qwen3 backbone does not implement returning attentions") def test_assisted_decoding_matches_greedy_search_0_random(self): pass @@ -501,8 +577,6 @@ def test_text_only_forward_ignores_metadata_without_vision_patches(self): model.to(torch_device) model.eval() - vision_token_grids = torch.zeros((self.model_tester.batch_size, 0, 2), device=torch_device, dtype=torch.long) - with torch.no_grad(): reference = model(input_ids=input_ids, attention_mask=attention_mask) @@ -511,12 +585,25 @@ def test_text_only_forward_ignores_metadata_without_vision_patches(self): result = model( input_ids=input_ids, attention_mask=attention_mask, - vision_token_grids=vision_token_grids, + image_grid_thw=None, + image_metadata=None, ) mock_get_image_features.assert_not_called() torch.testing.assert_close(result.last_hidden_state, reference.last_hidden_state) + def test_image_text_to_text_pipeline_supports_text_only_inputs(self): + config = self.model_tester.get_config() + model = IsaacForConditionalGeneration(config).to(torch_device).eval() + processor = create_isaac_processor(SimpleIsaacTokenizer(), config) + pipe = ImageTextToTextPipeline(model=model, processor=processor, max_new_tokens=4) + + outputs = pipe(text="What is two plus two?", return_full_text=False) + + self.assertEqual(len(outputs), 1) + self.assertEqual(outputs[0]["input_text"], "What is two plus two?") + self.assertIsInstance(outputs[0]["generated_text"], str) + def test_get_image_features_pooler_output_is_scatter_ready(self): config = self.model_tester.get_config() model = IsaacModel(config) @@ -531,31 +618,264 @@ def test_get_image_features_pooler_output_is_scatter_ready(self): device=torch_device, dtype=torch.long, ) - image_patch_attention_mask = torch.ones((2, 2, 4), device=torch_device, dtype=torch.long) image_token_offsets = torch.tensor([[1, 0], [2, 0]], device=torch_device, dtype=torch.long) image_token_lengths = torch.tensor([[2, 1], [1, 0]], device=torch_device, dtype=torch.long) + pixel_values, image_grid_thw, image_metadata = pack_image_inputs( + pixel_values=pixel_values, + image_token_grids=image_token_grids, + image_token_offsets=image_token_offsets, + image_token_lengths=image_token_lengths, + ) with torch.no_grad(): outputs = model.get_image_features( pixel_values=pixel_values, - image_token_grids=image_token_grids, - image_patch_attention_mask=image_patch_attention_mask, - image_token_offsets=image_token_offsets, - image_token_lengths=image_token_lengths, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, return_dict=True, ) expected = torch.cat( ( - outputs.last_hidden_state[0, 0, 1:3], - outputs.last_hidden_state[0, 1, 0:1], - outputs.last_hidden_state[1, 0, 2:3], + outputs.last_hidden_state[0, 1:3], + outputs.last_hidden_state[1, 0:1], + outputs.last_hidden_state[2, 2:3], ), dim=0, ) + pooled_output = torch.cat(outputs.pooler_output, dim=0) + + self.assertEqual(pooled_output.ndim, 2) + torch.testing.assert_close(pooled_output, expected) + + def test_get_rope_index_batch_major_skips_padded_and_fully_truncated_slots(self): + config = self.model_tester.get_config() + model = IsaacModel(config).to(torch_device).eval() + + input_ids = torch.zeros((2, 8), device=torch_device, dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + mm_token_type_ids = torch.tensor( + [ + [0, 0, 1, 1, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0], + ], + device=torch_device, + dtype=torch.long, + ) + image_grid_thw = torch.tensor( + [ + [[1, 2, 2], [1, 2, 2], [1, 2, 2]], + [[1, 2, 2], [0, 0, 0], [0, 0, 0]], + ], + device=torch_device, + dtype=torch.long, + ) + image_metadata = torch.tensor( + [ + [[1, 2], [0, 0], [2, 1]], + [[0, 1], [0, 0], [0, 0]], + ], + device=torch_device, + dtype=torch.long, + ) - self.assertEqual(outputs.pooler_output.ndim, 2) - torch.testing.assert_close(outputs.pooler_output, expected) + position_ids, rope_deltas = model.get_rope_index( + input_ids=input_ids, + mm_token_type_ids=mm_token_type_ids, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, + attention_mask=attention_mask, + ) + + expected_sample0 = torch.tensor( + [ + [0, 1, 2, 2, 3, 4, 5, 6], + [0, 1, 0, 1, 3, 1, 5, 6], + [0, 1, 1, 0, 3, 0, 5, 6], + ], + device=torch_device, + dtype=torch.long, + ) + expected_sample1 = torch.tensor( + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 0, 2, 3, 4, 5, 6, 7], + [0, 0, 2, 3, 4, 5, 6, 7], + ], + device=torch_device, + dtype=torch.long, + ) + + torch.testing.assert_close(position_ids[:, 0], expected_sample0) + torch.testing.assert_close(position_ids[:, 1], expected_sample1) + torch.testing.assert_close( + rope_deltas, + torch.tensor([[-1], [0]], device=torch_device, dtype=torch.long), + ) + + def test_forward_scatters_batch_major_image_features_in_slot_order(self): + config = self.model_tester.get_config() + model = IsaacModel(config).to(torch_device).eval() + + input_ids = torch.randint( + 0, + config.get_text_config().vocab_size, + (2, 6), + device=torch_device, + dtype=torch.long, + ) + mm_token_type_ids = torch.tensor( + [ + [0, 1, 1, 0, 1, 0], + [0, 0, 0, 0, 0, 0], + ], + device=torch_device, + dtype=torch.long, + ) + patch_size = self.model_tester.vision_config["patch_size"] + patch_dim = self.model_tester.vision_config["num_channels"] * patch_size * patch_size + pixel_values = torch.zeros((2, 2, 4, patch_dim), device=torch_device, dtype=torch.float32) + image_grid_thw = torch.tensor( + [ + [[1, 2, 2], [1, 2, 2]], + [[0, 0, 0], [0, 0, 0]], + ], + device=torch_device, + dtype=torch.long, + ) + image_metadata = torch.tensor( + [ + [[0, 2], [1, 1]], + [[0, 0], [0, 0]], + ], + device=torch_device, + dtype=torch.long, + ) + + hidden_size = config.get_text_config().hidden_size + scattered_features = ( + torch.full((2, hidden_size), 11.0, device=torch_device), + torch.full((1, hidden_size), 22.0, device=torch_device), + ) + captured = {} + + def fake_language_model(**kwargs): + captured["inputs_embeds"] = kwargs["inputs_embeds"].detach().clone() + return SimpleNamespace( + last_hidden_state=kwargs["inputs_embeds"], + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + with patch.object( + model, + "get_image_features", + return_value=SimpleNamespace(pooler_output=scattered_features), + ) as mock_get_image_features: + with patch.object(model, "compute_3d_position_ids", return_value=None): + with patch.object(model.language_model, "forward", side_effect=fake_language_model): + model( + input_ids=input_ids, + mm_token_type_ids=mm_token_type_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, + ) + + mock_get_image_features.assert_called_once() + call_kwargs = mock_get_image_features.call_args.kwargs + torch.testing.assert_close(call_kwargs["pixel_values"], pixel_values) + torch.testing.assert_close(call_kwargs["image_grid_thw"], image_grid_thw) + torch.testing.assert_close(call_kwargs["image_metadata"], image_metadata) + + scattered = captured["inputs_embeds"][mm_token_type_ids.bool()] + expected = torch.cat(scattered_features, dim=0).to(dtype=scattered.dtype) + torch.testing.assert_close(scattered, expected) + + def test_prepare_position_ids_for_generation_uses_batch_major_rope(self): + config = self.model_tester.get_config() + model = IsaacForConditionalGeneration(config).to(torch_device).eval() + + input_ids = torch.tensor([[4, 5, 6], [7, 8, 9]], device=torch_device, dtype=torch.long) + mm_token_type_ids = torch.tensor([[0, 1, 0], [0, 0, 0]], device=torch_device, dtype=torch.long) + image_grid_thw = torch.tensor( + [ + [[1, 2, 2]], + [[0, 0, 0]], + ], + device=torch_device, + dtype=torch.long, + ) + image_metadata = torch.tensor( + [ + [[0, 1]], + [[0, 0]], + ], + device=torch_device, + dtype=torch.long, + ) + expected_positions = torch.arange(18, device=torch_device, dtype=torch.long).view(3, 2, 3) + expected_deltas = torch.tensor([[0], [1]], device=torch_device, dtype=torch.long) + + with patch.object( + model.model, + "get_rope_index", + return_value=(expected_positions, expected_deltas), + ) as mock_get_rope_index: + position_ids = model._prepare_position_ids_for_generation( + input_ids, + { + "input_ids": input_ids, + "mm_token_type_ids": mm_token_type_ids, + "image_grid_thw": image_grid_thw, + "image_metadata": image_metadata, + "attention_mask": torch.ones_like(input_ids), + }, + ) + + mock_get_rope_index.assert_called_once() + torch.testing.assert_close(position_ids[1:], expected_positions) + torch.testing.assert_close(model.model.rope_deltas, expected_deltas) + + def test_expand_inputs_for_generation_repeats_batch_major_visual_tensors(self): + config = self.model_tester.get_config() + model = IsaacForConditionalGeneration(config).to(torch_device).eval() + + input_ids = torch.tensor([[1, 2], [3, 4]], device=torch_device, dtype=torch.long) + mm_token_type_ids = torch.tensor([[0, 1], [1, 0]], device=torch_device, dtype=torch.long) + pixel_values = torch.arange(2 * 2 * 3 * 4, device=torch_device, dtype=torch.float32).view(2, 2, 3, 4) + image_grid_thw = torch.tensor( + [ + [[1, 2, 2], [0, 0, 0]], + [[1, 2, 2], [1, 2, 2]], + ], + device=torch_device, + dtype=torch.long, + ) + image_metadata = torch.tensor( + [ + [[0, 1], [0, 0]], + [[1, 2], [0, 1]], + ], + device=torch_device, + dtype=torch.long, + ) + + expanded_input_ids, expanded_kwargs = model._expand_inputs_for_generation( + expand_size=2, + input_ids=input_ids, + mm_token_type_ids=mm_token_type_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, + ) + + torch.testing.assert_close(expanded_input_ids, input_ids.repeat_interleave(2, dim=0)) + torch.testing.assert_close(expanded_kwargs["mm_token_type_ids"], mm_token_type_ids.repeat_interleave(2, dim=0)) + torch.testing.assert_close(expanded_kwargs["pixel_values"], pixel_values.repeat_interleave(2, dim=0)) + torch.testing.assert_close(expanded_kwargs["image_grid_thw"], image_grid_thw.repeat_interleave(2, dim=0)) + torch.testing.assert_close(expanded_kwargs["image_metadata"], image_metadata.repeat_interleave(2, dim=0)) def test_for_conditional_generation(self): config, input_ids, attention_mask, labels = self.model_tester.prepare_config_and_inputs() @@ -578,7 +898,7 @@ def test_isaac_for_conditional_generation_initialization(self): self.assertTrue(hasattr(model, "model")) self.assertTrue(hasattr(model, "lm_head")) - self.assertTrue(hasattr(model.model, "vision_tower")) + self.assertTrue(hasattr(model.model, "visual")) self.assertTrue(hasattr(model.model, "multimodal_projector")) input_vocab_size = model.get_input_embeddings().num_embeddings @@ -604,6 +924,40 @@ def test_isaac_for_conditional_generation_loss_and_generate_flag(self): self.assertEqual(outputs.loss.ndim, 0) self.assertEqual(outputs.logits.shape, (batch_size, seq_len, output_vocab_size)) + @pytest.mark.generate + def test_left_padding_compatibility(self): + _, inputs_dict = self.prepare_config_and_inputs_for_generate() + mm_token_type_ids = inputs_dict["mm_token_type_ids"] + pad_size = (mm_token_type_ids.shape[0], 32) + padded_mm_token_type_ids = torch.cat( + (torch.zeros(pad_size, dtype=mm_token_type_ids.dtype, device=torch_device), mm_token_type_ids), dim=1 + ) + + super().test_left_padding_compatibility( + unpadded_custom_inputs={"mm_token_type_ids": mm_token_type_ids}, + padded_custom_inputs={"mm_token_type_ids": padded_mm_token_type_ids}, + ) + + @unittest.skip(reason="Isaac is image-only.") + def test_get_video_features_output_0(self): + pass + + @unittest.skip(reason="Isaac is image-only.") + def test_get_video_features_output_1(self): + pass + + @unittest.skip(reason="Isaac is image-only.") + def test_get_video_features_output_2(self): + pass + + @unittest.skip(reason="Isaac is image-only.") + def test_get_video_features_hidden_states(self): + pass + + @unittest.skip(reason="Isaac is image-only.") + def test_get_video_features_attentions(self): + pass + @require_torch class IsaacPixelShufflePaddedTest(unittest.TestCase): @@ -612,7 +966,7 @@ def test_pixel_shuffle_padded_matches_reference_no_attention_mask(self): token_grids = torch.tensor([[4, 4], [2, 4]], device=torch_device, dtype=torch.long) expected_hidden, expected_mask, expected_lengths = _pixel_shuffle_reference(x, token_grids, scale_factor=2) - hidden = pixel_shuffle_padded(x=x, token_grids=token_grids, scale_factor=2) + hidden = pixel_shuffle_padded(hidden_states=x, token_grids=token_grids, scale_factor=2) torch.testing.assert_close(hidden, expected_hidden) @@ -621,13 +975,13 @@ def test_pixel_shuffle_padded_raises_on_non_divisible_grid(self): token_grids = torch.tensor([[3, 5]], device=torch_device, dtype=torch.long) with pytest.raises(ValueError, match="divisible"): - pixel_shuffle_padded(x=x, token_grids=token_grids, scale_factor=2) + pixel_shuffle_padded(hidden_states=x, token_grids=token_grids, scale_factor=2) def test_pixel_shuffle_padded_zero_grid(self): x = torch.randn(1, 4, 8, device=torch_device) token_grids = torch.tensor([[0, 0]], device=torch_device, dtype=torch.long) - hidden = pixel_shuffle_padded(x=x, token_grids=token_grids, scale_factor=2) + hidden = pixel_shuffle_padded(hidden_states=x, token_grids=token_grids, scale_factor=2) self.assertEqual(hidden.shape, (1, 0, 32)) @@ -759,10 +1113,8 @@ def setUp(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.checkpoint = _base_reference_checkpoint_or_skip() self.hf_config = IsaacConfig.from_pretrained(self.checkpoint, revision=BASE_MODEL_REVISION) - self.tokenizer = Qwen2Tokenizer.from_pretrained( - self.checkpoint, trust_remote_code=False, use_fast=False, revision=BASE_MODEL_REVISION - ) - self.processor = create_isaac_processor(self.tokenizer, self.hf_config) + self.processor = IsaacProcessor.from_pretrained(self.checkpoint, revision=BASE_MODEL_REVISION, do_pad=True) + self.tokenizer = self.processor.tokenizer self.hf_config.vision_config._attn_implementation = "flash_attention_2" self.hf_config.vision_config.attn_implementation = "flash_attention_2" self.model = IsaacForConditionalGeneration.from_pretrained( @@ -771,7 +1123,7 @@ def setUp(self): self.model = self.model.to(device=self.device, dtype=self.dtype) self.model.eval() - def _generate_from_messages(self, messages, images, num_tokens=None): + def _generate_from_messages(self, messages, images, num_tokens=None, generate_kwargs=None): prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() processor_output = self.processor(text=prompt, images=images or None, return_tensors="pt") input_ids = processor_output["input_ids"].to(self.device) @@ -784,18 +1136,20 @@ def _generate_from_messages(self, messages, images, num_tokens=None): attention_mask = attention_mask.to(self.device) prompt_len = input_ids.shape[1] multimodal_inputs = to_model_multimodal_inputs(processor_output, self.device) + generate_kwargs = {} if generate_kwargs is None else dict(generate_kwargs) + generate_kwargs.setdefault("max_new_tokens", num_tokens or self.max_new_tokens) + generate_kwargs.setdefault("do_sample", False) + generate_kwargs.setdefault("pad_token_id", self.tokenizer.eos_token_id) + generate_kwargs.setdefault("eos_token_id", self.tokenizer.eos_token_id) + generate_kwargs.setdefault("return_dict_in_generate", True) + generate_kwargs.setdefault("output_logits", True) with torch.no_grad(): outputs = self.model.generate( input_ids=input_ids, attention_mask=attention_mask, **multimodal_inputs, - max_new_tokens=num_tokens or self.max_new_tokens, - do_sample=False, - pad_token_id=self.tokenizer.eos_token_id, - eos_token_id=self.tokenizer.eos_token_id, - return_dict_in_generate=True, - output_logits=True, + **generate_kwargs, ) generated_ids = outputs.sequences @@ -810,7 +1164,7 @@ def test_generate_from_image_text(self): messages = [ {"role": "user", "content": "Describe this image:"}, - {"role": "user", "content": ""}, + {"role": "user", "content": self.processor.image_token}, ] generated_text = self._generate_from_messages(messages, [image]) expected_fragment = "The image is a close-up photograph of a red cross symbol." @@ -842,12 +1196,12 @@ def test_vqa_from_image(self): "role": "user", }, ] - messages, images = document_to_messages(document) + messages, images = document_to_messages(document, image_token=self.processor.image_token) generated_text = self._generate_from_messages(messages, images, num_tokens=256) expected_response = "\nNo, it is not safe to cross the street at this moment. The traffic light for pedestrians is red, indicating that it is not safe to cross." assert generated_text == expected_response - def _generate_batch(self, prompts, images_list, num_tokens=None): + def _generate_batch(self, prompts, images_list, num_tokens=None, generate_kwargs=None): processor_output = self.processor(text=prompts, images=images_list, return_tensors="pt") input_ids = processor_output["input_ids"] if input_ids.dim() == 1: @@ -865,17 +1219,19 @@ def _generate_batch(self, prompts, images_list, num_tokens=None): attention_mask = attention_mask.to(self.device) multimodal_inputs = to_model_multimodal_inputs(processor_output, self.device) + generate_kwargs = {} if generate_kwargs is None else dict(generate_kwargs) + generate_kwargs.setdefault("max_new_tokens", num_tokens or self.max_new_tokens) + generate_kwargs.setdefault("do_sample", False) + generate_kwargs.setdefault("pad_token_id", self.tokenizer.eos_token_id) + generate_kwargs.setdefault("eos_token_id", self.tokenizer.eos_token_id) + generate_kwargs.setdefault("return_dict_in_generate", True) with torch.no_grad(): outputs = self.model.generate( input_ids=input_ids, attention_mask=attention_mask, **multimodal_inputs, - max_new_tokens=num_tokens or self.max_new_tokens, - do_sample=False, - pad_token_id=self.tokenizer.eos_token_id, - eos_token_id=self.tokenizer.eos_token_id, - return_dict_in_generate=True, + **generate_kwargs, ) sequences = outputs.sequences generated_texts = [] @@ -897,18 +1253,22 @@ def test_logit_equivalence(self): messages = [ {"role": "user", "content": "Describe this image:"}, - {"role": "user", "content": ""}, + {"role": "user", "content": self.processor.image_token}, ] prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() processor_output = self.processor(text=prompt, images=images, return_tensors="pt") input_ids = processor_output["input_ids"] device = next(self.model.parameters()).device input_ids = input_ids.to(device) + attention_mask = processor_output.get("attention_mask") + if attention_mask is not None: + attention_mask = attention_mask.to(device) multimodal_inputs = to_model_multimodal_inputs(processor_output, device) with torch.no_grad(): outputs = self.model.generate( input_ids=input_ids, + attention_mask=attention_mask, **multimodal_inputs, max_new_tokens=num_tokens or self.max_new_tokens, do_sample=False, @@ -961,13 +1321,13 @@ def test_batched_generation_matches_individual(self): # Image + text messages_image_text = [ {"role": "user", "content": "Describe this image:"}, - {"role": "user", "content": ""}, + {"role": "user", "content": self.processor.image_token}, ] single_image_text = self._generate_from_messages(messages_image_text, [red_image]) assert single_image_text, "Image-text single generation is empty" # VQA - messages_vqa, images_vqa = document_to_messages(vqa_document) + messages_vqa, images_vqa = document_to_messages(vqa_document, image_token=self.processor.image_token) single_vqa = self._generate_from_messages(messages_vqa, images_vqa, num_tokens=self.max_new_tokens) assert single_vqa, "VQA single generation is empty" @@ -986,10 +1346,10 @@ def test_batched_generation_matches_individual(self): # Input-level sanity assert len(prompts) == len(images_list) == 3 for i, (p, imgs) in enumerate(zip(prompts, images_list)): - expected_tokens = p.count(self.hf_config.vision_token) + expected_tokens = p.count(self.processor.image_token) num_imgs = len(imgs) assert expected_tokens == num_imgs, ( - f"sample {i} vision token/image mismatch: {expected_tokens} vs {num_imgs}" + f"sample {i} image token/image mismatch: {expected_tokens} vs {num_imgs}" ) pad_id = self.tokenizer.pad_token_id @@ -1025,23 +1385,34 @@ def test_batched_generation_matches_individual(self): expected_modality[-single_len:] = single_packed["mm_token_type_ids"].squeeze(0) torch.testing.assert_close(batch_modality_row, expected_modality) - if single_packed["vision_patches"] is not None: - expected_image_count = int(single_packed["vision_token_lengths"].gt(0).sum().item()) - batch_image_count = int(batch_packed["vision_token_lengths"][i].gt(0).sum().item()) - assert batch_image_count == expected_image_count - if expected_image_count > 0: - torch.testing.assert_close( - batch_packed["vision_token_grids"][i, :expected_image_count], - single_packed["vision_token_grids"][0, :expected_image_count], - ) - torch.testing.assert_close( - batch_packed["vision_token_offsets"][i, :expected_image_count], - single_packed["vision_token_offsets"][0, :expected_image_count], - ) - torch.testing.assert_close( - batch_packed["vision_token_lengths"][i, :expected_image_count], - single_packed["vision_token_lengths"][0, :expected_image_count], - ) + if batch_packed["image_grid_thw"] is not None: + batch_image_mask = batch_packed["image_grid_thw"][i, :, 0].eq(1) + expected_image_count = int(batch_image_mask.sum().item()) + if single_packed["image_grid_thw"] is None: + assert expected_image_count == 0 + else: + single_image_mask = single_packed["image_grid_thw"][0, :, 0].eq(1) + assert expected_image_count == int(single_image_mask.sum().item()) + if expected_image_count > 0: + batch_image_grid_thw = batch_packed["image_grid_thw"][i, batch_image_mask] + single_image_grid_thw = single_packed["image_grid_thw"][0, single_image_mask] + batch_image_metadata = batch_packed["image_metadata"][i, batch_image_mask] + single_image_metadata = single_packed["image_metadata"][0, single_image_mask] + + torch.testing.assert_close(batch_image_grid_thw, single_image_grid_thw) + torch.testing.assert_close(batch_image_metadata, single_image_metadata) + + for batch_pixel_values, single_pixel_values, grid_thw in zip( + batch_packed["pixel_values"][i, batch_image_mask], + single_packed["pixel_values"][0, single_image_mask], + batch_image_grid_thw, + strict=True, + ): + valid_patch_count = int((grid_thw[1] * grid_thw[2]).item()) + torch.testing.assert_close( + batch_pixel_values[:valid_patch_count], + single_pixel_values[:valid_patch_count], + ) if single_len == max_length: continue @@ -1053,10 +1424,9 @@ def test_batched_generation_matches_individual(self): assert not torch.any(attention_mask[: max_length - single_len]), f"sample {i} mask ones inside left pad" assert torch.all(attention_mask[-single_len:]), f"sample {i} mask zeros inside content" - assert batch_packed["vision_patches"] is not None - assert batch_packed["vision_token_grids"] is not None - assert batch_packed["vision_token_offsets"] is not None - assert batch_packed["vision_token_lengths"] is not None + assert batch_packed["pixel_values"] is not None + assert batch_packed["image_grid_thw"] is not None + assert batch_packed["image_metadata"] is not None batch_texts = self._generate_batch(prompts, images_list, num_tokens=100) assert len(batch_texts) == len(single_texts) == 3 @@ -1064,6 +1434,72 @@ def test_batched_generation_matches_individual(self): for i, (btxt, stxt) in enumerate(zip(batch_texts, single_texts)): assert stxt in btxt, f"batch[{i}] mismatch: {btxt!r} vs single[{i}] {stxt!r}" + def test_batched_beam_generation_matches_individual(self): + red_image = _load_red_dot_image() + if red_image is None: + pytest.skip("PIL.Image is required for Isaac generation tests.") + + vqa_document = [ + { + "type": "image", + "content": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", + "role": "user", + }, + { + "type": "text", + "content": "Is it safe to cross the street at this moment?", + "role": "user", + }, + ] + beam_kwargs = {"num_beams": 2} + + doc_text_only = [{"type": "text", "content": "What is the pythogorean theorem?", "role": "user"}] + messages_text_only, images_text_only = document_to_messages(doc_text_only) + single_text_only = self._generate_from_messages( + messages_text_only, + images_text_only, + num_tokens=self.max_new_tokens, + generate_kwargs=beam_kwargs, + ) + assert single_text_only, "Text-only beam generation is empty" + + messages_image_text = [ + {"role": "user", "content": "Describe this image:"}, + {"role": "user", "content": self.processor.image_token}, + ] + single_image_text = self._generate_from_messages(messages_image_text, [red_image], generate_kwargs=beam_kwargs) + assert single_image_text, "Image-text beam generation is empty" + + messages_vqa, images_vqa = document_to_messages(vqa_document, image_token=self.processor.image_token) + single_vqa = self._generate_from_messages( + messages_vqa, + images_vqa, + num_tokens=self.max_new_tokens, + generate_kwargs=beam_kwargs, + ) + assert single_vqa, "VQA beam generation is empty" + + single_texts = [single_text_only, single_image_text, single_vqa] + prompts = [ + self.processor.apply_chat_template(messages_text_only, tokenize=False, add_generation_prompt=True).strip(), + self.processor.apply_chat_template( + messages_image_text, tokenize=False, add_generation_prompt=True + ).strip(), + self.processor.apply_chat_template(messages_vqa, tokenize=False, add_generation_prompt=True).strip(), + ] + images_list = [images_text_only, [red_image], images_vqa] + + batch_texts = self._generate_batch( + prompts, + images_list, + num_tokens=100, + generate_kwargs=beam_kwargs, + ) + assert len(batch_texts) == len(single_texts) == 3 + + for i, (btxt, stxt) in enumerate(zip(batch_texts, single_texts)): + assert stxt in btxt, f"beam batch[{i}] mismatch: {btxt!r} vs single[{i}] {stxt!r}" + @require_torch @require_vision @@ -1077,10 +1513,9 @@ def setUp(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.checkpoint = _reference_checkpoint_or_skip() self.hf_config = IsaacConfig.from_pretrained(self.checkpoint, revision=MODEL_REVISION) - self.tokenizer = Qwen2Tokenizer.from_pretrained( - self.checkpoint, trust_remote_code=False, use_fast=False, revision=MODEL_REVISION - ) - self.processor = create_isaac_processor(self.tokenizer, self.hf_config) + # The current local slow fallback only supports padded packing for this checkpoint. + self.processor = IsaacProcessor.from_pretrained(self.checkpoint, revision=MODEL_REVISION, do_pad=True) + self.tokenizer = self.processor.tokenizer self.hf_config.vision_config._attn_implementation = "flash_attention_2" self.hf_config.vision_config.attn_implementation = "flash_attention_2" self.model = IsaacForConditionalGeneration.from_pretrained( @@ -1107,21 +1542,26 @@ def test_hf_generate_box_points(self): "role": "user", }, ] - messages, images = document_to_messages(document, vision_token=self.hf_config.vision_token) + messages, images = document_to_messages(document, image_token=self.processor.image_token) prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() processor_output = self.processor(text=prompt, images=images, return_tensors="pt") input_ids = processor_output["input_ids"].to(self.device) + attention_mask = processor_output.get("attention_mask") + if attention_mask is not None: + attention_mask = attention_mask.to(self.device) prompt_len = input_ids.shape[1] multimodal_inputs = to_model_multimodal_inputs(processor_output, self.device) with torch.no_grad(): outputs = self.model.generate( input_ids=input_ids, + attention_mask=attention_mask, **multimodal_inputs, max_new_tokens=self.max_new_tokens, do_sample=False, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, + tokenizer=self.tokenizer, return_dict_in_generate=True, ) @@ -1157,16 +1597,20 @@ def test_hf_generate_polygon_points(self): "role": "user", }, ] - messages, images = document_to_messages(document, vision_token=self.hf_config.vision_token) + messages, images = document_to_messages(document, image_token=self.processor.image_token) prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() processor_output = self.processor(text=prompt, images=images, return_tensors="pt") input_ids = processor_output["input_ids"].to(self.device) + attention_mask = processor_output.get("attention_mask") + if attention_mask is not None: + attention_mask = attention_mask.to(self.device) prompt_len = input_ids.shape[1] multimodal_inputs = to_model_multimodal_inputs(processor_output, self.device) with torch.no_grad(): outputs = self.model.generate( input_ids=input_ids, + attention_mask=attention_mask, **multimodal_inputs, max_new_tokens=self.max_new_tokens, do_sample=False, diff --git a/tests/models/isaac/test_post_processing_isaac.py b/tests/models/isaac/test_post_processing_isaac.py deleted file mode 100644 index 32c52c175ffd..000000000000 --- a/tests/models/isaac/test_post_processing_isaac.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Isaac processor post-processing helpers.""" - -import pytest - -from transformers import PythonBackend -from transformers.models.isaac.image_processing_isaac import IsaacImageProcessor -from transformers.models.isaac.processing_isaac import IsaacProcessor -from transformers.testing_utils import require_torch - - -class SimpleIsaacTokenizer(PythonBackend): - vocab_files_names = {} - model_input_names = ["input_ids"] - - def __init__(self): - self._vocab = { - "": 0, - "": 1, - "": 2, - "": 3, - "": 4, - } - self._ids_to_tokens = {idx: tok for tok, idx in self._vocab.items()} - super().__init__( - bos_token="", - eos_token="", - pad_token="", - unk_token="", - extra_special_tokens=[""], - model_max_length=512, - ) - - def get_vocab(self): - return dict(self._vocab) - - def _tokenize(self, text): - clean = text.replace("\n", " ").strip() - if not clean: - return [] - return [token for token in clean.split(" ") if token] - - def _convert_token_to_id(self, token): - if token not in self._vocab: - next_id = len(self._vocab) - self._vocab[token] = next_id - self._ids_to_tokens[next_id] = token - return self._vocab[token] - - def _convert_id_to_token(self, index): - return self._ids_to_tokens.get(index, self.unk_token) - - @property - def vocab_size(self) -> int: - return len(self._vocab) - - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - if token_ids_1 is not None: - token_ids_0 = token_ids_0 + token_ids_1 - return [self.bos_token_id] + list(token_ids_0) + [self.eos_token_id] - - def save_vocabulary(self, save_directory, filename_prefix=None): - return () - - -def _make_processor(): - return IsaacProcessor(image_processor=IsaacImageProcessor(), tokenizer=SimpleIsaacTokenizer()) - - -@require_torch -def test_post_process_generation_extracts_boxes_and_cleans_text(): - processor = _make_processor() - - generated_text = ( - "No, it is not safe to cross the street. " - '(808, 247), (863, 386)' - ) - - clean_text, annotations = processor.post_process_generation(generated_text) - - assert clean_text == "No, it is not safe to cross the street." - assert len(annotations) == 1 - box = annotations[0] - assert box.mention == "traffic light" - assert box.t == pytest.approx(0.5) - assert box.top_left.x == 808 - assert box.top_left.y == 247 - assert box.bottom_right.x == 863 - assert box.bottom_right.y == 386 diff --git a/tests/models/isaac/test_processing_isaac.py b/tests/models/isaac/test_processing_isaac.py new file mode 100644 index 000000000000..431850cdbe64 --- /dev/null +++ b/tests/models/isaac/test_processing_isaac.py @@ -0,0 +1,879 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Testing suite for the Isaac processor.""" + +import os +import re +import unittest +from pathlib import Path + +import numpy as np +import pytest +import torch +from huggingface_hub import is_offline_mode + +from transformers import IsaacConfig, PythonBackend +from transformers.models.isaac.image_processing_isaac import IsaacImageProcessor +from transformers.models.isaac.processing_isaac import IsaacProcessor +from transformers.testing_utils import require_torch, require_vision +from transformers.tokenization_utils_base import BatchEncoding +from transformers.utils import is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from PIL import Image +else: + Image = None + + +ISAAC_OUTPUT_KEYS = { + "input_ids", + "attention_mask", + "mm_token_type_ids", + "pixel_values", + "image_grid_thw", + "image_metadata", +} + + +def _simple_tokenizer_call( + tokenizer, + text, + padding=False, + truncation=None, + max_length=None, + pad_to_multiple_of=None, + return_attention_mask=True, + return_overflowing_tokens=False, + return_tensors=None, + add_special_tokens=True, + **kwargs, +): + texts = [text] if isinstance(text, str) else list(text) + rows = [] + row_kinds = [] + overflow_to_sample_mapping = [] + + for sample_idx, sample in enumerate(texts): + token_ids = [tokenizer._convert_token_to_id(token) for token in tokenizer._tokenize(sample)] + if add_special_tokens: + token_ids = tokenizer.build_inputs_with_special_tokens(token_ids) + + kept_ids = list(token_ids) + dropped_ids = [] + if truncation and max_length is not None and len(token_ids) > max_length: + if tokenizer.truncation_side == "left": + dropped_ids = token_ids[:-max_length] + kept_ids = token_ids[-max_length:] + else: + kept_ids = token_ids[:max_length] + dropped_ids = token_ids[max_length:] + + rows.append(kept_ids) + row_kinds.append("kept") + overflow_to_sample_mapping.append(sample_idx) + + if return_overflowing_tokens and dropped_ids: + rows.append(dropped_ids) + row_kinds.append("overflow") + overflow_to_sample_mapping.append(sample_idx) + + kept_rows = [row for row, row_kind in zip(rows, row_kinds, strict=True) if row_kind == "kept"] + target_length = None + if padding in (True, "longest"): + target_length = max((len(row) for row in kept_rows), default=0) + elif padding == "max_length": + target_length = max_length + + if target_length is not None and pad_to_multiple_of is not None and target_length % pad_to_multiple_of != 0: + target_length = ((target_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + padded_rows = [] + attention_masks = [] + for row, row_kind in zip(rows, row_kinds, strict=True): + if row_kind == "kept" and target_length is not None: + pad_len = target_length - len(row) + if tokenizer.padding_side == "left": + padded_row = [tokenizer.pad_token_id] * pad_len + row + attention_mask = [0] * pad_len + [1] * len(row) + else: + padded_row = row + [tokenizer.pad_token_id] * pad_len + attention_mask = [1] * len(row) + [0] * pad_len + else: + padded_row = row + attention_mask = [1] * len(row) + + padded_rows.append(padded_row) + attention_masks.append(attention_mask) + + data = {"input_ids": padded_rows} + if return_attention_mask: + data["attention_mask"] = attention_masks + if return_overflowing_tokens: + data["overflow_to_sample_mapping"] = overflow_to_sample_mapping + + return BatchEncoding(data=data, tensor_type=return_tensors) + + +class SimpleIsaacTokenizer(PythonBackend): + vocab_files_names = {} + model_input_names = ["input_ids"] + + def __init__(self): + self._vocab = { + "": 0, + "": 1, + "": 2, + "": 3, + "": 4, + "<|image_pad|>": 5, + } + self._ids_to_tokens = {idx: tok for tok, idx in self._vocab.items()} + super().__init__( + bos_token="", + eos_token="", + pad_token="", + unk_token="", + additional_special_tokens=[""], + extra_special_tokens={"image_pad_token": "<|image_pad|>"}, + model_max_length=512, + ) + + def get_vocab(self): + return dict(self._vocab) + + def _tokenize(self, text): + clean = text.replace("\n", " ").strip() + if not clean: + return [] + + special_tokens = sorted( + (token for token in self._vocab if token.startswith("<") and token.endswith(">")), + key=len, + reverse=True, + ) + if not special_tokens: + return [token for token in clean.split(" ") if token] + + split_pattern = "(" + "|".join(re.escape(token) for token in special_tokens) + ")" + tokens = [] + for chunk in re.split(split_pattern, clean): + if not chunk or chunk.isspace(): + continue + if chunk in self._vocab: + tokens.append(chunk) + else: + tokens.extend(token for token in chunk.split(" ") if token) + return tokens + + def _convert_token_to_id(self, token): + if token not in self._vocab: + next_id = len(self._vocab) + self._vocab[token] = next_id + self._ids_to_tokens[next_id] = token + return self._vocab[token] + + def _convert_id_to_token(self, index): + return self._ids_to_tokens.get(index, self.unk_token) + + @property + def vocab_size(self) -> int: + return len(self._vocab) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + if token_ids_1 is not None: + token_ids_0 = token_ids_0 + token_ids_1 + return [self.bos_token_id] + list(token_ids_0) + [self.eos_token_id] + + def save_vocabulary(self, save_directory, filename_prefix=None): + return () + + def __call__(self, text, **kwargs): + return _simple_tokenizer_call(self, text, **kwargs) + + +class SimpleIsaacTokenizerWithNamedImagePad(PythonBackend): + vocab_files_names = {} + model_input_names = ["input_ids"] + + def __init__(self): + self._vocab = { + "": 0, + "": 1, + "": 2, + "": 3, + "": 4, + "": 5, + "<|image_pad|>": 6, + } + self._ids_to_tokens = {idx: tok for tok, idx in self._vocab.items()} + super().__init__( + bos_token="", + eos_token="", + pad_token="", + unk_token="", + extra_special_tokens={"image_pad_token": ""}, + model_max_length=512, + ) + + def get_vocab(self): + return dict(self._vocab) + + def _tokenize(self, text): + clean = text.replace("\n", " ").strip() + if not clean: + return [] + + special_tokens = sorted( + (token for token in self._vocab if token.startswith("<") and token.endswith(">")), + key=len, + reverse=True, + ) + split_pattern = "(" + "|".join(re.escape(token) for token in special_tokens) + ")" + tokens = [] + for chunk in re.split(split_pattern, clean): + if not chunk or chunk.isspace(): + continue + if chunk in self._vocab: + tokens.append(chunk) + else: + tokens.extend(token for token in chunk.split(" ") if token) + return tokens + + def _convert_token_to_id(self, token): + return self._vocab.get(token, self._vocab[""]) + + def _convert_id_to_token(self, index): + return self._ids_to_tokens.get(index, self.unk_token) + + @property + def vocab_size(self) -> int: + return len(self._vocab) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + if token_ids_1 is not None: + token_ids_0 = token_ids_0 + token_ids_1 + return [self.bos_token_id] + list(token_ids_0) + [self.eos_token_id] + + def save_vocabulary(self, save_directory, filename_prefix=None): + return () + + def __call__(self, text, **kwargs): + return _simple_tokenizer_call(self, text, **kwargs) + + +class IsaacProcessorTestDouble(IsaacProcessor): + def check_argument_for_proper_class(self, argument_name, argument): + return type(argument) + + +def _make_dummy_image(size=(32, 32), color=(255, 0, 0)): + if Image is None: + raise RuntimeError("PIL.Image is not available in this environment.") + return Image.new("RGB", size, color=color) + + +def _make_processor_with_max_len(tokenizer, base_config, max_len): + config = IsaacConfig(**base_config.to_dict()) + config.max_sequence_length = max_len + vision_config = config.vision_config + image_processor = IsaacImageProcessor( + patch_size=vision_config.patch_size, + max_num_patches=vision_config.num_patches, + pixel_shuffle_scale=vision_config.pixel_shuffle_scale_factor, + rescale_factor=config.vision_rescale_factor, + ) + return IsaacProcessorTestDouble( + image_processor=image_processor, + tokenizer=tokenizer, + max_sequence_length=config.max_sequence_length, + ) + + +def _run_processor(processor, text, images=None): + return processor(text=text, images=images, return_tensors="pt") + + +def _make_post_process_processor(): + return IsaacProcessorTestDouble(image_processor=IsaacImageProcessor(), tokenizer=SimpleIsaacTokenizer()) + + +def test_processor_prefers_named_image_pad_token(): + processor = IsaacProcessorTestDouble( + image_processor=IsaacImageProcessor(), tokenizer=SimpleIsaacTokenizerWithNamedImagePad() + ) + + assert processor.image_token == "" + assert processor.image_token_id == processor.tokenizer.image_pad_token_id + assert processor.image_token_id != processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") + + +def _assert_common(outputs, batch_size=1): + assert set(outputs.keys()) == ISAAC_OUTPUT_KEYS + + input_ids = outputs["input_ids"] + attention_mask = outputs["attention_mask"] + mm_token_type_ids = outputs["mm_token_type_ids"] + pixel_values = outputs["pixel_values"] + image_grid_thw = outputs["image_grid_thw"] + image_metadata = outputs["image_metadata"] + + assert input_ids.shape[0] == batch_size + assert attention_mask.shape == input_ids.shape + assert mm_token_type_ids.shape == input_ids.shape + assert input_ids.dtype == torch.long + assert attention_mask.dtype == torch.long + assert mm_token_type_ids.dtype == torch.long + + if pixel_values is None: + assert image_grid_thw is None + assert image_metadata is None + else: + assert pixel_values.ndim == 4 + assert image_grid_thw.shape == (batch_size, pixel_values.shape[1], 3) + assert image_metadata.shape == (batch_size, pixel_values.shape[1], 2) + assert image_grid_thw.dtype == torch.long + assert image_metadata.dtype == torch.long + + active_slots = image_grid_thw[..., 0].eq(1) + assert torch.all(image_grid_thw[~active_slots].eq(0)) + if active_slots.any(): + assert torch.all(image_grid_thw[active_slots, 1:] > 0) + assert torch.all(image_metadata[active_slots] >= 0) + + return outputs + + +def _get_sample_image_mask(outputs, batch_index=0): + image_grid_thw = outputs["image_grid_thw"] + if image_grid_thw is None: + return torch.zeros((0,), dtype=torch.bool) + return image_grid_thw[batch_index, :, 0].eq(1) + + +def _assert_no_vision(outputs, batch_index=0): + assert not _get_sample_image_mask(outputs, batch_index=batch_index).any() + assert not outputs["mm_token_type_ids"][batch_index].eq(1).any() + + +def _assert_vision_segments(outputs, expected_segments, batch_index=0): + sample_image_mask = _get_sample_image_mask(outputs, batch_index=batch_index) + active_segments = int(sample_image_mask.sum().item()) + assert active_segments == expected_segments + assert torch.all(outputs["image_metadata"][batch_index, sample_image_mask, 1] > 0) + assert torch.all(outputs["image_grid_thw"][batch_index, sample_image_mask, 1:].prod(dim=-1) > 0) + + +def _count_modality(outputs, modality_value, batch_index=0): + return int( + (outputs["attention_mask"][batch_index].bool() & outputs["mm_token_type_ids"][batch_index].eq(modality_value)) + .sum() + .item() + ) + + +def _get_active_vision_grids(outputs, batch_index=0): + image_grid_thw = outputs["image_grid_thw"] + if image_grid_thw is None: + return torch.zeros((0, 2), dtype=torch.long) + return image_grid_thw[batch_index, _get_sample_image_mask(outputs, batch_index=batch_index), 1:] + + +def _get_active_vision_offsets(outputs, batch_index=0): + image_metadata = outputs["image_metadata"] + if image_metadata is None: + return torch.zeros((0,), dtype=torch.long) + return image_metadata[batch_index, _get_sample_image_mask(outputs, batch_index=batch_index), 0] + + +def _get_active_vision_lengths(outputs, batch_index=0): + image_metadata = outputs["image_metadata"] + if image_metadata is None: + return torch.zeros((0,), dtype=torch.long) + return image_metadata[batch_index, _get_sample_image_mask(outputs, batch_index=batch_index), 1] + + +def _get_expected_vision_lengths(outputs, pixel_shuffle_scale=1, batch_index=0): + grids = _get_active_vision_grids(outputs, batch_index=batch_index) + if grids.numel() == 0: + return grids.new_zeros((0,)) + return torch.prod(grids, dim=-1) // (pixel_shuffle_scale**2) + + +@pytest.fixture +def isaac_tiny_config(): + text_config = { + "bos_token_id": 0, + "eos_token_id": 1, + "pad_token_id": 2, + "hidden_act": "silu", + "head_dim": 32 // 4, + "hidden_size": 32, + "vocab_size": 99, + "intermediate_size": 32 * 3, + "max_position_embeddings": 128, + "model_type": "qwen3", + "num_attention_heads": 4, + "num_hidden_layers": 2, + "num_key_value_heads": 4, + "rope_parameters": {"rope_type": "default", "mrope_section": [2, 1, 1], "mrope_interleaved": True}, + "tie_word_embeddings": True, + } + + vision_config = { + "hidden_size": 32, + "intermediate_size": 32 * 2, + "num_hidden_layers": 1, + "num_attention_heads": 4, + "num_channels": 3, + "num_patches": 64, + "patch_size": 4, + "pixel_shuffle_scale_factor": 1, + "attention_dropout": 0.0, + "layer_norm_eps": 1e-6, + } + + config = IsaacConfig(text_config=text_config, vision_config=vision_config) + config._attn_implementation = "sdpa" + config.text_config._attn_implementation = "sdpa" + config.vision_attn_implementation = "sdpa" + return config + + +@pytest.fixture +def isaac_tokenizer(): + return SimpleIsaacTokenizer() + + +@pytest.fixture +def isaac_processor(isaac_tokenizer, isaac_tiny_config): + vision_config = isaac_tiny_config.vision_config + image_processor = IsaacImageProcessor( + patch_size=vision_config.patch_size, + max_num_patches=vision_config.num_patches, + pixel_shuffle_scale=vision_config.pixel_shuffle_scale_factor, + rescale_factor=isaac_tiny_config.vision_rescale_factor, + ) + return IsaacProcessorTestDouble( + image_processor=image_processor, + tokenizer=isaac_tokenizer, + max_sequence_length=isaac_tiny_config.max_sequence_length, + ) + + +BASE_MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1-Base") +BASE_MODEL_REVISION = os.environ.get("ISAAC_TEST_MODEL_REVISION", "refs/pr/3") or None +LOCAL_CHECKPOINT = os.environ.get("ISAAC_TEST_MODEL_PATH") + + +def _checkpoint_or_skip(model_id=BASE_MODEL_ID): + if LOCAL_CHECKPOINT: + resolved = Path(LOCAL_CHECKPOINT).expanduser() + if not resolved.exists(): + pytest.skip(f"Local checkpoint path {resolved} does not exist.") + return str(resolved) + if is_offline_mode(): + pytest.skip("Offline mode: set ISAAC_TEST_MODEL_PATH to a local checkpoint to run these tests.") + return model_id + + +@require_torch +@require_vision +class IsaacProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = IsaacProcessorTestDouble + model_id = BASE_MODEL_ID + images_input_name = "pixel_values" + + @classmethod + def _setup_from_pretrained(cls, model_id, **kwargs): + checkpoint = _checkpoint_or_skip(model_id) + return super()._setup_from_pretrained( + checkpoint, + revision=BASE_MODEL_REVISION, + patch_size=4, + max_num_patches=4, + **kwargs, + ) + + @classmethod + def _setup_test_attributes(cls, processor): + cls.image_token = processor.image_token + cls.pad_token_id = processor.tokenizer.pad_token_id + cls.image_pad_token_id = processor.image_token_id + + def prepare_image_inputs(self, batch_size: int | None = None, nested: bool = False): + if batch_size is None: + return _make_dummy_image(size=(16, 16)) + images = [_make_dummy_image(size=(16, 16), color=(50 * (i + 1), 0, 0)) for i in range(batch_size)] + if nested: + return [[image] for image in images] + return images + + def test_model_input_names(self): + processor = self.get_processor() + inputs = processor( + text=self.prepare_text_inputs(modalities="image"), + images=self.prepare_image_inputs(), + return_tensors="pt", + ) + + self.assertSetEqual(set(inputs.keys()), set(processor.model_input_names)) + + @unittest.skip("IsaacProcessor expands image placeholders into image pad tokens before tokenization") + def test_tokenizer_defaults(self): + pass + + @unittest.skip("IsaacProcessor does not return offset mappings needed for assistant masks") + def test_apply_chat_template_assistant_mask(self): + pass + + @unittest.skip("Isaac chat templates emit placeholders but the processor consumes image pad tokens") + def test_apply_chat_template_image_0(self): + pass + + @unittest.skip("Isaac chat templates emit placeholders but the processor consumes image pad tokens") + def test_apply_chat_template_image_1(self): + pass + + def test_get_num_multimodal_tokens_matches_processor_call(self): + processor = self.get_processor() + + image_sizes = [(100, 100), (300, 100), (500, 30), (213, 167)] + image_inputs = [np.random.randint(255, size=(h, w, 3), dtype=np.uint8) for h, w in image_sizes] + + text = [f"This is an image {self.image_token}"] * len(image_inputs) + inputs = processor( + text=text, + images=[[image] for image in image_inputs], + padding=True, + return_mm_token_type_ids=True, + return_tensors="pt", + ) + + num_image_tokens_from_call = inputs.mm_token_type_ids.sum(-1).tolist() + num_image_tokens_from_helper = processor._get_num_multimodal_tokens(image_sizes=image_sizes) + self.assertListEqual(num_image_tokens_from_call, num_image_tokens_from_helper["num_image_tokens"]) + + def test_single_vs_batched_consistency(self): + processor = self.get_processor() + prompt = f"hello {processor.image_token} world" + image = self.prepare_image_inputs() + + single = _assert_common(processor(text=prompt, images=[image], return_tensors="pt")) + batch = _assert_common( + processor(text=[prompt, "short"], images=[[image], []], return_tensors="pt"), batch_size=2 + ) + + single_ids = single["input_ids"].squeeze(0) + batch_ids = batch["input_ids"][0] + self.assertTrue(torch.equal(batch_ids[-single_ids.size(0) :], single_ids)) + + image_positions = batch["mm_token_type_ids"][0].eq(1) + if image_positions.any(): + self.assertTrue(torch.all(batch_ids[image_positions] == self.image_pad_token_id)) + self.assertTrue(torch.all(batch["attention_mask"][0][image_positions] == 1)) + + single_image_mask = _get_sample_image_mask(single, batch_index=0) + batch_image_mask = _get_sample_image_mask(batch, batch_index=0) + torch.testing.assert_close( + batch["pixel_values"][0, batch_image_mask], + single["pixel_values"][0, single_image_mask], + ) + torch.testing.assert_close( + batch["image_grid_thw"][0, batch_image_mask], + single["image_grid_thw"][0, single_image_mask], + ) + torch.testing.assert_close( + batch["image_metadata"][0, batch_image_mask], + single["image_metadata"][0, single_image_mask], + ) + + _assert_vision_segments(batch, expected_segments=1, batch_index=0) + _assert_no_vision(batch, batch_index=1) + + +@require_torch +@require_vision +def test_text_only_has_no_vision_fields(isaac_processor): + outputs = _assert_common(_run_processor(isaac_processor, text="Hello, how are you?", images=None)) + assert outputs["pixel_values"] is None + assert outputs["image_grid_thw"] is None + assert outputs["image_metadata"] is None + _assert_no_vision(outputs) + + +@require_torch +def test_post_process_generation_extracts_boxes_and_cleans_text(): + processor = _make_post_process_processor() + + generated_text = ( + "No, it is not safe to cross the street. " + '(808, 247), (863, 386)' + ) + + clean_text, annotations = processor.post_process_generation(generated_text) + + assert clean_text == "No, it is not safe to cross the street." + assert len(annotations) == 1 + box = annotations[0] + assert box.mention == "traffic light" + assert box.t == pytest.approx(0.5) + assert box.top_left.x == 808 + assert box.top_left.y == 247 + assert box.bottom_right.x == 863 + assert box.bottom_right.y == 386 + + +@require_torch +def test_post_process_generation_extracts_polygons_and_filters_by_expected_type(): + processor = _make_post_process_processor() + + generated_text = ( + 'Point (1, 2) ' + 'Box (3, 4), (5, 6) ' + 'Polygon (10, 20), (30, 40), (50, 60)' + ) + + clean_text, annotations = processor.post_process_generation(generated_text, expected="polygon") + + assert clean_text == "Point Box Polygon" + assert len(annotations) == 1 + polygon = annotations[0] + assert polygon.mention == "lane" + assert polygon.t == pytest.approx(0.25) + assert len(polygon.points) == 3 + assert polygon.points[0].x == 10 + assert polygon.points[0].y == 20 + assert polygon.points[1].x == 30 + assert polygon.points[1].y == 40 + assert polygon.points[2].x == 50 + assert polygon.points[2].y == 60 + + _, boxes = processor.post_process_generation(generated_text, expected="box") + assert len(boxes) == 1 + assert boxes[0].mention == "sign" + + +@require_torch +def test_post_process_generation_rejects_polygons_with_fewer_than_three_points(): + processor = _make_post_process_processor() + + with pytest.raises(ValueError, match=r"Malformed tag"): + processor.post_process_generation('(10, 20), (30, 40)', expected="polygon") + + +@require_torch +@require_vision +def test_single_image_returns_offsets_and_lengths(isaac_processor): + image_token = isaac_processor.image_token + outputs = _assert_common( + _run_processor( + isaac_processor, text=f"Look at this {image_token} and describe it.", images=[_make_dummy_image()] + ) + ) + _assert_vision_segments(outputs, expected_segments=1) + + grid_tokens = _get_expected_vision_lengths(outputs, isaac_processor.image_processor.pixel_shuffle_scale) + torch.testing.assert_close(_get_active_vision_lengths(outputs), grid_tokens) + torch.testing.assert_close( + _get_active_vision_offsets(outputs), torch.zeros_like(_get_active_vision_offsets(outputs)) + ) + + +@require_torch +@require_vision +def test_multiple_images_have_matching_offsets_lengths_and_grids(isaac_processor): + image_token = isaac_processor.image_token + images = [_make_dummy_image(color=(255, 0, 0)), _make_dummy_image(color=(0, 255, 0))] + + outputs = _assert_common( + _run_processor(isaac_processor, text=f"First {image_token} then {image_token}", images=images) + ) + _assert_vision_segments(outputs, expected_segments=2) + + grid_tokens = _get_expected_vision_lengths(outputs, isaac_processor.image_processor.pixel_shuffle_scale) + torch.testing.assert_close(_get_active_vision_lengths(outputs), grid_tokens) + torch.testing.assert_close( + _get_active_vision_offsets(outputs), torch.zeros_like(_get_active_vision_offsets(outputs)) + ) + + +@require_torch +@require_vision +def test_error_on_image_mismatch(isaac_processor): + image_token = isaac_processor.image_token + with pytest.raises(ValueError, match="one image per"): + _run_processor(isaac_processor, text=f"{image_token} {image_token}", images=[_make_dummy_image()]) + + +@require_torch +@require_vision +def test_consecutive_vision_tokens_allow_empty_text_segments(isaac_processor): + image_token = isaac_processor.image_token + images = [_make_dummy_image(), _make_dummy_image(color=(0, 0, 255))] + + outputs = _assert_common( + _run_processor(isaac_processor, text=f"prefix {image_token}{image_token} suffix", images=images) + ) + _assert_vision_segments(outputs, expected_segments=2) + + torch.testing.assert_close( + _get_active_vision_offsets(outputs), torch.zeros_like(_get_active_vision_offsets(outputs)) + ) + grid_tokens = _get_expected_vision_lengths(outputs, isaac_processor.image_processor.pixel_shuffle_scale) + torch.testing.assert_close(_get_active_vision_lengths(outputs), grid_tokens) + + +@require_torch +@require_vision +def test_device_and_dtype_consistency(isaac_processor): + image_token = isaac_processor.image_token + outputs = _assert_common( + _run_processor(isaac_processor, text=f"Describe this {image_token}", images=[_make_dummy_image()]) + ) + _assert_vision_segments(outputs, expected_segments=1) + + tensors = [ + outputs["input_ids"], + outputs["attention_mask"], + outputs["mm_token_type_ids"], + outputs["image_grid_thw"], + outputs["image_metadata"], + ] + devices = {tensor.device for tensor in tensors} + assert len(devices) == 1 + for tensor in tensors: + assert tensor.dtype == torch.long + + +@require_torch +@require_vision +def test_no_crop_when_total_below_max(isaac_processor): + image_token = isaac_processor.image_token + outputs = _assert_common( + _run_processor(isaac_processor, text=f"hello {image_token} world", images=[_make_dummy_image()]) + ) + _assert_vision_segments(outputs, expected_segments=1) + + grid_tokens = _get_expected_vision_lengths(outputs, isaac_processor.image_processor.pixel_shuffle_scale) + text_tokens = _count_modality(outputs, 0) + assert outputs["input_ids"].shape[1] == grid_tokens.item() + text_tokens + + +@require_torch +@require_vision +def test_exact_fit_keeps_all_tokens(isaac_processor, isaac_tokenizer, isaac_tiny_config): + image_token = isaac_processor.image_token + text = f"hey {image_token} there" + image = _make_dummy_image() + + base_outputs = _assert_common(_run_processor(isaac_processor, text=text, images=[image])) + base_length = base_outputs["input_ids"].shape[1] + base_vision_length = _get_active_vision_lengths(base_outputs).item() + + processor = _make_processor_with_max_len(isaac_tokenizer, isaac_tiny_config, base_length) + outputs = _assert_common(_run_processor(processor, text=text, images=[image])) + + _assert_vision_segments(outputs, expected_segments=1) + assert outputs["input_ids"].shape[1] == base_length + assert _get_active_vision_lengths(outputs).item() == base_vision_length + + +@require_torch +@require_vision +def test_crop_truncates_text_segment_only(isaac_processor, isaac_tokenizer, isaac_tiny_config): + image_token = isaac_processor.image_token + text_prefix_tokens = " ".join([f"t{i}" for i in range(8)]) + text = f"{text_prefix_tokens} {image_token} tail end" + image = _make_dummy_image() + + base_outputs = _assert_common(_run_processor(isaac_processor, text=text, images=[image])) + full_text_tokens = _count_modality(base_outputs, 0) + vision_length = _get_active_vision_lengths(base_outputs).item() + + max_len = base_outputs["input_ids"].shape[1] - 4 + processor = _make_processor_with_max_len(isaac_tokenizer, isaac_tiny_config, max_len) + outputs = _assert_common(_run_processor(processor, text=text, images=[image])) + + _assert_vision_segments(outputs, expected_segments=1) + assert outputs["input_ids"].shape[1] == max_len + assert _count_modality(outputs, 0) == full_text_tokens - 4 + torch.testing.assert_close( + _get_active_vision_offsets(outputs), torch.zeros_like(_get_active_vision_offsets(outputs)) + ) + assert _get_active_vision_lengths(outputs).item() == vision_length + + +@require_torch +@require_vision +def test_crop_cuts_through_image_segment(isaac_processor, isaac_tokenizer, isaac_tiny_config): + image_token = isaac_processor.image_token + text_before = "hi" + text_after = "bye" + text = f"{text_before} {image_token} {text_after}" + image = _make_dummy_image() + + base_outputs = _assert_common(_run_processor(isaac_processor, text=text, images=[image])) + vision_full = _get_active_vision_lengths(base_outputs).item() + text_before_len = len(isaac_tokenizer.encode(text_before, add_special_tokens=False)) + text_after_len = len(isaac_tokenizer.encode(text_after, add_special_tokens=False)) + total_length = vision_full + text_before_len + text_after_len + + max_len = 40 + start = total_length - max_len + expected_offset = max(0, start - text_before_len) + expected_length = vision_full - expected_offset + + processor = _make_processor_with_max_len(isaac_tokenizer, isaac_tiny_config, max_len) + outputs = _assert_common(_run_processor(processor, text=text, images=[image])) + + _assert_vision_segments(outputs, expected_segments=1) + assert outputs["input_ids"].shape[1] == max_len + assert _get_active_vision_offsets(outputs).item() == expected_offset + assert _get_active_vision_lengths(outputs).item() == expected_length + assert _count_modality(outputs, 0) == text_after_len + + +@require_torch +@require_vision +def test_batch_outputs_match_individual_calls(isaac_processor): + texts = ["hi", "this one is longer"] + + per_sample = [_assert_common(_run_processor(isaac_processor, text=text, images=None)) for text in texts] + batch_outputs = _assert_common(_run_processor(isaac_processor, text=texts, images=None), batch_size=len(texts)) + + pad_id = isaac_processor.pad_token_id + for index, single_output in enumerate(per_sample): + single_ids = single_output["input_ids"].squeeze(0) + single_mask = single_output["attention_mask"].squeeze(0) + single_mm = single_output["mm_token_type_ids"].squeeze(0) + + batch_ids = batch_outputs["input_ids"][index] + batch_mask = batch_outputs["attention_mask"][index] + batch_mm = batch_outputs["mm_token_type_ids"][index] + + single_len = single_ids.shape[0] + assert torch.equal(batch_ids[-single_len:], single_ids) + assert torch.equal(batch_mask[-single_len:], single_mask) + assert torch.equal(batch_mm[-single_len:], single_mm) + + if single_len < batch_ids.shape[0]: + pad_span = batch_ids[: batch_ids.shape[0] - single_len] + assert torch.all(pad_span == pad_id) + assert not torch.any(batch_mask[: batch_ids.shape[0] - single_len]) + + _assert_no_vision(batch_outputs, batch_index=index) From 81206dbca2c4e6c1799752b7a91218abf6f8c4d6 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Tue, 31 Mar 2026 14:38:04 +0400 Subject: [PATCH 71/77] check repo fixes --- docs/source/en/model_doc/isaac.md | 4 ++++ src/transformers/models/isaac/modeling_isaac.py | 8 +++++++- src/transformers/models/isaac/modular_isaac.py | 1 + utils/check_repo.py | 4 ++++ 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/isaac.md b/docs/source/en/model_doc/isaac.md index ce2ee88866d8..655eef43b6d8 100644 --- a/docs/source/en/model_doc/isaac.md +++ b/docs/source/en/model_doc/isaac.md @@ -132,6 +132,10 @@ Set `expected="point"` to extract point annotations, or leave `expected=None` to [[autodoc]] IsaacConfig +## IsaacVisionTransformer + +[[autodoc]] IsaacVisionTransformer + ## IsaacTextModel [[autodoc]] IsaacTextModel diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 01afede9a528..5f5547417dbf 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -1839,4 +1839,10 @@ def _expand_inputs_for_generation( return input_ids, model_kwargs -__all__ = ["IsaacTextModel", "IsaacModel", "IsaacPreTrainedModel", "IsaacForConditionalGeneration"] +__all__ = [ + "IsaacTextModel", + "IsaacVisionTransformer", + "IsaacModel", + "IsaacPreTrainedModel", + "IsaacForConditionalGeneration", +] diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index f9463d5d5b08..a7d865642d49 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -1672,6 +1672,7 @@ def _expand_inputs_for_generation( "IsaacTextConfig", "IsaacTextModel", "IsaacVisionConfig", + "IsaacVisionTransformer", "IsaacModel", "IsaacPreTrainedModel", # noqa: F822 "IsaacForConditionalGeneration", diff --git a/utils/check_repo.py b/utils/check_repo.py index 3ed1bb0abc12..74c6c3de4760 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -209,6 +209,8 @@ "Qwen3VLMoeTextModel", # Building part of bigger (tested) model. "Qwen3_5TextModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3_5ForConditionalGeneration. "Qwen3_5MoeTextModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3_5MoeForConditionalGeneration. + "IsaacTextModel", # Building part of bigger (tested) model. Tested implicitly through IsaacForConditionalGeneration. + "IsaacVisionTransformer", # Building part of bigger (tested) model. Tested implicitly through IsaacForConditionalGeneration. "Qwen2_5OmniForConditionalGeneration", # Not a regular model. Testted in Qwen2_5OmniModelIntergrationTest "Qwen2_5OmniTalkerForConditionalGeneration", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest. "Qwen2_5OmniTalkerModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest. @@ -459,6 +461,8 @@ "PaddleOCRVisionModel", # Building part of bigger (tested) model "PaddleOCRVisionTransformer", # Building part of bigger (tested) model "PaddleOCRTextModel", # Building part of bigger (tested) model + "IsaacTextModel", # Building part of a bigger model + "IsaacVisionTransformer", # Building part of a bigger model "Qwen2_5OmniTalkerForConditionalGeneration", # Building part of a bigger model "Qwen2_5OmniTalkerModel", # Building part of a bigger model "Qwen2_5OmniThinkerForConditionalGeneration", # Building part of a bigger model From 86235d40e19137394e30a37e748f85bee950a170 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Tue, 31 Mar 2026 15:15:51 +0400 Subject: [PATCH 72/77] add correct date --- docs/source/en/model_doc/isaac.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/isaac.md b/docs/source/en/model_doc/isaac.md index 655eef43b6d8..fc370847f839 100644 --- a/docs/source/en/model_doc/isaac.md +++ b/docs/source/en/model_doc/isaac.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on {release_date} and added to Hugging Face Transformers on 2026-03-24.* +*This model was released on {release_date} and added to Hugging Face Transformers on 2026-03-31.*
From e99bbc1a1edc1d6ce7dc2fea0b837b37044140f0 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Tue, 31 Mar 2026 15:52:01 +0400 Subject: [PATCH 73/77] fix: make the pointing types belong to processor class --- .../models/isaac/modeling_isaac.py | 40 ++-- .../models/isaac/modular_isaac.py | 217 +++++++++--------- .../models/isaac/processing_isaac.py | 199 +++++++++------- 3 files changed, 247 insertions(+), 209 deletions(-) diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 5f5547417dbf..507ae006ddbe 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -50,26 +50,6 @@ import torch.nn.functional as F -class SinglePoint(NamedTuple): - x: int - y: int - mention: str | None = None - t: float | None = None - - -class BoundingBox(NamedTuple): - top_left: SinglePoint - bottom_right: SinglePoint - mention: str | None = None - t: float | None = None - - -class Polygon(NamedTuple): - points: tuple[SinglePoint, ...] - mention: str | None = None - t: float | None = None - - class IsaacVisionEmbeddings(nn.Module): """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences. @@ -524,6 +504,26 @@ def forward(self, image_features): return hidden_states +class _SinglePoint(NamedTuple): + x: int + y: int + mention: str | None = None + t: float | None = None + + +class _BoundingBox(NamedTuple): + top_left: Any + bottom_right: Any + mention: str | None = None + t: float | None = None + + +class _Polygon(NamedTuple): + points: tuple[Any, ...] + mention: str | None = None + t: float | None = None + + class IsaacRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index a7d865642d49..1e3b0b32d27e 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -74,36 +74,6 @@ from ..pix2struct.image_processing_pix2struct import torch_extract_patches -class SinglePoint(NamedTuple): - x: int - y: int - mention: str | None = None - t: float | None = None - - -class BoundingBox(NamedTuple): - top_left: SinglePoint - bottom_right: SinglePoint - mention: str | None = None - t: float | None = None - - -class Polygon(NamedTuple): - points: tuple[SinglePoint, ...] - mention: str | None = None - t: float | None = None - - -IsaacAnnotation = SinglePoint | BoundingBox | Polygon - - -_POINT_BOX_OR_POLYGON_TAG = re.compile( - r"<(?Ppoint|point_box|polygon)(?P[^>]*)>(?P[\s\S]*?)", re.IGNORECASE -) -_ATTR_RE = re.compile(r"(\w+)\s*=\s*(?:\"([^\"]*)\"|([^\s>]+))") -_COORD_RE = re.compile(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)") - - class IsaacProcessorKwargs(ProcessingKwargs, total=False): _defaults = { "text_kwargs": { @@ -114,78 +84,6 @@ class IsaacProcessorKwargs(ProcessingKwargs, total=False): } -def _maybe_float(value: str | None) -> float | None: - if value is None: - return None - try: - return float(value) - except ValueError: - return None - - -def _parse_attrs(attr_text: str) -> dict[str, str]: - attrs = {} - for match in _ATTR_RE.finditer(attr_text or ""): - key = match.group(1) - value = match.group(2) or match.group(3) or "" - attrs[key] = value - return attrs - - -def _parse_point_body(body: str, mention: str | None = None, t: str | None = None) -> SinglePoint: - match = _COORD_RE.search(body) - if not match: - raise ValueError(f"Malformed tag: {body!r}") - x, y = int(match.group(1)), int(match.group(2)) - return SinglePoint(x=x, y=y, mention=mention, t=_maybe_float(t)) - - -def _parse_box_body(body: str, mention: str | None = None, t: str | None = None) -> BoundingBox: - coords = list(_COORD_RE.finditer(body)) - if len(coords) < 2: - raise ValueError(f"Malformed tag: {body!r}") - - top_left = SinglePoint(x=int(coords[0].group(1)), y=int(coords[0].group(2))) - bottom_right = SinglePoint(x=int(coords[1].group(1)), y=int(coords[1].group(2))) - return BoundingBox(top_left=top_left, bottom_right=bottom_right, mention=mention, t=_maybe_float(t)) - - -def _parse_polygon_body(body: str, mention: str | None = None, t: str | None = None) -> Polygon: - coords = list(_COORD_RE.finditer(body)) - if len(coords) < 3: - raise ValueError(f"Malformed tag: {body!r}") - - points = tuple(SinglePoint(x=int(coord.group(1)), y=int(coord.group(2))) for coord in coords) - return Polygon(points=points, mention=mention, t=_maybe_float(t)) - - -def clean_text_and_extract_points( - text: str, - expected: str | None = None, -) -> tuple[str, list[IsaacAnnotation]]: - results: list[IsaacAnnotation] = [] - for match in _POINT_BOX_OR_POLYGON_TAG.finditer(text or ""): - tag = match.group("tag").lower() - attrs = _parse_attrs(match.group("attrs")) - mention = attrs.get("mention") - t = attrs.get("t") - if tag == "point": - if expected not in (None, "point"): - continue - results.append(_parse_point_body(match.group("body"), mention=mention, t=t)) - elif tag == "point_box": - if expected not in (None, "box"): - continue - results.append(_parse_box_body(match.group("body"), mention=mention, t=t)) - else: - if expected not in (None, "polygon"): - continue - results.append(_parse_polygon_body(match.group("body"), mention=mention, t=t)) - - clean_text = re.sub(r"\s+", " ", _POINT_BOX_OR_POLYGON_TAG.sub("", text or "")).strip() - return clean_text, results - - @auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") @strict class IsaacVisionConfig(Siglip2VisionConfig): @@ -849,6 +747,29 @@ def __post_init__(self, **kwargs): @auto_docstring class IsaacProcessor(ProcessorMixin): + class _SinglePoint(NamedTuple): + x: int + y: int + mention: str | None = None + t: float | None = None + + class _BoundingBox(NamedTuple): + top_left: Any + bottom_right: Any + mention: str | None = None + t: float | None = None + + class _Polygon(NamedTuple): + points: tuple[Any, ...] + mention: str | None = None + t: float | None = None + + _point_box_or_polygon_tag = re.compile( + r"<(?Ppoint|point_box|polygon)(?P[^>]*)>(?P[\s\S]*?)", re.IGNORECASE + ) + _attr_re = re.compile(r"(\w+)\s*=\s*(?:\"([^\"]*)\"|([^\s>]+))") + _coord_re = re.compile(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)") + def __init__( self, image_processor, @@ -880,6 +801,94 @@ def __init__( def model_input_names(self): return super().model_input_names + ["mm_token_type_ids", "image_metadata"] + @staticmethod + def _maybe_float(value: str | None) -> float | None: + if value is None: + return None + try: + return float(value) + except ValueError: + return None + + @classmethod + def _parse_attrs(cls, attr_text: str) -> dict[str, str]: + attrs = {} + for match in cls._attr_re.finditer(attr_text or ""): + key = match.group(1) + value = match.group(2) or match.group(3) or "" + attrs[key] = value + return attrs + + @classmethod + def _parse_point_body( + cls, + body: str, + mention: str | None = None, + t: str | None = None, + ) -> Any: + match = cls._coord_re.search(body) + if not match: + raise ValueError(f"Malformed tag: {body!r}") + x, y = int(match.group(1)), int(match.group(2)) + return cls._SinglePoint(x=x, y=y, mention=mention, t=cls._maybe_float(t)) + + @classmethod + def _parse_box_body( + cls, + body: str, + mention: str | None = None, + t: str | None = None, + ) -> Any: + coords = list(cls._coord_re.finditer(body)) + if len(coords) < 2: + raise ValueError(f"Malformed tag: {body!r}") + + top_left = cls._SinglePoint(x=int(coords[0].group(1)), y=int(coords[0].group(2))) + bottom_right = cls._SinglePoint(x=int(coords[1].group(1)), y=int(coords[1].group(2))) + return cls._BoundingBox(top_left=top_left, bottom_right=bottom_right, mention=mention, t=cls._maybe_float(t)) + + @classmethod + def _parse_polygon_body( + cls, + body: str, + mention: str | None = None, + t: str | None = None, + ) -> Any: + coords = list(cls._coord_re.finditer(body)) + if len(coords) < 3: + raise ValueError(f"Malformed tag: {body!r}") + + points = tuple(cls._SinglePoint(x=int(coord.group(1)), y=int(coord.group(2))) for coord in coords) + return cls._Polygon(points=points, mention=mention, t=cls._maybe_float(t)) + + @classmethod + def clean_text_and_extract_points( + cls, + text: str, + expected: str | None = None, + ) -> tuple[str, list[Any]]: + results: list[Any] = [] + for match in cls._point_box_or_polygon_tag.finditer(text or ""): + tag = match.group("tag").lower() + attrs = cls._parse_attrs(match.group("attrs")) + mention = attrs.get("mention") + t = attrs.get("t") + if tag == "point": + if expected not in (None, "point"): + continue + results.append(cls._parse_point_body(match.group("body"), mention=mention, t=t)) + elif tag == "point_box": + if expected not in (None, "box"): + continue + results.append(cls._parse_box_body(match.group("body"), mention=mention, t=t)) + else: + if expected not in (None, "polygon"): + continue + results.append(cls._parse_polygon_body(match.group("body"), mention=mention, t=t)) + + clean_text = re.sub(r"\s+", " ", cls._point_box_or_polygon_tag.sub("", text or "")).strip() + return clean_text, results + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): vision_data = {} if image_sizes is not None: @@ -901,9 +910,9 @@ def post_process_generation( text: str, expected: str | None = None, cleanup_and_extract: bool = True, - ) -> str | tuple[str, list[IsaacAnnotation]]: + ) -> str | tuple[str, list[Any]]: if cleanup_and_extract: - return clean_text_and_extract_points(text, expected=expected) + return self.clean_text_and_extract_points(text, expected=expected) return text def post_process_image_text_to_text( diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index a9ec49908530..4ba3b68e5542 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -19,13 +19,13 @@ # limitations under the License. import re +from typing import Any, NamedTuple from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, make_nested_list_of_images from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin from ...utils import TensorType, auto_docstring from ...utils.import_utils import is_torch_available -from .modeling_isaac import BoundingBox, Polygon, SinglePoint if is_torch_available(): @@ -42,90 +42,31 @@ class IsaacProcessorKwargs(ProcessingKwargs, total=False): } -IsaacAnnotation = SinglePoint | BoundingBox | Polygon - - -_POINT_BOX_OR_POLYGON_TAG = re.compile( - r"<(?Ppoint|point_box|polygon)(?P[^>]*)>(?P[\s\S]*?)", re.IGNORECASE -) -_ATTR_RE = re.compile(r"(\w+)\s*=\s*(?:\"([^\"]*)\"|([^\s>]+))") -_COORD_RE = re.compile(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)") - - -def _maybe_float(value: str | None) -> float | None: - if value is None: - return None - try: - return float(value) - except ValueError: - return None - - -def _parse_attrs(attr_text: str) -> dict[str, str]: - attrs = {} - for match in _ATTR_RE.finditer(attr_text or ""): - key = match.group(1) - value = match.group(2) or match.group(3) or "" - attrs[key] = value - return attrs - - -def _parse_point_body(body: str, mention: str | None = None, t: str | None = None) -> SinglePoint: - match = _COORD_RE.search(body) - if not match: - raise ValueError(f"Malformed tag: {body!r}") - x, y = int(match.group(1)), int(match.group(2)) - return SinglePoint(x=x, y=y, mention=mention, t=_maybe_float(t)) - - -def _parse_box_body(body: str, mention: str | None = None, t: str | None = None) -> BoundingBox: - coords = list(_COORD_RE.finditer(body)) - if len(coords) < 2: - raise ValueError(f"Malformed tag: {body!r}") - - top_left = SinglePoint(x=int(coords[0].group(1)), y=int(coords[0].group(2))) - bottom_right = SinglePoint(x=int(coords[1].group(1)), y=int(coords[1].group(2))) - return BoundingBox(top_left=top_left, bottom_right=bottom_right, mention=mention, t=_maybe_float(t)) - - -def _parse_polygon_body(body: str, mention: str | None = None, t: str | None = None) -> Polygon: - coords = list(_COORD_RE.finditer(body)) - if len(coords) < 3: - raise ValueError(f"Malformed tag: {body!r}") - - points = tuple(SinglePoint(x=int(coord.group(1)), y=int(coord.group(2))) for coord in coords) - return Polygon(points=points, mention=mention, t=_maybe_float(t)) - - -def clean_text_and_extract_points( - text: str, - expected: str | None = None, -) -> tuple[str, list[IsaacAnnotation]]: - results: list[IsaacAnnotation] = [] - for match in _POINT_BOX_OR_POLYGON_TAG.finditer(text or ""): - tag = match.group("tag").lower() - attrs = _parse_attrs(match.group("attrs")) - mention = attrs.get("mention") - t = attrs.get("t") - if tag == "point": - if expected not in (None, "point"): - continue - results.append(_parse_point_body(match.group("body"), mention=mention, t=t)) - elif tag == "point_box": - if expected not in (None, "box"): - continue - results.append(_parse_box_body(match.group("body"), mention=mention, t=t)) - else: - if expected not in (None, "polygon"): - continue - results.append(_parse_polygon_body(match.group("body"), mention=mention, t=t)) - - clean_text = re.sub(r"\s+", " ", _POINT_BOX_OR_POLYGON_TAG.sub("", text or "")).strip() - return clean_text, results - - @auto_docstring class IsaacProcessor(ProcessorMixin): + class _SinglePoint(NamedTuple): + x: int + y: int + mention: str | None = None + t: float | None = None + + class _BoundingBox(NamedTuple): + top_left: Any + bottom_right: Any + mention: str | None = None + t: float | None = None + + class _Polygon(NamedTuple): + points: tuple[Any, ...] + mention: str | None = None + t: float | None = None + + _point_box_or_polygon_tag = re.compile( + r"<(?Ppoint|point_box|polygon)(?P[^>]*)>(?P[\s\S]*?)", re.IGNORECASE + ) + _attr_re = re.compile(r"(\w+)\s*=\s*(?:\"([^\"]*)\"|([^\s>]+))") + _coord_re = re.compile(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)") + def __init__( self, image_processor, @@ -157,6 +98,94 @@ def __init__( def model_input_names(self): return super().model_input_names + ["mm_token_type_ids", "image_metadata"] + @staticmethod + def _maybe_float(value: str | None) -> float | None: + if value is None: + return None + try: + return float(value) + except ValueError: + return None + + @classmethod + def _parse_attrs(cls, attr_text: str) -> dict[str, str]: + attrs = {} + for match in cls._attr_re.finditer(attr_text or ""): + key = match.group(1) + value = match.group(2) or match.group(3) or "" + attrs[key] = value + return attrs + + @classmethod + def _parse_point_body( + cls, + body: str, + mention: str | None = None, + t: str | None = None, + ) -> Any: + match = cls._coord_re.search(body) + if not match: + raise ValueError(f"Malformed tag: {body!r}") + x, y = int(match.group(1)), int(match.group(2)) + return cls._SinglePoint(x=x, y=y, mention=mention, t=cls._maybe_float(t)) + + @classmethod + def _parse_box_body( + cls, + body: str, + mention: str | None = None, + t: str | None = None, + ) -> Any: + coords = list(cls._coord_re.finditer(body)) + if len(coords) < 2: + raise ValueError(f"Malformed tag: {body!r}") + + top_left = cls._SinglePoint(x=int(coords[0].group(1)), y=int(coords[0].group(2))) + bottom_right = cls._SinglePoint(x=int(coords[1].group(1)), y=int(coords[1].group(2))) + return cls._BoundingBox(top_left=top_left, bottom_right=bottom_right, mention=mention, t=cls._maybe_float(t)) + + @classmethod + def _parse_polygon_body( + cls, + body: str, + mention: str | None = None, + t: str | None = None, + ) -> Any: + coords = list(cls._coord_re.finditer(body)) + if len(coords) < 3: + raise ValueError(f"Malformed tag: {body!r}") + + points = tuple(cls._SinglePoint(x=int(coord.group(1)), y=int(coord.group(2))) for coord in coords) + return cls._Polygon(points=points, mention=mention, t=cls._maybe_float(t)) + + @classmethod + def clean_text_and_extract_points( + cls, + text: str, + expected: str | None = None, + ) -> tuple[str, list[Any]]: + results: list[Any] = [] + for match in cls._point_box_or_polygon_tag.finditer(text or ""): + tag = match.group("tag").lower() + attrs = cls._parse_attrs(match.group("attrs")) + mention = attrs.get("mention") + t = attrs.get("t") + if tag == "point": + if expected not in (None, "point"): + continue + results.append(cls._parse_point_body(match.group("body"), mention=mention, t=t)) + elif tag == "point_box": + if expected not in (None, "box"): + continue + results.append(cls._parse_box_body(match.group("body"), mention=mention, t=t)) + else: + if expected not in (None, "polygon"): + continue + results.append(cls._parse_polygon_body(match.group("body"), mention=mention, t=t)) + + clean_text = re.sub(r"\s+", " ", cls._point_box_or_polygon_tag.sub("", text or "")).strip() + return clean_text, results + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): vision_data = {} if image_sizes is not None: @@ -178,9 +207,9 @@ def post_process_generation( text: str, expected: str | None = None, cleanup_and_extract: bool = True, - ) -> str | tuple[str, list[IsaacAnnotation]]: + ) -> str | tuple[str, list[Any]]: if cleanup_and_extract: - return clean_text_and_extract_points(text, expected=expected) + return self.clean_text_and_extract_points(text, expected=expected) return text def post_process_image_text_to_text( From 24af7781d2dbb8c318d6a7b0986cc127f816377c Mon Sep 17 00:00:00 2001 From: philippguevorguian <73610213+philippguevorguian@users.noreply.github.com> Date: Mon, 13 Apr 2026 18:22:39 +0400 Subject: [PATCH 74/77] style: pre final review (#20) * clean up and rearrange code * fix: allow cutting through any image span * test: cropping middle of arbitrary image * test: drop stale / redundant tests * test: drop flash attn debug test * test: drop stale helpers * fix: restore isaac processor compatibility * test: use public api in isaac integration tests * fix: restore isaac generation outputs * simplif * style: simplify 2 * test: drop redundant tests * test: drop more low level image processing tests * test: no plan to define 4 channel numpy processing * test: drop image processor properties test * test: focus image processing tests * test: drop unneeded input trimming helper, chat template now omits the newline by default * tests: enable Isaac tokenizer defaults coverage * isaac: support assistant mask chat template tests * tests: cover Isaac image placeholder expansion * tests: patch Isaac chat template for assistant masks * tests: use Isaac default assistant mask template * tests: align Isaac image batching coverage * tests: drop unneeded utilities / low-level tests * style: isaacvisionmodel not isaacvisiontransformer * tests: clean up imports * wip 1 * style: drop now unneeded check_argument_for_proper_class override * test: don't skip where assisted decoding works * style: inherit from closer base class * style: lint * chore: convert artifacts --------- Co-authored-by: raushan --- docs/source/en/model_doc/isaac.md | 89 +- .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 1 + .../models/isaac/configuration_isaac.py | 10 - .../models/isaac/image_processing_isaac.py | 153 +- .../models/isaac/modeling_isaac.py | 479 ++-- .../models/isaac/modular_isaac.py | 2523 ++++++++--------- .../models/isaac/processing_isaac.py | 412 ++- .../isaac/test_image_processing_isaac.py | 314 +- tests/models/isaac/test_modeling_isaac.py | 1420 +++------- tests/models/isaac/test_processing_isaac.py | 821 +----- utils/check_repo.py | 5 +- 12 files changed, 2212 insertions(+), 4017 deletions(-) diff --git a/docs/source/en/model_doc/isaac.md b/docs/source/en/model_doc/isaac.md index fc370847f839..91773a4962ed 100644 --- a/docs/source/en/model_doc/isaac.md +++ b/docs/source/en/model_doc/isaac.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on {release_date} and added to Hugging Face Transformers on 2026-03-31.* +*This model was released on {release_date} and added to Hugging Face Transformers on 2026-04-13.*
@@ -35,78 +35,61 @@ long multimodal prompts manageable. For more information, refer to the [technica Isaac checkpoints are distributed under Perceptron's Non-Production license; please review the license that ships with the weights before using them in commercial settings. -## Usage +## Usage tips -Isaac uses explicit image placeholders in the rendered prompt. Every occurrence of `processor.image_token` (usually -``) must have a matching image in the `images` argument. +- Batched inputs can mix text-only and multimodal samples. For direct processor/model batching, pass images as a nested + list such as `[[], [image_a], [image_b, image_c]]`. +- `image_grid_thw[batch_idx, image_slot] == (0, 0, 0)` marks a padded empty slot. Real image slots have + `(T=1, H>0, W>0)`. +- If truncation is enabled, the processor keeps the rightmost part of the multimodal prompt and updates the slot-local + `image_metadata[..., 0]` and `image_metadata[..., 1]` values automatically. + +## Usage example + +Isaac uses explicit image placeholders in the rendered prompt. Every occurrence of `processor.image_token` (usually ``) must have a matching image in the `images` argument. ```py import torch from PIL import Image from transformers import AutoProcessor, IsaacForConditionalGeneration -model_id = "Perceptron/isaac-base" +model_id = "PerceptronAI/Isaac-0.1" processor = AutoProcessor.from_pretrained(model_id) model = IsaacForConditionalGeneration.from_pretrained( model_id, - torch_dtype=torch.bfloat16, + dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2", ) -images = [Image.open("chart.png"), Image.open("panel.jpg")] -messages = [ - {"role": "user", "content": "Compare the two figures and explain what changed."}, - {"role": "user", "content": f"{processor.image_token}{processor.image_token}"}, +conversation = [ + { + + "role": "user", + "content": [ + {"type": "text", "text": "Compare the two figures and explain what changed."}, + {"type": "image", "path": "first_image.png"}, + {"type": "image", "path": "second_image.png"}, + ], + }, ] prompt = processor.apply_chat_template( - messages, - tokenize=False, + conversation, + tokenize=True, + return_dict=True, add_generation_prompt=True, -).strip() - -inputs = processor(text=prompt, images=images, return_tensors="pt") -model_inputs = { - key: value.to(model.device) - for key, value in inputs.items() - if value is not None -} - -with torch.inference_mode(): - generated_ids = model.generate( - **model_inputs, - max_new_tokens=256, - do_sample=False, - eos_token_id=processor.tokenizer.eos_token_id, - pad_token_id=processor.tokenizer.eos_token_id, - ) - -generated_ids = generated_ids[:, model_inputs["input_ids"].shape[1] :] + return_tensors="pt", +) + +inputs = processor(text=prompt, images=images, return_tensors="pt").to(model.device) +generated_ids = model.generate(**inputs, max_new_tokens=256, do_sample=False,) + +generated_ids = generated_ids[:, inputs["input_ids"].shape[1] :] response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] print(response) ``` -`IsaacProcessor` returns the standard text tensors plus Isaac's batch-major visual tensors: - -- `pixel_values`: `(batch_size, max_images, max_patches, patch_dim)` -- `image_grid_thw`: `(batch_size, max_images, 3)` -- `image_metadata`: `(batch_size, max_images, 2)` storing `(offset, length)` for each image slot -- `mm_token_type_ids`: `(batch_size, sequence_length)` - -Important notes: - -- Pass the full processor output to `generate()`. Isaac uses the multimodal tensors during prefill and handles cached - decoding internally. -- For fully text-only batches, `pixel_values`, `image_grid_thw`, and `image_metadata` are `None`. When moving inputs to - the model, keep only non-`None` values as shown above. -- Batched inputs can mix text-only and multimodal samples. For direct processor/model batching, pass images as a nested - list such as `[[], [image_a], [image_b, image_c]]`. -- `image_grid_thw[batch_idx, image_slot] == (0, 0, 0)` marks a padded empty slot. Real image slots have - `(T=1, H>0, W>0)`. -- If truncation is enabled, the processor keeps the rightmost part of the multimodal prompt and updates the slot-local - `image_metadata[..., 0]` and `image_metadata[..., 1]` values automatically. - ### Post-processing grounded outputs Isaac can generate grounded points and boxes in tagged text spans. Use `post_process_generation()` to strip the tags and @@ -132,9 +115,9 @@ Set `expected="point"` to extract point annotations, or leave `expected=None` to [[autodoc]] IsaacConfig -## IsaacVisionTransformer +## IsaacVisionModel -[[autodoc]] IsaacVisionTransformer +[[autodoc]] IsaacVisionModel ## IsaacTextModel diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 149138596c14..305aaf5f4e39 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -239,6 +239,7 @@ ("internvl", "InternVLConfig"), ("internvl_vision", "InternVLVisionConfig"), ("isaac", "IsaacConfig"), + ("isaac_vision", "IsaacVisionConfig"), ("jais2", "Jais2Config"), ("jamba", "JambaConfig"), ("janus", "JanusConfig"), @@ -760,6 +761,7 @@ ("internvl", "InternVL"), ("internvl_vision", "InternVLVision"), ("isaac", "Isaac"), + ("isaac_vision", "IsaacVision"), ("jais2", "Jais2"), ("jamba", "Jamba"), ("janus", "Janus"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 1f2d74bed813..120668f4ac8d 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -236,6 +236,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("internvl", "InternVLModel"), ("internvl_vision", "InternVLVisionModel"), ("isaac", "IsaacModel"), + ("isaac_vision", "IsaacVisionModel"), ("jais2", "Jais2Model"), ("jamba", "JambaModel"), ("janus", "JanusModel"), diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index 646230c3b03e..9e4c58954ffc 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -52,7 +52,6 @@ class IsaacVisionConfig(PreTrainedConfig): attention_dropout: float | int = 0.0 num_patches: int = 256 - pixel_shuffle_scale_factor: int = 1 @@ -109,7 +108,6 @@ class IsaacTextConfig(PreTrainedConfig): attention_bias: bool = False use_sliding_window: bool = False max_window_layers: int = 28 - layer_types: list[str] | None = None attention_dropout: float | int = 0.0 pad_token_id: int | None = None bos_token_id: int | None = None @@ -120,11 +118,7 @@ def __post_init__(self, **kwargs): if self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads - if self.layer_types is None: - self.layer_types = ["full_attention" for _ in range(self.num_hidden_layers)] - PretrainedConfig.__post_init__(self, **kwargs) - self.validate_layer_type() @auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") @@ -159,9 +153,6 @@ class IsaacConfig(PretrainedConfig): max_sequence_length: int = 16384 def __post_init__(self, **kwargs): - for key in ("use_cache", "rope_theta", "max_position_embeddings"): - kwargs.pop(key, None) - if isinstance(self.text_config, dict): self.text_config = self.sub_configs["text_config"](**self.text_config) elif self.text_config is None: @@ -181,7 +172,6 @@ def __post_init__(self, **kwargs): ) self.vision_rescale_factor = float(self.vision_rescale_factor) - super().__post_init__(**kwargs) diff --git a/src/transformers/models/isaac/image_processing_isaac.py b/src/transformers/models/isaac/image_processing_isaac.py index 460ea04452d6..5f5af7905b6b 100644 --- a/src/transformers/models/isaac/image_processing_isaac.py +++ b/src/transformers/models/isaac/image_processing_isaac.py @@ -20,23 +20,25 @@ import math -from collections.abc import Sequence from typing import Any from ... import TorchvisionBackend from ...feature_extraction_utils import BatchFeature from ...image_transforms import group_images_by_shape, reorder_images from ...image_utils import ImageInput, PILImageResampling, SizeDict, make_nested_list_of_images -from ...processing_utils import ImagesKwargs +from ...processing_utils import ImagesKwargs, Unpack from ...utils import TensorType, auto_docstring -from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN -from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD +from ...utils.constants import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD from ...utils.import_utils import is_torch_available if is_torch_available(): import torch import torch.nn.functional as F + from torchvision.transforms.v2 import functional as tvF + + +# --------------------------------Isaac Image Processor-------------------------------- class IsaacImageProcessorKwargs(ImagesKwargs, total=False): @@ -173,29 +175,27 @@ def get_image_size_for_max_num_patches( @auto_docstring class IsaacImageProcessor(TorchvisionBackend): - MAX_PIXELS = 60_000_000 # 60โ€‘megapixel ceiling โ‰ˆ 8200 ร— 7300 px - - resample = PILImageResampling.BILINEAR - model_input_names = [ - "pixel_values", - "image_grid_thw", - ] + model_input_names = ["pixel_values", "image_grid_thw"] valid_kwargs = IsaacImageProcessorKwargs + resample = PILImageResampling.BILINEAR do_resize = True do_center_crop = False - patch_size: int | None = 16 - max_num_patches: int | None = 256 - min_num_patches: int | None = None - pixel_shuffle_scale: int | None = 1 + patch_size = 16 + max_num_patches = 256 + min_num_patches = None + pixel_shuffle_scale = 1 do_pad = True do_rescale = True do_normalize = True - image_mean = list(VISION_MEAN) - image_std = list(VISION_STD) + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD do_convert_rgb = True disable_grouping = False + def __init__(self, **kwargs: Unpack[IsaacImageProcessorKwargs]): + super().__init__(**kwargs) + def _validate_preprocess_kwargs(self, **kwargs): # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for. kwargs.pop("do_resize", None) @@ -220,55 +220,28 @@ def resize( return image.clamp(0, 255).round().to(torch.uint8) return F.interpolate(image, size=(size.height, size.width), mode="bilinear", align_corners=False) - def get_number_of_image_patches( - self, - image_height: int, - image_width: int, - images_kwargs: dict[str, Any] | None = None, - ) -> int: - images_kwargs = images_kwargs or {} - patch_size = images_kwargs.get("patch_size", self.patch_size) - max_num_patches = images_kwargs.get("max_num_patches", self.max_num_patches) - min_num_patches = images_kwargs.get("min_num_patches", self.min_num_patches) - pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale", self.pixel_shuffle_scale) - - target_height, target_width = get_image_size_for_max_num_patches( - image_height, - image_width, - patch_size, - max_num_patches, - min_num_patches=min_num_patches, - pixel_shuffle_scale=pixel_shuffle_scale, - ) - return (target_height // patch_size) * (target_width // patch_size) - def pack_images( self, vision_patches: list[list[torch.Tensor]], vision_token_grids: list[list[torch.Tensor]], ) -> dict[str, torch.Tensor | None]: batch_size = len(vision_patches) - max_images = max((len(sample_patches) for sample_patches in vision_patches), default=0) flat_patches = [patches for sample_patches in vision_patches for patches in sample_patches] - if max_images == 0 or not flat_patches: - return { - "pixel_values": None, - "image_grid_thw": None, - } + if len(flat_patches) == 0: + return {"pixel_values": None, "image_grid_thw": None} first_patch = flat_patches[0] max_patches = max(patches.shape[0] for patches in flat_patches) - patch_dim = first_patch.shape[-1] - patch_dtype = first_patch.dtype - patch_device = first_patch.device + max_images = max((len(sample_patches) for sample_patches in vision_patches), default=0) + patch_dim = first_patch.shape[-1] tensors = { "pixel_values": torch.zeros( (batch_size, max_images, max_patches, patch_dim), - device=patch_device, - dtype=patch_dtype, + device=first_patch.device, + dtype=first_patch.dtype, ), - "image_grid_thw": torch.zeros((batch_size, max_images, 3), device=patch_device, dtype=torch.long), + "image_grid_thw": torch.zeros((batch_size, max_images, 3), device=first_patch.device, dtype=torch.long), } for batch_idx, (sample_patches, sample_token_grids) in enumerate( @@ -286,48 +259,39 @@ def _preprocess( self, images: list[list[torch.Tensor]], do_resize: bool, - resample: Any | None, - do_rescale: bool | None, - rescale_factor: float | None, - do_normalize: bool | None, - image_mean: float | Sequence[float] | None, - image_std: float | Sequence[float] | None, - do_pad: bool | None = None, - disable_grouping: bool | None = None, - return_tensors: str | TensorType | None = None, - patch_size: int | None = None, - max_num_patches: int | None = None, - min_num_patches: int | None = None, - pixel_shuffle_scale: int | None = None, + resample: PILImageResampling | tvF.InterpolationMode | int | None, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: float | list[float] | None, + image_std: float | list[float] | None, + do_pad: bool, + patch_size: int, + max_num_patches: int, + min_num_patches: int, + pixel_shuffle_scale: int, + disable_grouping: bool | None, + return_tensors: str | TensorType | None, **kwargs, ) -> BatchFeature: - resample = kwargs.pop("interpolation", resample) - # IsaacProcessor routes text-only calls here as an empty image list per sample. - # Return `None` visual fields so text-only batches skip multimodal codepaths like other VLMs. if all(len(sample_images) == 0 for sample_images in images): - tensors = { - "pixel_values": None, - "image_grid_thw": None, - } - return BatchFeature(data=tensors, tensor_type=return_tensors) + return BatchFeature(data={"pixel_values": None, "image_grid_thw": None}, tensor_type=return_tensors) grouped_images, grouped_images_index = group_images_by_shape( images, disable_grouping=disable_grouping, is_nested=True ) - grouped_outputs = {} - for shape, stacked_images in grouped_images.items(): grouped_batch_size, channels, original_height, original_width = stacked_images.shape - target_height, target_width = get_image_size_for_max_num_patches( - original_height, - original_width, - patch_size, - max_num_patches, - min_num_patches=min_num_patches, - pixel_shuffle_scale=pixel_shuffle_scale, - ) if do_resize: + target_height, target_width = get_image_size_for_max_num_patches( + original_height, + original_width, + patch_size, + max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + ) image_batch = self.resize( stacked_images, SizeDict(height=target_height, width=target_width), resample=resample ) @@ -356,7 +320,8 @@ def _preprocess( if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): raise ValueError( - f"Token grid (h={height_tokens}, w={width_tokens}) must be divisible by pixel_shuffle_scale={pixel_shuffle_scale}; adjust resize/patch parameters or disable pixel shuffle." + f"Token grid (h={height_tokens}, w={width_tokens}) must be divisible by pixel_shuffle_scale={pixel_shuffle_scale};" + f" adjust resize/patch parameters or disable pixel shuffle." ) grouped_outputs[shape] = ( @@ -383,5 +348,27 @@ def _preprocess( return BatchFeature(data=tensors, tensor_type=return_tensors) + def get_number_of_image_patches( + self, + image_height: int, + image_width: int, + images_kwargs: dict[str, Any] | None = None, + ) -> int: + images_kwargs = images_kwargs or {} + patch_size = images_kwargs.get("patch_size", self.patch_size) + max_num_patches = images_kwargs.get("max_num_patches", self.max_num_patches) + min_num_patches = images_kwargs.get("min_num_patches", self.min_num_patches) + pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale", self.pixel_shuffle_scale) + + target_height, target_width = get_image_size_for_max_num_patches( + image_height, + image_width, + patch_size, + max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + ) + return (target_height // patch_size) * (target_width // patch_size) + __all__ = ["IsaacImageProcessor"] diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py index 507ae006ddbe..1d293615cc37 100644 --- a/src/transformers/models/isaac/modeling_isaac.py +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -30,16 +30,18 @@ from ...masking_utils import create_bidirectional_mask, create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + CausalLMOutputWithPast, +) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, torch_compilable_check from ...utils.generic import TransformersKwargs, can_return_tuple, maybe_autocast, merge_with_config_defaults -from ...utils.import_utils import ( - is_torch_available, - is_torchdynamo_compiling, -) +from ...utils.import_utils import is_torch_available, is_torchdynamo_compiling from ...utils.output_capturing import capture_outputs from .configuration_isaac import IsaacConfig, IsaacTextConfig, IsaacVisionConfig @@ -228,15 +230,12 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor | None]: """Input shape: Batch x Time x Channel""" - batch_size, seq_length, embed_dim = hidden_states.shape - - queries = self.q_proj(hidden_states) - keys = self.k_proj(hidden_states) - values = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] - queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + hidden_shape = (*input_shape, -1, self.head_dim) + queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward @@ -253,7 +252,7 @@ def forward( dropout=0.0 if not self.training else self.dropout, ) - attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.out_proj(attn_output) return attn_output, attn_weights @@ -338,78 +337,8 @@ def forward( return BaseModelOutput(last_hidden_state=hidden_states) -def pixel_shuffle_padded( - hidden_states: torch.Tensor, - token_grids: torch.Tensor, - scale_factor: int = 1, -) -> torch.Tensor: - """Apply pixel shuffle per image on padded batched vision embeddings. - - Args: - x (`torch.Tensor`): - Vision embeddings of shape `(num_images, max_patches, hidden_size)`. - token_grids (`torch.Tensor`): - Grid sizes `(height, width)` per image, shape `(num_images, 2)`. - scale_factor (`int`, *optional*, defaults to 1): - Spatial down-sampling factor. - - Returns: - `torch.Tensor`: Pixel-shuffled embeddings of shape - `(num_images, max_tokens, hidden_size * scale_factor**2)`. - """ - num_images, max_patches, embed_dim = hidden_states.shape - output_dim = embed_dim * scale_factor * scale_factor - - token_grids = token_grids.to(device=hidden_states.device, dtype=torch.long) - heights = token_grids[:, 0] - widths = token_grids[:, 1] - full_lengths = heights * widths - - non_empty = full_lengths > 0 - if not is_torchdynamo_compiling(): - divisible = ((heights % scale_factor) == 0) & ((widths % scale_factor) == 0) - torch_compilable_check( - (~non_empty) | divisible, - f"Every non-empty (H, W) grid must be divisible by pixel_shuffle_scale={scale_factor}.", - ) - - output_lengths = (heights // scale_factor) * (widths // scale_factor) - max_output_tokens = output_lengths.max() - shuffled_4d = hidden_states.new_zeros((num_images, max_output_tokens, scale_factor * scale_factor, embed_dim)) - - token_positions = ( - torch.arange(max_patches, device=hidden_states.device, dtype=torch.long).unsqueeze(0).expand(num_images, -1) - ) - valid_token_mask = token_positions < full_lengths.unsqueeze(1) - - safe_widths = torch.where(widths > 0, widths, torch.ones_like(widths)) - row_index = torch.div(token_positions, safe_widths.unsqueeze(1), rounding_mode="floor") - col_index = token_positions.remainder(safe_widths.unsqueeze(1)) - - output_widths = widths.div(scale_factor, rounding_mode="floor") - output_index = row_index.div(scale_factor, rounding_mode="floor") * output_widths.unsqueeze(1) - output_index = output_index + col_index.div(scale_factor, rounding_mode="floor") - sub_index = row_index.remainder(scale_factor) * scale_factor + col_index.remainder(scale_factor) - - batch_index = ( - torch.arange(num_images, device=hidden_states.device, dtype=torch.long).unsqueeze(1).expand_as(token_positions) - ) - shuffled_4d[batch_index[valid_token_mask], output_index[valid_token_mask], sub_index[valid_token_mask]] = ( - hidden_states[valid_token_mask] - ) - - shuffled = shuffled_4d.view(num_images, max_output_tokens, output_dim) - return shuffled - - -class IsaacVisionTransformer(PreTrainedModel): - """Vision tower for padded variable-resolution patches with per-image masks. - - Args: - config (IsaacVisionConfig): Vision configuration with pixel-shuffle and patching parameters. - - """ - +@auto_docstring +class IsaacVisionModel(PreTrainedModel): config: IsaacVisionConfig _supports_sdpa = True _supports_flash_attn = True @@ -420,7 +349,6 @@ class IsaacVisionTransformer(PreTrainedModel): def __init__(self, config: IsaacVisionConfig): super().__init__(config) - self.config = config self.embeddings = IsaacVisionEmbeddings(config) self.encoder = IsaacVisionEncoder(config) self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -434,26 +362,86 @@ def _init_weights(self, module): if isinstance(module, IsaacVisionEmbeddings): init.zeros_(module.position_embedding) + def pixel_shuffle_padded( + self, + hidden_states: torch.Tensor, + token_grids: torch.Tensor, + ) -> torch.Tensor: + """Apply pixel shuffle per image on padded batched vision embeddings. + + Args: + hidden_states (`torch.Tensor`): + Vision embeddings of shape `(num_images, max_patches, hidden_size)`. + token_grids (`torch.Tensor`): + Grid sizes `(height, width)` per image, shape `(num_images, 2)`. + + Returns: + `torch.Tensor`: Pixel-shuffled embeddings of shape + `(num_images, max_tokens, hidden_size * scale_factor**2)`. + """ + scale_factor = self.pixel_shuffle_scale_factor + num_images, max_patches, embed_dim = hidden_states.shape + output_dim = embed_dim * scale_factor * scale_factor + + token_grids = token_grids.to(device=hidden_states.device, dtype=torch.long) + heights = token_grids[:, 0] + widths = token_grids[:, 1] + full_lengths = heights * widths + + non_empty = full_lengths > 0 + if not is_torchdynamo_compiling(): + divisible = ((heights % scale_factor) == 0) & ((widths % scale_factor) == 0) + torch_compilable_check( + (~non_empty) | divisible, + f"Every non-empty (H, W) grid must be divisible by pixel_shuffle_scale={scale_factor}.", + ) + + output_lengths = (heights // scale_factor) * (widths // scale_factor) + max_output_tokens = output_lengths.max() + shuffled_4d = hidden_states.new_zeros((num_images, max_output_tokens, scale_factor * scale_factor, embed_dim)) + + token_positions = ( + torch.arange(max_patches, device=hidden_states.device, dtype=torch.long) + .unsqueeze(0) + .expand(num_images, -1) + ) + valid_token_mask = token_positions < full_lengths.unsqueeze(1) + + safe_widths = torch.where(widths > 0, widths, torch.ones_like(widths)) + row_index = torch.div(token_positions, safe_widths.unsqueeze(1), rounding_mode="floor") + col_index = token_positions.remainder(safe_widths.unsqueeze(1)) + + output_widths = widths.div(scale_factor, rounding_mode="floor") + output_index = row_index.div(scale_factor, rounding_mode="floor") * output_widths.unsqueeze(1) + output_index = output_index + col_index.div(scale_factor, rounding_mode="floor") + sub_index = row_index.remainder(scale_factor) * scale_factor + col_index.remainder(scale_factor) + + batch_index = ( + torch.arange(num_images, device=hidden_states.device, dtype=torch.long) + .unsqueeze(1) + .expand_as(token_positions) + ) + shuffled_4d[batch_index[valid_token_mask], output_index[valid_token_mask], sub_index[valid_token_mask]] = ( + hidden_states[valid_token_mask] + ) + + shuffled = shuffled_4d.view(num_images, max_output_tokens, output_dim) + return shuffled + @merge_with_config_defaults @capture_outputs(tie_last_hidden_states=False) + @auto_docstring def forward( self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: + r""" + image_grid_thw (`torch.LongTensor`, *optional*): + Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. """ - Inputs: - pixel_values (`torch.Tensor`): - Patches shaped `(num_images, max_patches, patch_dim)`. - image_grid_thw (`torch.Tensor`): - Grid tensor shaped `(num_images, 3)` with per-image `(T=1, H_tokens, W_tokens)`. - - Returns: - `BaseModelOutputWithPooling` with pixel-shuffled embeddings in `last_hidden_state`. - """ - vision_token_grids = image_grid_thw[:, 1:].to(dtype=torch.long) - full_lengths = vision_token_grids[:, 0] * vision_token_grids[:, 1] + full_lengths = image_grid_thw[:, 1] * image_grid_thw[:, 2] token_positions = torch.arange(pixel_values.shape[1], device=pixel_values.device, dtype=torch.long) image_patch_attention_mask = token_positions.unsqueeze(0) < full_lengths.unsqueeze(1) image_patch_attention_mask = image_patch_attention_mask.to(dtype=torch.long) @@ -471,10 +459,9 @@ def forward( encoder_outputs = self.encoder(inputs_embeds=hidden_states, attention_mask=encoder_attention_mask, **kwargs) hidden_states = self.post_layernorm(encoder_outputs.last_hidden_state) - hidden_states = pixel_shuffle_padded( + hidden_states = self.pixel_shuffle_padded( hidden_states=hidden_states, - token_grids=vision_token_grids, - scale_factor=self.pixel_shuffle_scale_factor, + token_grids=image_grid_thw[:, 1:], ) return BaseModelOutputWithPooling( @@ -485,51 +472,11 @@ def forward( ) -class IsaacMultiModalProjector(nn.Module): - """Maps vision tower outputs to the text hidden size with a SiLU MLP.""" - - def __init__(self, config: IsaacConfig): - super().__init__() - text_config = config.get_text_config() - vision_hidden_size = config.vision_config.hidden_size * (config.vision_config.pixel_shuffle_scale_factor**2) - backbone_hidden_size = text_config.hidden_size - self.linear_1 = nn.Linear(vision_hidden_size, 4 * vision_hidden_size, bias=False) - self.silu = nn.SiLU() - self.linear_2 = nn.Linear(4 * vision_hidden_size, backbone_hidden_size, bias=False) - - def forward(self, image_features): - hidden_states = self.linear_1(image_features) - hidden_states = self.silu(hidden_states) - hidden_states = self.linear_2(hidden_states) - return hidden_states - - -class _SinglePoint(NamedTuple): - x: int - y: int - mention: str | None = None - t: float | None = None - - -class _BoundingBox(NamedTuple): - top_left: Any - bottom_right: Any - mention: str | None = None - t: float | None = None - - -class _Polygon(NamedTuple): - points: tuple[Any, ...] - mention: str | None = None - t: float | None = None - - class IsaacRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` def __init__(self, config: IsaacTextConfig, device=None): super().__init__() - rope_parameters = config.rope_parameters self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings @@ -543,9 +490,11 @@ def __init__(self, config: IsaacTextConfig, device=None): self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) - - self.mrope_section = self._resolve_mrope_section(rope_parameters.get("mrope_section"), self.inv_freq.shape[0]) - self.hidden_size = config.hidden_size + self.mrope_section = config.rope_parameters.get("mrope_section") + if self.mrope_section is None: + weights = (2, 1, 1) + self.mrope_section = [self.inv_freq.shape[0] * w // sum(weights) for w in weights] + self.mrope_section[0] += self.inv_freq.shape[0] - sum(self.mrope_section) @staticmethod def compute_default_rope_parameters( @@ -610,17 +559,6 @@ def apply_interleaved_mrope(self, freqs: torch.Tensor, mrope_section: list[int]) chunks = freqs.split(tuple(mrope_section), dim=-1) return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) - @staticmethod - def _resolve_mrope_section(section: list[int] | None, rotary_half_dim: int) -> list[int]: - if section is None: - weights = (2, 1, 1) - base = [rotary_half_dim * w // sum(weights) for w in weights] - base[0] += rotary_half_dim - sum(base) - return base - - section = [int(v) for v in section] - return section - @use_kernel_forward_from_hub("RMSNorm") class IsaacTextRMSNorm(nn.Module): @@ -987,6 +925,25 @@ def _deepstack_process( return hidden_states +class IsaacMultiModalProjector(nn.Module): + """Maps vision tower outputs to the text hidden size with a SiLU MLP.""" + + def __init__(self, config: IsaacConfig): + super().__init__() + text_config = config.get_text_config() + vision_hidden_size = config.vision_config.hidden_size * (config.vision_config.pixel_shuffle_scale_factor**2) + backbone_hidden_size = text_config.hidden_size + self.linear_1 = nn.Linear(vision_hidden_size, 4 * vision_hidden_size, bias=False) + self.silu = nn.SiLU() + self.linear_2 = nn.Linear(4 * vision_hidden_size, backbone_hidden_size, bias=False) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.silu(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + @auto_docstring class IsaacModel(IsaacPreTrainedModel): base_model_prefix = "model" @@ -1004,7 +961,7 @@ class IsaacModel(IsaacPreTrainedModel): def __init__(self, config: IsaacConfig): super().__init__(config) self.language_model = IsaacTextModel._from_config(config.text_config) - self.visual = IsaacVisionTransformer(config.vision_config) + self.visual = IsaacVisionModel(config.vision_config) self.multimodal_projector = IsaacMultiModalProjector(config) self.max_sequence_length = config.max_sequence_length self.vision_rescale_factor = config.vision_rescale_factor @@ -1070,11 +1027,11 @@ def get_vision_position_ids( def get_rope_index( self, - input_ids: torch.LongTensor | None, + input_ids: torch.LongTensor, mm_token_type_ids: torch.Tensor, - image_grid_thw: torch.Tensor | None = None, + image_grid_thw: torch.Tensor, + image_metadata: torch.Tensor, attention_mask: torch.Tensor | None = None, - image_metadata: torch.Tensor | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -1101,9 +1058,6 @@ def get_rope_index( position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) """ - if image_grid_thw is None or image_metadata is None: - raise ValueError("Isaac multimodal RoPE requires both `image_grid_thw` and `image_metadata`.") - if attention_mask is None: if input_ids is None: attention_mask = mm_token_type_ids.new_ones(mm_token_type_ids.shape, dtype=torch.long) @@ -1185,22 +1139,6 @@ def get_rope_index( return position_ids, rope_deltas - @can_return_tuple - @auto_docstring - def get_video_features( - self, - pixel_values_videos: torch.FloatTensor, - video_grid_thw: torch.LongTensor | None = None, - **kwargs: Unpack[TransformersKwargs], - ) -> tuple | BaseModelOutputWithPooling: - r""" - pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): - The tensors corresponding to the input videos. - video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): - The temporal, height and width of feature shape of each video in LLM. - """ - raise ValueError("Isaac is image-only and does not support `pixel_values_videos` or `video_grid_thw`.") - @can_return_tuple @auto_docstring def get_image_features( @@ -1210,37 +1148,18 @@ def get_image_features( image_metadata: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values (`torch.FloatTensor`, *optional*): + Batch-major patch vectors shaped `(batch_size, max_images, max_patches, patch_dim)`. + image_grid_thw (`torch.LongTensor`, *optional*): + Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. + image_metadata (`torch.Tensor`, *optional*): + Batch-major per-slot metadata `(offset, length)` shaped `(batch_size, max_images, 2)`. """ - Args: - pixel_values (`torch.Tensor`): - Batch-major patch vectors with shape `(batch_size, max_images, max_patches, patch_dim)`. - image_grid_thw (`torch.Tensor`): - Batch-major grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. - image_metadata (`torch.Tensor`, *optional*): - Batch-major per-slot metadata `(offset, length)` shaped `(batch_size, max_images, 2)`. - """ - if pixel_values.shape[0] == 0: - hidden_size = self.config.get_text_config().hidden_size - return BaseModelOutputWithPooling( - last_hidden_state=pixel_values.new_zeros((0, 0, hidden_size)), - pooler_output=(), - hidden_states=None, - attentions=None, - ) - - image_grid_thw = image_grid_thw.to(dtype=torch.long) active_slot_mask = image_grid_thw[..., 0].eq(1) - if not active_slot_mask.any(): - hidden_size = self.config.get_text_config().hidden_size - return BaseModelOutputWithPooling( - last_hidden_state=pixel_values.new_zeros((0, 0, hidden_size)), - pooler_output=(), - hidden_states=None, - attentions=None, - ) - flat_pixel_values = pixel_values[active_slot_mask] flat_image_grid_thw = image_grid_thw[active_slot_mask] + vision_outputs: BaseModelOutputWithPooling = self.visual( pixel_values=flat_pixel_values, image_grid_thw=flat_image_grid_thw, @@ -1249,21 +1168,20 @@ def get_image_features( ) projected_features = self.multimodal_projector(vision_outputs.last_hidden_state) - pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor - full_lengths = flat_image_grid_thw[:, 1].div(pixel_shuffle_scale, rounding_mode="floor") * flat_image_grid_thw[ - :, 2 - ].div(pixel_shuffle_scale, rounding_mode="floor") + # Truncate image features using offset and length if image_metadata is None: - offsets = torch.zeros_like(full_lengths) - lengths = full_lengths + pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor + downsampled_height = flat_image_grid_thw[:, 1].div(pixel_shuffle_scale, rounding_mode="floor") + downsampled_width = flat_image_grid_thw[:, 2].div(pixel_shuffle_scale, rounding_mode="floor") + lengths = downsampled_height * downsampled_width + offsets = torch.zeros_like(lengths) else: torch_compilable_check( image_metadata.shape[:2] == image_grid_thw.shape[:2], "IsaacModel.get_image_features expects batch-major metadata aligned with `image_grid_thw`.", ) - active_metadata = image_metadata[active_slot_mask] - offsets = active_metadata[:, 0].to(device=projected_features.device, dtype=torch.long) - lengths = active_metadata[:, 1].to(device=projected_features.device, dtype=torch.long) + offsets = image_metadata[active_slot_mask][:, 0] + lengths = image_metadata[active_slot_mask][:, 1] image_features = tuple( projected_features[image_idx, offset : offset + length] @@ -1307,7 +1225,6 @@ def compute_3d_position_ids( past_key_values: Cache | None = None, ) -> torch.Tensor: past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() - has_multimodal = ( image_grid_thw is not None and image_metadata is not None @@ -1315,8 +1232,9 @@ def compute_3d_position_ids( ) if has_multimodal and mm_token_type_ids is None and input_ids is not None: raise ValueError( - "Multimodal data was passed (via `image_grid_thw`) but `mm_token_type_ids` is missing. " - "Please pass `mm_token_type_ids` so Isaac can build multimodal RoPE positions." + "Multimodal data was passed (via `image_grid_thw` or `image_metadata`) but `mm_token_type_ids` is " + "missing. Please pass `mm_token_type_ids` to the model so that multimodal RoPE (M-RoPE) can be " + "computed correctly. `mm_token_type_ids` is returned by the processor alongside `input_ids`." ) if has_multimodal and past_seen_tokens == 0: @@ -1379,16 +1297,15 @@ def forward( use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPast: - """ - Args: - mm_token_type_ids (`torch.LongTensor`, *optional*): - Multimodal token type ids aligned with the token sequence, using `0 -> text` and `1 -> image`. - pixel_values (`torch.FloatTensor`, *optional*): - Batch-major patch vectors shaped `(batch_size, max_images, max_patches, patch_dim)`. - image_grid_thw (`torch.LongTensor`, *optional*): - Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. - image_metadata (`torch.LongTensor`, *optional*): - Batch-major per-slot metadata shaped `(batch_size, max_images, 2)` with `(offset, length)`. + r""" + mm_token_type_ids (`torch.LongTensor`, *optional*): + Multimodal token type ids aligned with the token sequence, using `0 -> text` and `1 -> image`. + pixel_values (`torch.FloatTensor`, *optional*): + Batch-major patch vectors shaped `(batch_size, max_images, max_patches, patch_dim)`. + image_grid_thw (`torch.LongTensor`, *optional*): + Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. + image_metadata (`torch.LongTensor`, *optional*): + Batch-major per-slot metadata shaped `(batch_size, max_images, 2)` with `(offset, length)`. """ if (input_ids is None) == (inputs_embeds is None): raise ValueError("You must specify exactly one of `input_ids` or `inputs_embeds`.") @@ -1396,25 +1313,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - batch_size, seq_len = inputs_embeds.shape[:2] - if mm_token_type_ids is None: - mm_token_type_ids = torch.full((batch_size, seq_len), 0, device=inputs_embeds.device, dtype=torch.long) - else: - mm_token_type_ids = mm_token_type_ids.to(device=inputs_embeds.device, dtype=torch.long) - if mm_token_type_ids.shape[1] < seq_len: - padding = mm_token_type_ids.new_zeros((batch_size, seq_len - mm_token_type_ids.shape[1])) - mm_token_type_ids = torch.cat([mm_token_type_ids, padding], dim=1) - elif mm_token_type_ids.shape[1] > seq_len: - mm_token_type_ids = mm_token_type_ids[:, -seq_len:] - - if image_metadata is not None: - image_metadata = image_metadata.to(device=inputs_embeds.device, dtype=torch.long) - image_mask = None - has_active_images = ( - pixel_values is not None and image_grid_thw is not None and bool(image_grid_thw[..., 0].eq(1).any().item()) - ) - if has_active_images: + if pixel_values is not None and image_grid_thw is not None: image_outputs = self.get_image_features( pixel_values=pixel_values, image_grid_thw=image_grid_thw, @@ -1467,54 +1367,36 @@ def forward( **kwargs, ) - outputs_with_rope = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - outputs_with_rope["rope_deltas"] = self.rope_deltas - return outputs_with_rope @dataclass -@auto_docstring -class BaseModelOutputWithDeepstackFeatures(BaseModelOutputWithPooling): - r""" - deepstack_features (`List[torch.FloatTensor]`, *optional*): - List of hidden-states (feature maps) from deepstack layers. +class IsaacCausalLMOutputWithPast(CausalLMOutputWithPast): """ + Base class for Isaac causal language model (or autoregressive) outputs. - deepstack_features: list[torch.FloatTensor] | None = None + Args: + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + rope_deltas: torch.LongTensor | None = None @dataclass -@auto_docstring( - custom_intro=""" - Base class for Isaac causal language model (or autoregressive) outputs. - """ -) -class IsaacCausalLMOutputWithPast(ModelOutput): +@auto_docstring +class BaseModelOutputWithDeepstackFeatures(BaseModelOutputWithPooling): r""" - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): - The rope index difference between sequence length and multimodal rope. + deepstack_features (`List[torch.FloatTensor]`, *optional*): + List of hidden-states (feature maps) from deepstack layers. """ - loss: torch.FloatTensor | None = None - logits: torch.FloatTensor | None = None - past_key_values: Cache | None = None - hidden_states: tuple[torch.FloatTensor] | None = None - attentions: tuple[torch.FloatTensor] | None = None - rope_deltas: torch.LongTensor | None = None + deepstack_features: list[torch.FloatTensor] | None = None @auto_docstring @@ -1636,7 +1518,6 @@ def forward( >>> print(output_text) ``` """ - outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, @@ -1652,8 +1533,6 @@ def forward( ) hidden_states = outputs[0] - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -1667,13 +1546,13 @@ def forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - rope_deltas=outputs.rope_deltas, + rope_deltas=self.model.rope_deltas, ) def prepare_inputs_for_generation( self, input_ids, - past_key_values=None, + past_key_values: Cache = None, attention_mask: torch.Tensor | None = None, inputs_embeds: torch.FloatTensor | None = None, mm_token_type_ids: torch.LongTensor | None = None, @@ -1681,8 +1560,8 @@ def prepare_inputs_for_generation( image_grid_thw: torch.LongTensor | None = None, image_metadata: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, - is_first_iteration=False, - use_cache=True, + is_first_iteration: bool = False, + use_cache: bool = True, **kwargs, ) -> dict[str, Any]: model_inputs = super().prepare_inputs_for_generation( @@ -1697,13 +1576,14 @@ def prepare_inputs_for_generation( use_cache=use_cache, **kwargs, ) - is_prefill = is_first_iteration or not use_cache + multimodal_inputs = { "mm_token_type_ids": mm_token_type_ids, "pixel_values": pixel_values, "image_grid_thw": image_grid_thw, "image_metadata": image_metadata, } + is_prefill = is_first_iteration or not use_cache for key, value in multimodal_inputs.items(): model_inputs[key] = value if is_prefill else None if model_inputs["mm_token_type_ids"] is not None: @@ -1722,6 +1602,7 @@ def prepare_inputs_for_generation( model_inputs["mm_token_type_ids"] = torch.cat([model_inputs["mm_token_type_ids"], padding], dim=1) elif current_length > sequence_length: model_inputs["mm_token_type_ids"] = model_inputs["mm_token_type_ids"][:, -sequence_length:] + return model_inputs def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs): @@ -1839,10 +1720,24 @@ def _expand_inputs_for_generation( return input_ids, model_kwargs -__all__ = [ - "IsaacTextModel", - "IsaacVisionTransformer", - "IsaacModel", - "IsaacPreTrainedModel", - "IsaacForConditionalGeneration", -] +class SinglePoint(NamedTuple): + x: int + y: int + mention: str | None = None + t: float | None = None + + +class BoundingBox(NamedTuple): + top_left: Any + bottom_right: Any + mention: str | None = None + t: float | None = None + + +class Polygon(NamedTuple): + points: tuple[Any, ...] + mention: str | None = None + t: float | None = None + + +__all__ = ["IsaacTextModel", "IsaacVisionModel", "IsaacModel", "IsaacPreTrainedModel", "IsaacForConditionalGeneration"] diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 1e3b0b32d27e..7be4e9bb1b00 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -16,7 +16,8 @@ import math import re -from collections.abc import Sequence +from collections import defaultdict +from dataclasses import dataclass from typing import Any, NamedTuple from huggingface_hub.dataclasses import strict @@ -30,25 +31,22 @@ from ...image_transforms import group_images_by_shape, reorder_images from ...image_utils import ImageInput, PILImageResampling, SizeDict, make_nested_list_of_images from ...masking_utils import create_bidirectional_mask -from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...models.qwen3.configuration_qwen3 import Qwen3Config from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...utils import TensorType, auto_docstring, torch_compilable_check -from ...utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN -from ...utils.constants import IMAGENET_STANDARD_STD as VISION_STD +from ...utils.constants import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD from ...utils.generic import TransformersKwargs, can_return_tuple, merge_with_config_defaults from ...utils.import_utils import ( is_torch_available, is_torchdynamo_compiling, is_torchvision_available, - is_vision_available, ) from ...utils.output_capturing import capture_outputs from ..qwen3_vl.modeling_qwen3_vl import ( Qwen3VLForConditionalGeneration, Qwen3VLModel, - Qwen3VLTextAttention, Qwen3VLTextDecoderLayer, Qwen3VLTextModel, Qwen3VLTextRotaryEmbedding, @@ -66,24 +64,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -if is_vision_available(): - from PIL.Image import Image -else: - Image = None + from torchvision.transforms.v2 import functional as tvF + if is_torchvision_available(): from ..pix2struct.image_processing_pix2struct import torch_extract_patches -class IsaacProcessorKwargs(ProcessingKwargs, total=False): - _defaults = { - "text_kwargs": { - "padding": True, - "return_attention_mask": True, - "return_mm_token_type_ids": True, - }, - } - - @auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") @strict class IsaacVisionConfig(Siglip2VisionConfig): @@ -98,7 +84,6 @@ class IsaacVisionConfig(Siglip2VisionConfig): model_type = "isaac_vision" base_config_key = "vision_config" - pixel_shuffle_scale_factor: int = 1 @@ -121,247 +106,68 @@ class IsaacTextConfig(Qwen3Config): ignore_keys_at_rope_validation = {"mrope_section", "mrope_interleaved"} max_position_embeddings: int = 32768 sliding_window = AttributeError() + layer_types = AttributeError() def __post_init__(self, **kwargs): if self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads - if self.layer_types is None: - self.layer_types = ["full_attention" for _ in range(self.num_hidden_layers)] - PretrainedConfig.__post_init__(self, **kwargs) - self.validate_layer_type() - - -class IsaacImageProcessorKwargs(ImagesKwargs, total=False): - """ - patch_size (`int`, *optional*): - Side length (in pixels) for square patches extracted from resized images. - max_num_patches (`int`, *optional*): - Upper bound on extracted patches per image after resizing. - min_num_patches (`int`, *optional*): - Lower bound on extracted patches per image after resizing. - pixel_shuffle_scale (`int`, *optional*): - Pixel-shuffle reduction factor applied in the vision tower. - """ - - patch_size: int - max_num_patches: int - min_num_patches: int - pixel_shuffle_scale: int - - -@auto_docstring -class IsaacImageProcessor(TorchvisionBackend): - MAX_PIXELS = 60_000_000 # 60โ€‘megapixel ceiling โ‰ˆ 8200 ร— 7300 px - - resample = PILImageResampling.BILINEAR - model_input_names = [ - "pixel_values", - "image_grid_thw", - ] - valid_kwargs = IsaacImageProcessorKwargs - - do_resize = True - do_center_crop = False - patch_size: int | None = 16 - max_num_patches: int | None = 256 - min_num_patches: int | None = None - pixel_shuffle_scale: int | None = 1 - do_pad = True - do_rescale = True - do_normalize = True - image_mean = list(VISION_MEAN) - image_std = list(VISION_STD) - do_convert_rgb = True - disable_grouping = False - - def _validate_preprocess_kwargs(self, **kwargs): - # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for. - kwargs.pop("do_resize", None) - return super()._validate_preprocess_kwargs(**kwargs) - - def _prepare_images_structure( - self, - images: ImageInput, - expected_ndims: int = 3, - ) -> ImageInput: - images = self.fetch_images(images) - return make_nested_list_of_images(images, expected_ndims=expected_ndims) - - def resize( - self, - image: torch.Tensor, - size: SizeDict, - **kwargs, - ) -> torch.Tensor: - if image.dtype == torch.uint8: - image = F.interpolate(image.float(), size=(size.height, size.width), mode="bilinear", align_corners=False) - return image.clamp(0, 255).round().to(torch.uint8) - return F.interpolate(image, size=(size.height, size.width), mode="bilinear", align_corners=False) - - def get_number_of_image_patches( - self, - image_height: int, - image_width: int, - images_kwargs: dict[str, Any] | None = None, - ) -> int: - images_kwargs = images_kwargs or {} - patch_size = images_kwargs.get("patch_size", self.patch_size) - max_num_patches = images_kwargs.get("max_num_patches", self.max_num_patches) - min_num_patches = images_kwargs.get("min_num_patches", self.min_num_patches) - pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale", self.pixel_shuffle_scale) - - target_height, target_width = get_image_size_for_max_num_patches( - image_height, - image_width, - patch_size, - max_num_patches, - min_num_patches=min_num_patches, - pixel_shuffle_scale=pixel_shuffle_scale, - ) - return (target_height // patch_size) * (target_width // patch_size) - - def pack_images( - self, - vision_patches: list[list[torch.Tensor]], - vision_token_grids: list[list[torch.Tensor]], - ) -> dict[str, torch.Tensor | None]: - batch_size = len(vision_patches) - max_images = max((len(sample_patches) for sample_patches in vision_patches), default=0) - flat_patches = [patches for sample_patches in vision_patches for patches in sample_patches] - if max_images == 0 or not flat_patches: - return { - "pixel_values": None, - "image_grid_thw": None, - } - - first_patch = flat_patches[0] - max_patches = max(patches.shape[0] for patches in flat_patches) - patch_dim = first_patch.shape[-1] - patch_dtype = first_patch.dtype - patch_device = first_patch.device - - tensors = { - "pixel_values": torch.zeros( - (batch_size, max_images, max_patches, patch_dim), - device=patch_device, - dtype=patch_dtype, - ), - "image_grid_thw": torch.zeros((batch_size, max_images, 3), device=patch_device, dtype=torch.long), - } - - for batch_idx, (sample_patches, sample_token_grids) in enumerate( - zip(vision_patches, vision_token_grids, strict=True) - ): - for image_slot, (patches, token_grid) in enumerate(zip(sample_patches, sample_token_grids, strict=True)): - patch_count = int(patches.shape[0]) - tensors["pixel_values"][batch_idx, image_slot, :patch_count] = patches - tensors["image_grid_thw"][batch_idx, image_slot, 0] = 1 - tensors["image_grid_thw"][batch_idx, image_slot, 1:] = token_grid - - return tensors - - def _preprocess( - self, - images: list[list[torch.Tensor]], - do_resize: bool, - resample: Any | None, - do_rescale: bool | None, - rescale_factor: float | None, - do_normalize: bool | None, - image_mean: float | Sequence[float] | None, - image_std: float | Sequence[float] | None, - do_pad: bool | None = None, - disable_grouping: bool | None = None, - return_tensors: str | TensorType | None = None, - patch_size: int | None = None, - max_num_patches: int | None = None, - min_num_patches: int | None = None, - pixel_shuffle_scale: int | None = None, - **kwargs, - ) -> BatchFeature: - resample = kwargs.pop("interpolation", resample) - # IsaacProcessor routes text-only calls here as an empty image list per sample. - # Return `None` visual fields so text-only batches skip multimodal codepaths like other VLMs. - if all(len(sample_images) == 0 for sample_images in images): - tensors = { - "pixel_values": None, - "image_grid_thw": None, - } - return BatchFeature(data=tensors, tensor_type=return_tensors) - - grouped_images, grouped_images_index = group_images_by_shape( - images, disable_grouping=disable_grouping, is_nested=True - ) - grouped_outputs = {} - for shape, stacked_images in grouped_images.items(): - grouped_batch_size, channels, original_height, original_width = stacked_images.shape - target_height, target_width = get_image_size_for_max_num_patches( - original_height, - original_width, - patch_size, - max_num_patches, - min_num_patches=min_num_patches, - pixel_shuffle_scale=pixel_shuffle_scale, - ) - if do_resize: - image_batch = self.resize( - stacked_images, SizeDict(height=target_height, width=target_width), resample=resample - ) - else: - if (original_height % patch_size) or (original_width % patch_size): - raise ValueError( - f"Image dimensions (h={original_height}, w={original_width}) must be divisible by patch_size={patch_size} when resize is disabled; enable resizing or adjust the input resolution." - ) - image_batch, target_height, target_width = stacked_images, original_height, original_width +@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") +@strict +class IsaacConfig(PretrainedConfig): + r""" + vision_config (`IsaacVisionConfig` or `dict`, *optional*): + Configuration for the Isaac vision tower. Dictionaries are converted to [`IsaacVisionConfig`]. If unset, + the default [`IsaacVisionConfig`] is used. + text_config (`IsaacTextConfig` or `dict`, *optional*): + Configuration for the text backbone. Dictionaries are converted to [`IsaacTextConfig`]. + vision_rescale_factor (`float`, *optional*, defaults to 1 / 255): + Rescale factor applied by the image processor before normalization. + max_sequence_length (`int`, *optional*, defaults to 16384): + Maximum multimodal sequence length produced by the processor and expected by the model. - image_batch = self.rescale_and_normalize( - image_batch, - do_rescale=do_rescale, - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - ) + Example: - patches = torch_extract_patches(image_batch, patch_size, patch_size) - _, height_tokens, width_tokens, patch_dim = patches.shape + ```python + >>> from transformers import IsaacConfig, IsaacModel - token_grid = ( - torch.tensor([height_tokens, width_tokens], device=patches.device).long().expand(grouped_batch_size, 2) - ) + >>> configuration = IsaacConfig() + >>> model = IsaacModel(configuration) + >>> configuration = model.config + ``` + """ - if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): - raise ValueError( - f"Token grid (h={height_tokens}, w={width_tokens}) must be divisible by pixel_shuffle_scale={pixel_shuffle_scale}; adjust resize/patch parameters or disable pixel shuffle." - ) + model_type = "isaac" + sub_configs = {"vision_config": IsaacVisionConfig, "text_config": IsaacTextConfig} + vision_config: IsaacVisionConfig | dict | None = None + text_config: IsaacTextConfig | dict | None = None + vision_rescale_factor: float = 1 / 255 + max_sequence_length: int = 16384 - grouped_outputs[shape] = ( - patches.reshape(grouped_batch_size, -1, patch_dim), - token_grid, + def __post_init__(self, **kwargs): + if isinstance(self.text_config, dict): + self.text_config = self.sub_configs["text_config"](**self.text_config) + elif self.text_config is None: + self.text_config = self.sub_configs["text_config"]() + elif not isinstance(self.text_config, IsaacTextConfig): + raise TypeError( + f"text_config must be a dict or an IsaacTextConfig instance, got {type(self.text_config).__name__}." ) - keys = ("vision_patches", "vision_token_grids") - nested_outputs = {} - for i, key in enumerate(keys): - nested_outputs[key] = reorder_images( - {shape: values[i] for shape, values in grouped_outputs.items()}, - dict(grouped_images_index), - is_nested=True, + if isinstance(self.vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**self.vision_config) + elif self.vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + elif not isinstance(self.vision_config, IsaacVisionConfig): + raise TypeError( + f"vision_config must be a dict or an IsaacVisionConfig instance, got {type(self.vision_config).__name__}." ) - if not do_pad: - raise ValueError("IsaacImageProcessor doesn't support `do_pad=False` mode.") - - tensors = self.pack_images( - vision_patches=nested_outputs["vision_patches"], - vision_token_grids=nested_outputs["vision_token_grids"], - ) - - return BatchFeature(data=tensors, tensor_type=return_tensors) + self.vision_rescale_factor = float(self.vision_rescale_factor) + super().__post_init__(**kwargs) class IsaacVisionEmbeddings(Siglip2VisionEmbeddings): @@ -430,78 +236,8 @@ def __init__(self, config: IsaacVisionConfig): self.layers = nn.ModuleList([IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) -def pixel_shuffle_padded( - hidden_states: torch.Tensor, - token_grids: torch.Tensor, - scale_factor: int = 1, -) -> torch.Tensor: - """Apply pixel shuffle per image on padded batched vision embeddings. - - Args: - x (`torch.Tensor`): - Vision embeddings of shape `(num_images, max_patches, hidden_size)`. - token_grids (`torch.Tensor`): - Grid sizes `(height, width)` per image, shape `(num_images, 2)`. - scale_factor (`int`, *optional*, defaults to 1): - Spatial down-sampling factor. - - Returns: - `torch.Tensor`: Pixel-shuffled embeddings of shape - `(num_images, max_tokens, hidden_size * scale_factor**2)`. - """ - num_images, max_patches, embed_dim = hidden_states.shape - output_dim = embed_dim * scale_factor * scale_factor - - token_grids = token_grids.to(device=hidden_states.device, dtype=torch.long) - heights = token_grids[:, 0] - widths = token_grids[:, 1] - full_lengths = heights * widths - - non_empty = full_lengths > 0 - if not is_torchdynamo_compiling(): - divisible = ((heights % scale_factor) == 0) & ((widths % scale_factor) == 0) - torch_compilable_check( - (~non_empty) | divisible, - f"Every non-empty (H, W) grid must be divisible by pixel_shuffle_scale={scale_factor}.", - ) - - output_lengths = (heights // scale_factor) * (widths // scale_factor) - max_output_tokens = output_lengths.max() - shuffled_4d = hidden_states.new_zeros((num_images, max_output_tokens, scale_factor * scale_factor, embed_dim)) - - token_positions = ( - torch.arange(max_patches, device=hidden_states.device, dtype=torch.long).unsqueeze(0).expand(num_images, -1) - ) - valid_token_mask = token_positions < full_lengths.unsqueeze(1) - - safe_widths = torch.where(widths > 0, widths, torch.ones_like(widths)) - row_index = torch.div(token_positions, safe_widths.unsqueeze(1), rounding_mode="floor") - col_index = token_positions.remainder(safe_widths.unsqueeze(1)) - - output_widths = widths.div(scale_factor, rounding_mode="floor") - output_index = row_index.div(scale_factor, rounding_mode="floor") * output_widths.unsqueeze(1) - output_index = output_index + col_index.div(scale_factor, rounding_mode="floor") - sub_index = row_index.remainder(scale_factor) * scale_factor + col_index.remainder(scale_factor) - - batch_index = ( - torch.arange(num_images, device=hidden_states.device, dtype=torch.long).unsqueeze(1).expand_as(token_positions) - ) - shuffled_4d[batch_index[valid_token_mask], output_index[valid_token_mask], sub_index[valid_token_mask]] = ( - hidden_states[valid_token_mask] - ) - - shuffled = shuffled_4d.view(num_images, max_output_tokens, output_dim) - return shuffled - - -class IsaacVisionTransformer(PreTrainedModel): - """Vision tower for padded variable-resolution patches with per-image masks. - - Args: - config (IsaacVisionConfig): Vision configuration with pixel-shuffle and patching parameters. - - """ - +@auto_docstring +class IsaacVisionModel(PreTrainedModel): config: IsaacVisionConfig _supports_sdpa = True _supports_flash_attn = True @@ -512,7 +248,6 @@ class IsaacVisionTransformer(PreTrainedModel): def __init__(self, config: IsaacVisionConfig): super().__init__(config) - self.config = config self.embeddings = IsaacVisionEmbeddings(config) self.encoder = IsaacVisionEncoder(config) self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -526,26 +261,86 @@ def _init_weights(self, module): if isinstance(module, IsaacVisionEmbeddings): init.zeros_(module.position_embedding) - @merge_with_config_defaults - @capture_outputs(tie_last_hidden_states=False) - def forward( + def pixel_shuffle_padded( + self, + hidden_states: torch.Tensor, + token_grids: torch.Tensor, + ) -> torch.Tensor: + """Apply pixel shuffle per image on padded batched vision embeddings. + + Args: + hidden_states (`torch.Tensor`): + Vision embeddings of shape `(num_images, max_patches, hidden_size)`. + token_grids (`torch.Tensor`): + Grid sizes `(height, width)` per image, shape `(num_images, 2)`. + + Returns: + `torch.Tensor`: Pixel-shuffled embeddings of shape + `(num_images, max_tokens, hidden_size * scale_factor**2)`. + """ + scale_factor = self.pixel_shuffle_scale_factor + num_images, max_patches, embed_dim = hidden_states.shape + output_dim = embed_dim * scale_factor * scale_factor + + token_grids = token_grids.to(device=hidden_states.device, dtype=torch.long) + heights = token_grids[:, 0] + widths = token_grids[:, 1] + full_lengths = heights * widths + + non_empty = full_lengths > 0 + if not is_torchdynamo_compiling(): + divisible = ((heights % scale_factor) == 0) & ((widths % scale_factor) == 0) + torch_compilable_check( + (~non_empty) | divisible, + f"Every non-empty (H, W) grid must be divisible by pixel_shuffle_scale={scale_factor}.", + ) + + output_lengths = (heights // scale_factor) * (widths // scale_factor) + max_output_tokens = output_lengths.max() + shuffled_4d = hidden_states.new_zeros((num_images, max_output_tokens, scale_factor * scale_factor, embed_dim)) + + token_positions = ( + torch.arange(max_patches, device=hidden_states.device, dtype=torch.long) + .unsqueeze(0) + .expand(num_images, -1) + ) + valid_token_mask = token_positions < full_lengths.unsqueeze(1) + + safe_widths = torch.where(widths > 0, widths, torch.ones_like(widths)) + row_index = torch.div(token_positions, safe_widths.unsqueeze(1), rounding_mode="floor") + col_index = token_positions.remainder(safe_widths.unsqueeze(1)) + + output_widths = widths.div(scale_factor, rounding_mode="floor") + output_index = row_index.div(scale_factor, rounding_mode="floor") * output_widths.unsqueeze(1) + output_index = output_index + col_index.div(scale_factor, rounding_mode="floor") + sub_index = row_index.remainder(scale_factor) * scale_factor + col_index.remainder(scale_factor) + + batch_index = ( + torch.arange(num_images, device=hidden_states.device, dtype=torch.long) + .unsqueeze(1) + .expand_as(token_positions) + ) + shuffled_4d[batch_index[valid_token_mask], output_index[valid_token_mask], sub_index[valid_token_mask]] = ( + hidden_states[valid_token_mask] + ) + + shuffled = shuffled_4d.view(num_images, max_output_tokens, output_dim) + return shuffled + + @merge_with_config_defaults + @capture_outputs(tie_last_hidden_states=False) + @auto_docstring + def forward( self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: + r""" + image_grid_thw (`torch.LongTensor`, *optional*): + Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. """ - Inputs: - pixel_values (`torch.Tensor`): - Patches shaped `(num_images, max_patches, patch_dim)`. - image_grid_thw (`torch.Tensor`): - Grid tensor shaped `(num_images, 3)` with per-image `(T=1, H_tokens, W_tokens)`. - - Returns: - `BaseModelOutputWithPooling` with pixel-shuffled embeddings in `last_hidden_state`. - """ - vision_token_grids = image_grid_thw[:, 1:].to(dtype=torch.long) - full_lengths = vision_token_grids[:, 0] * vision_token_grids[:, 1] + full_lengths = image_grid_thw[:, 1] * image_grid_thw[:, 2] token_positions = torch.arange(pixel_values.shape[1], device=pixel_values.device, dtype=torch.long) image_patch_attention_mask = token_positions.unsqueeze(0) < full_lengths.unsqueeze(1) image_patch_attention_mask = image_patch_attention_mask.to(dtype=torch.long) @@ -563,10 +358,9 @@ def forward( encoder_outputs = self.encoder(inputs_embeds=hidden_states, attention_mask=encoder_attention_mask, **kwargs) hidden_states = self.post_layernorm(encoder_outputs.last_hidden_state) - hidden_states = pixel_shuffle_padded( + hidden_states = self.pixel_shuffle_padded( hidden_states=hidden_states, - token_grids=vision_token_grids, - scale_factor=self.pixel_shuffle_scale_factor, + token_grids=image_grid_thw[:, 1:], ) return BaseModelOutputWithPooling( @@ -577,6 +371,30 @@ def forward( ) +class IsaacRotaryEmbedding(Qwen3VLTextRotaryEmbedding): + def __init__(self, config: IsaacTextConfig, device=None): + super().__init__(config, device=device) + self.mrope_section = config.rope_parameters.get("mrope_section") + if self.mrope_section is None: + weights = (2, 1, 1) + self.mrope_section = [self.inv_freq.shape[0] * w // sum(weights) for w in weights] + self.mrope_section[0] += self.inv_freq.shape[0] - sum(self.mrope_section) + + def apply_interleaved_mrope(self, freqs: torch.Tensor, mrope_section: list[int]) -> torch.Tensor: + chunks = freqs.split(tuple(mrope_section), dim=-1) + return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) + + +class IsaacTextDecoderLayer(Qwen3VLTextDecoderLayer): + pass + + +class IsaacTextModel(Qwen3VLTextModel): + def __init__(self, config: IsaacTextConfig): + super().__init__(config) + self.rotary_emb = IsaacRotaryEmbedding(config=config, device=self.device) + + class IsaacMultiModalProjector(nn.Module): """Maps vision tower outputs to the text hidden size with a SiLU MLP.""" @@ -596,1084 +414,1215 @@ def forward(self, image_features): return hidden_states -def get_scaled_image_size( - scale: float, - original_size: int, - patch_size: int, - pixel_shuffle_scale: int, -) -> int: - scaled_size = scale * original_size - divisor = patch_size * pixel_shuffle_scale - scaled_size = math.ceil(scaled_size / divisor) * divisor - scaled_size = max(divisor, scaled_size) - return int(scaled_size) - - -def get_image_size_for_max_num_patches( - image_height: int, - image_width: int, - patch_size: int, - max_num_patches: int, - min_num_patches: int | None = None, - eps: float = 1e-5, - pixel_shuffle_scale: int = 1, -) -> tuple[int, int]: - r"""Compute a target resolution whose patch grid satisfies patching parametrization. - - Args: - image_height (`int`): - Height in pixels of the source image prior to any resizing. - image_width (`int`): - Width in pixels of the source image prior to any resizing. - patch_size (`int`): - Size of the square patch used by the vision encoder. - max_num_patches (`int`): - Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. - min_num_patches (`int`, *optional*): - Lower bound on the number of patches. When provided the image will be scaled up if necessary. - eps (`float`, *optional*, defaults to 1e-5): - Convergence tolerance for the internal binary search to determing the target dimensions. - pixel_shuffle_scale (`int`, *optional*, defaults to 1): - Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. - - Returns: - `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` - and respect both the maximum and optional minimum patch-count constraints. - """ - - # Ensure divisibility - divisor = patch_size * pixel_shuffle_scale - adjusted_height = math.ceil(image_height / divisor) * divisor - adjusted_height = max(divisor, adjusted_height) - adjusted_width = math.ceil(image_width / divisor) * divisor - adjusted_width = max(divisor, adjusted_width) - - num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) - - if min_num_patches is not None and num_patches < min_num_patches: - # Scale up via binary search to satisfy the minimum patch budget while - # preserving divisibility by patch_size * pixel_shuffle_scale. - scale_min, scale_max = 1.0, 100.0 - while (scale_max - scale_min) >= eps: - scale = (scale_min + scale_max) / 2 - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) - num_patches = (target_height / patch_size) * (target_width / patch_size) - if num_patches >= min_num_patches: - scale_max = scale - else: - scale_min = scale - scale = scale_max - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) - return target_height, target_width - elif num_patches <= max_num_patches: - return adjusted_height, adjusted_width - else: - # Scale down - scale_min, scale_max = eps / 10, 1.0 - while (scale_max - scale_min) >= eps: - scale = (scale_min + scale_max) / 2 - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) - num_patches = (target_height / patch_size) * (target_width / patch_size) - if num_patches <= max_num_patches: - scale_min = scale - else: - scale_max = scale - scale = scale_min - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) - return target_height, target_width - - -@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") -@strict -class IsaacConfig(PretrainedConfig): - r""" - vision_config (`IsaacVisionConfig` or `dict`, *optional*): - Configuration for the Isaac vision tower. Dictionaries are converted to [`IsaacVisionConfig`]. If unset, - the default [`IsaacVisionConfig`] is used. - text_config (`IsaacTextConfig` or `dict`, *optional*): - Configuration for the text backbone. Dictionaries are converted to [`IsaacTextConfig`]. - vision_rescale_factor (`float`, *optional*, defaults to 1 / 255): - Rescale factor applied by the image processor before normalization. - max_sequence_length (`int`, *optional*, defaults to 16384): - Maximum multimodal sequence length produced by the processor and expected by the model. - Example: - - ```python - >>> from transformers import IsaacConfig, IsaacModel +@auto_docstring +class IsaacModel(Qwen3VLModel): + input_modalities = ("image", "text") + supports_gradient_checkpointing = True + _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionEncoderLayer"] + _can_compile_fullgraph = False + _supports_flex_attn = False + _tied_weights_keys = {} + _input_embed_layer = "language_model.embed_tokens" - >>> configuration = IsaacConfig() - >>> model = IsaacModel(configuration) - >>> configuration = model.config - ``` - """ + def __init__(self, config: IsaacConfig): + PreTrainedModel.__init__(self, config) + self.language_model = IsaacTextModel._from_config(config.text_config) + self.visual = IsaacVisionModel(config.vision_config) + self.multimodal_projector = IsaacMultiModalProjector(config) + self.max_sequence_length = config.max_sequence_length + self.vision_rescale_factor = config.vision_rescale_factor + self.rope_deltas = None - model_type = "isaac" - sub_configs = {"vision_config": IsaacVisionConfig, "text_config": IsaacTextConfig} - vision_config: IsaacVisionConfig | dict | None = None - text_config: IsaacTextConfig | dict | None = None - vision_rescale_factor: float = 1 / 255 - max_sequence_length: int = 16384 + self.post_init() - def __post_init__(self, **kwargs): - for key in ("use_cache", "rope_theta", "max_position_embeddings"): - kwargs.pop(key, None) + @can_return_tuple + @auto_docstring + def get_image_features( + self, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, + image_metadata: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values (`torch.FloatTensor`, *optional*): + Batch-major patch vectors shaped `(batch_size, max_images, max_patches, patch_dim)`. + image_grid_thw (`torch.LongTensor`, *optional*): + Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. + image_metadata (`torch.Tensor`, *optional*): + Batch-major per-slot metadata `(offset, length)` shaped `(batch_size, max_images, 2)`. + """ + active_slot_mask = image_grid_thw[..., 0].eq(1) + flat_pixel_values = pixel_values[active_slot_mask] + flat_image_grid_thw = image_grid_thw[active_slot_mask] - if isinstance(self.text_config, dict): - self.text_config = self.sub_configs["text_config"](**self.text_config) - elif self.text_config is None: - self.text_config = self.sub_configs["text_config"]() - elif not isinstance(self.text_config, IsaacTextConfig): - raise TypeError( - f"text_config must be a dict or an IsaacTextConfig instance, got {type(self.text_config).__name__}." - ) + vision_outputs: BaseModelOutputWithPooling = self.visual( + pixel_values=flat_pixel_values, + image_grid_thw=flat_image_grid_thw, + return_dict=True, + **kwargs, + ) + projected_features = self.multimodal_projector(vision_outputs.last_hidden_state) - if isinstance(self.vision_config, dict): - self.vision_config = self.sub_configs["vision_config"](**self.vision_config) - elif self.vision_config is None: - self.vision_config = self.sub_configs["vision_config"]() - elif not isinstance(self.vision_config, IsaacVisionConfig): - raise TypeError( - f"vision_config must be a dict or an IsaacVisionConfig instance, got {type(self.vision_config).__name__}." + # Truncate image features using offset and length + if image_metadata is None: + pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor + downsampled_height = flat_image_grid_thw[:, 1].div(pixel_shuffle_scale, rounding_mode="floor") + downsampled_width = flat_image_grid_thw[:, 2].div(pixel_shuffle_scale, rounding_mode="floor") + lengths = downsampled_height * downsampled_width + offsets = torch.zeros_like(lengths) + else: + torch_compilable_check( + image_metadata.shape[:2] == image_grid_thw.shape[:2], + "IsaacModel.get_image_features expects batch-major metadata aligned with `image_grid_thw`.", ) + offsets = image_metadata[active_slot_mask][:, 0] + lengths = image_metadata[active_slot_mask][:, 1] - self.vision_rescale_factor = float(self.vision_rescale_factor) + image_features = tuple( + projected_features[image_idx, offset : offset + length] + for image_idx, (offset, length) in enumerate(zip(offsets.tolist(), lengths.tolist(), strict=True)) + ) - super().__post_init__(**kwargs) + return BaseModelOutputWithPooling( + last_hidden_state=projected_features, + pooler_output=image_features, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + ) + def get_placeholder_mask( + self, + mm_token_type_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor, + ) -> torch.BoolTensor: + image_token_mask = mm_token_type_ids.to(dtype=torch.long) == 1 + n_image_tokens = image_token_mask.sum() + image_token_mask = image_token_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[image_token_mask].numel() == image_features.numel(), + f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", + ) + return image_token_mask -@auto_docstring -class IsaacProcessor(ProcessorMixin): - class _SinglePoint(NamedTuple): - x: int - y: int - mention: str | None = None - t: float | None = None - - class _BoundingBox(NamedTuple): - top_left: Any - bottom_right: Any - mention: str | None = None - t: float | None = None - - class _Polygon(NamedTuple): - points: tuple[Any, ...] - mention: str | None = None - t: float | None = None - - _point_box_or_polygon_tag = re.compile( - r"<(?Ppoint|point_box|polygon)(?P[^>]*)>(?P[\s\S]*?)", re.IGNORECASE - ) - _attr_re = re.compile(r"(\w+)\s*=\s*(?:\"([^\"]*)\"|([^\s>]+))") - _coord_re = re.compile(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)") + def get_video_features(self, **super_kwargs): + raise AttributeError("Isaac is image-only and does not support `pixel_values_videos` or `video_grid_thw`.") - def __init__( + def get_vision_position_ids( self, - image_processor, - tokenizer, - chat_template: str | dict[str, str] | None = None, - max_sequence_length: int = 16384, - ): - """ - Args: - chat_template (`str` or `dict[str, str]`, *optional*): - Chat template override forwarded to [`~processing_utils.ProcessorMixin`]. - max_sequence_length (`int`, *optional*, defaults to 16384): - Maximum packed multimodal sequence length produced by the processor. - """ - if chat_template is None: - chat_template = getattr(tokenizer, "chat_template", None) - - self.image_processor = image_processor - super().__init__(image_processor, tokenizer, chat_template=chat_template) - self.text_pad_token_id = self.pad_token_id = tokenizer.pad_token_id - self.image_token = getattr(tokenizer, "image_pad_token", None) or getattr(tokenizer, "image_token", None) - self.image_token_id = getattr(tokenizer, "image_pad_token_id", None) or getattr( - tokenizer, "image_token_id", None + start_position: int, + grid_thw: torch.LongTensor, + image_metadata: torch.LongTensor, + ) -> torch.LongTensor: + pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor + height = grid_thw[1].div(pixel_shuffle_scale, rounding_mode="floor").item() + width = grid_thw[2].div(pixel_shuffle_scale, rounding_mode="floor").item() + token_positions = torch.arange(height * width, device=grid_thw.device, dtype=torch.long) + vision_position_ids = torch.stack( + ( + torch.full((token_positions.shape[0],), start_position, device=grid_thw.device, dtype=torch.long), + token_positions.div(width, rounding_mode="floor"), + token_positions.remainder(width), + ), + dim=0, ) + token_offset = int(image_metadata[0].item()) + token_length = int(image_metadata[1].item()) + return vision_position_ids[:, token_offset : token_offset + token_length] - self.max_sequence_length = max_sequence_length + def get_rope_index( + self, + input_ids: torch.LongTensor, + mm_token_type_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + image_metadata: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if attention_mask is None: + if input_ids is None: + attention_mask = mm_token_type_ids.new_ones(mm_token_type_ids.shape, dtype=torch.long) + else: + attention_mask = input_ids.new_ones(input_ids.shape, dtype=torch.long) - @property - def model_input_names(self): - return super().model_input_names + ["mm_token_type_ids", "image_metadata"] + if input_ids is None: + batch_size, seq_len = attention_mask.shape + position_dtype = torch.long + else: + batch_size, seq_len = input_ids.shape + position_dtype = input_ids.dtype - @staticmethod - def _maybe_float(value: str | None) -> float | None: - if value is None: - return None - try: - return float(value) - except ValueError: - return None + device = attention_mask.device + mm_token_type_ids = mm_token_type_ids.to(dtype=torch.long) + image_grid_thw = image_grid_thw.to(dtype=torch.long) + image_metadata = image_metadata.to(dtype=torch.long) + attention_mask = attention_mask.to(dtype=torch.long) + active_slot_mask = image_grid_thw[..., 0].eq(1) - @classmethod - def _parse_attrs(cls, attr_text: str) -> dict[str, str]: - attrs = {} - for match in cls._attr_re.finditer(attr_text or ""): - key = match.group(1) - value = match.group(2) or match.group(3) or "" - attrs[key] = value - return attrs + position_ids = torch.zeros((3, batch_size, seq_len), device=device, dtype=position_dtype) + rope_deltas = torch.zeros((batch_size, 1), device=device, dtype=torch.long) - @classmethod - def _parse_point_body( - cls, - body: str, - mention: str | None = None, - t: str | None = None, - ) -> Any: - match = cls._coord_re.search(body) - if not match: - raise ValueError(f"Malformed tag: {body!r}") - x, y = int(match.group(1)), int(match.group(2)) - return cls._SinglePoint(x=x, y=y, mention=mention, t=cls._maybe_float(t)) + for batch_idx in range(batch_size): + sample_attention_mask = attention_mask[batch_idx].bool() + sample_token_types = mm_token_type_ids[batch_idx][sample_attention_mask] + sample_grids = image_grid_thw[batch_idx] + sample_metadata = image_metadata[batch_idx] + sample_active_slots = active_slot_mask[batch_idx] - @classmethod - def _parse_box_body( - cls, - body: str, - mention: str | None = None, - t: str | None = None, - ) -> Any: - coords = list(cls._coord_re.finditer(body)) - if len(coords) < 2: - raise ValueError(f"Malformed tag: {body!r}") + current_pos = 0 + image_idx = 0 + seq_pos = 0 + llm_pos_ids_list = [] - top_left = cls._SinglePoint(x=int(coords[0].group(1)), y=int(coords[0].group(2))) - bottom_right = cls._SinglePoint(x=int(coords[1].group(1)), y=int(coords[1].group(2))) - return cls._BoundingBox(top_left=top_left, bottom_right=bottom_right, mention=mention, t=cls._maybe_float(t)) + while seq_pos < sample_token_types.shape[0]: + modality_type = int(sample_token_types[seq_pos].item()) + if modality_type == 0: + group_end = seq_pos + 1 + while group_end < sample_token_types.shape[0] and sample_token_types[group_end] == 0: + group_end += 1 + group_length = group_end - seq_pos + llm_pos_ids_list.append( + torch.arange(group_length, device=device, dtype=torch.long).view(1, -1).expand(3, -1) + + current_pos + ) + current_pos += group_length + seq_pos = group_end + else: + while image_idx < sample_metadata.shape[0] and ( + not bool(sample_active_slots[image_idx].item()) or sample_metadata[image_idx, 1].item() == 0 + ): + image_idx += 1 + torch_compilable_check( + image_idx < sample_metadata.shape[0], + "Isaac multimodal sequence has more visible image tokens than batch-major image metadata slots.", + ) + token_length = int(sample_metadata[image_idx, 1].item()) + torch_compilable_check( + token_length <= sample_token_types.shape[0] - seq_pos, + "Isaac image metadata length exceeds the remaining multimodal placeholder span.", + ) + llm_pos_ids_list.append( + self.get_vision_position_ids(current_pos, sample_grids[image_idx], sample_metadata[image_idx]) + ) + current_pos += 1 + seq_pos += token_length + image_idx += 1 - @classmethod - def _parse_polygon_body( - cls, - body: str, - mention: str | None = None, - t: str | None = None, - ) -> Any: - coords = list(cls._coord_re.finditer(body)) - if len(coords) < 3: - raise ValueError(f"Malformed tag: {body!r}") + llm_positions = ( + torch.cat(llm_pos_ids_list, dim=1) + if llm_pos_ids_list + else torch.zeros((3, 0), device=device, dtype=torch.long) + ) + position_ids[:, batch_idx, sample_attention_mask] = llm_positions + rope_deltas[batch_idx, 0] = ( + llm_positions.max() + 1 - sample_token_types.shape[0] if llm_positions.numel() > 0 else 0 + ) - points = tuple(cls._SinglePoint(x=int(coord.group(1)), y=int(coord.group(2))) for coord in coords) - return cls._Polygon(points=points, mention=mention, t=cls._maybe_float(t)) + return position_ids, rope_deltas - @classmethod - def clean_text_and_extract_points( - cls, - text: str, - expected: str | None = None, - ) -> tuple[str, list[Any]]: - results: list[Any] = [] - for match in cls._point_box_or_polygon_tag.finditer(text or ""): - tag = match.group("tag").lower() - attrs = cls._parse_attrs(match.group("attrs")) - mention = attrs.get("mention") - t = attrs.get("t") - if tag == "point": - if expected not in (None, "point"): - continue - results.append(cls._parse_point_body(match.group("body"), mention=mention, t=t)) - elif tag == "point_box": - if expected not in (None, "box"): - continue - results.append(cls._parse_box_body(match.group("body"), mention=mention, t=t)) - else: - if expected not in (None, "polygon"): - continue - results.append(cls._parse_polygon_body(match.group("body"), mention=mention, t=t)) + def compute_3d_position_ids( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor | None, + mm_token_type_ids: torch.Tensor | None = None, + image_grid_thw: torch.Tensor | None = None, + image_metadata: torch.Tensor | None = None, + past_key_values: Cache | None = None, + ) -> torch.Tensor: + past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() + has_multimodal = ( + image_grid_thw is not None + and image_metadata is not None + and bool(image_grid_thw[..., 0].eq(1).any().item()) + ) + if has_multimodal and mm_token_type_ids is None and input_ids is not None: + raise ValueError( + "Multimodal data was passed (via `image_grid_thw` or `image_metadata`) but `mm_token_type_ids` is " + "missing. Please pass `mm_token_type_ids` to the model so that multimodal RoPE (M-RoPE) can be " + "computed correctly. `mm_token_type_ids` is returned by the processor alongside `input_ids`." + ) - clean_text = re.sub(r"\s+", " ", cls._point_box_or_polygon_tag.sub("", text or "")).strip() - return clean_text, results + if has_multimodal and past_seen_tokens == 0: + position_ids, rope_deltas = self.get_rope_index( + input_ids=input_ids, + mm_token_type_ids=mm_token_type_ids, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, + attention_mask=attention_mask, + ) + self.rope_deltas = rope_deltas + return position_ids - def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): - vision_data = {} - if image_sizes is not None: - images_kwargs = dict(IsaacProcessorKwargs._defaults.get("images_kwargs", {})) - images_kwargs.update(kwargs) + if self.rope_deltas is None: + return None - num_image_patches = [ - self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) - for image_size in image_sizes - ] - pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale") or self.image_processor.pixel_shuffle_scale - num_image_tokens = [num_patches // pixel_shuffle_scale**2 for num_patches in num_image_patches] - vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + rope_deltas = torch.as_tensor(self.rope_deltas, device=inputs_embeds.device, dtype=torch.long).reshape(-1, 1) + if rope_deltas.shape[0] != inputs_embeds.shape[0]: + if inputs_embeds.shape[0] % rope_deltas.shape[0] == 0: + rope_deltas = rope_deltas.repeat_interleave(inputs_embeds.shape[0] // rope_deltas.shape[0], dim=0) + else: + rope_deltas = rope_deltas[:1].expand(inputs_embeds.shape[0], -1) - return MultiModalData(**vision_data) + if attention_mask is not None and attention_mask.shape[-1] > inputs_embeds.shape[1]: + rope_position = attention_mask.long().cumsum(dim=-1) - 1 + rope_position = rope_position.masked_fill(attention_mask == 0, 0) + rope_position = rope_position[:, -inputs_embeds.shape[1] :] + else: + rope_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + dtype=torch.long, + ).view(1, -1) + rope_position = rope_position.expand(inputs_embeds.shape[0], -1) - def post_process_generation( - self, - text: str, - expected: str | None = None, - cleanup_and_extract: bool = True, - ) -> str | tuple[str, list[Any]]: - if cleanup_and_extract: - return self.clean_text_and_extract_points(text, expected=expected) - return text + position_ids = rope_position.view(1, inputs_embeds.shape[0], -1).expand(3, -1, -1) + return position_ids + rope_deltas.to(device=inputs_embeds.device).unsqueeze(0) - def post_process_image_text_to_text( - self, - generated_outputs, - skip_special_tokens: bool = True, - cleanup_and_extract: bool = False, - expected: str | None = None, - **kwargs, - ): - generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs) - return [ - self.post_process_generation(text, expected=expected, cleanup_and_extract=cleanup_and_extract) - for text in generated_texts - ] + @auto_docstring( + custom_intro=""" + Forward pass with multimodal MRoPE position ids. - def __call__( + When image placeholders are present, Isaac computes vision features, scatters them into the token + embeddings, and runs the shared text backbone on the mixed sequence. + """, + ) + @can_return_tuple + def forward( self, - text: str | list[str], - images: ImageInput | None = None, - return_tensors: str | TensorType | None = TensorType.PYTORCH, - **kwargs, - ) -> BatchFeature: - output_kwargs = self._merge_kwargs( - IsaacProcessorKwargs, - tokenizer_init_kwargs=self.tokenizer.init_kwargs, + input_ids: torch.LongTensor | None = None, + mm_token_type_ids: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + image_metadata: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPast: + r""" + mm_token_type_ids (`torch.LongTensor`, *optional*): + Multimodal token type ids aligned with the token sequence, using `0 -> text` and `1 -> image`. + pixel_values (`torch.FloatTensor`, *optional*): + Batch-major patch vectors shaped `(batch_size, max_images, max_patches, patch_dim)`. + image_grid_thw (`torch.LongTensor`, *optional*): + Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. + image_metadata (`torch.LongTensor`, *optional*): + Batch-major per-slot metadata shaped `(batch_size, max_images, 2)` with `(offset, length)`. + """ + if (input_ids is None) == (inputs_embeds is None): + raise ValueError("You must specify exactly one of `input_ids` or `inputs_embeds`.") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_mask = None + if pixel_values is not None and image_grid_thw is not None: + image_outputs = self.get_image_features( + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, + return_dict=True, + ) + image_embeds = image_outputs.pooler_output + if len(image_embeds) > 0: + image_embeds = torch.cat(image_embeds, dim=0).to( + device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + image_mask = self.get_placeholder_mask( + mm_token_type_ids=mm_token_type_ids, + inputs_embeds=inputs_embeds, + image_features=image_embeds, + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if isinstance(attention_mask, dict): + attention_mask = attention_mask["full_attention"] + + past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() + computed_position_ids = self.compute_3d_position_ids( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + mm_token_type_ids=mm_token_type_ids, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, + attention_mask=attention_mask, + past_key_values=past_key_values, + ) + if computed_position_ids is not None: + position_ids = computed_position_ids + elif past_seen_tokens > 0: + position_ids = None + elif position_ids is not None and past_seen_tokens == 0: + position_ids = position_ids.to(device=inputs_embeds.device) + if position_ids.ndim == 2: + position_ids = position_ids.view(1, position_ids.shape[0], -1).expand(3, -1, -1) + + outputs = self.language_model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + visual_pos_masks=image_mask[..., 0] if image_mask is not None else None, + deepstack_visual_embeds=None, + use_cache=use_cache, **kwargs, ) - text_kwargs = output_kwargs["text_kwargs"] - truncation = text_kwargs.pop("truncation", None) - max_length = text_kwargs.pop("max_length", None) - padding = text_kwargs.pop("padding", True) - padding_side = text_kwargs.pop("padding_side", "left") - return_attention_mask = text_kwargs.pop("return_attention_mask", True) - return_mm_token_type_ids = text_kwargs.pop("return_mm_token_type_ids", True) - pad_to_multiple_of = text_kwargs.pop("pad_to_multiple_of", None) - text_kwargs.pop("return_tensors", None) - text_kwargs.pop("return_overflowing_tokens", None) - text_kwargs.setdefault("add_special_tokens", False) - - texts = [text] if isinstance(text, str) else text - if images is None: - batched_images = [[] for _ in texts] - else: - fetched_images = self.image_processor.fetch_images(images) - batched_images = make_nested_list_of_images(fetched_images) - if len(batched_images) != len(texts): - num_images_in_text = [text_value.count(self.image_token) for text_value in texts] - num_images_in_images = [len(sample_images) for sample_images in batched_images] - add_message = "" - if sum(num_images_in_text) == sum(num_images_in_images): - add_message = " Make sure to pass your images as a nested list, where each sub-list holds images for one text sample." - raise ValueError( - f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(texts)}).{add_message}" - ) + return BaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + - pairs = list(zip(texts, batched_images, strict=True)) - image_inputs = self.image_processor(images=batched_images, return_tensors=TensorType.PYTORCH) - image_grid_thw = image_inputs["image_grid_thw"] - image_metadata = None - vision_segment_lengths = None - if image_grid_thw is not None: - batch_size, max_images = image_grid_thw.shape[:2] - image_metadata = torch.zeros((batch_size, max_images, 2), dtype=torch.long) - grid_heights = image_grid_thw[..., 1] - grid_widths = image_grid_thw[..., 2] - vision_segment_lengths = (grid_heights // self.image_processor.pixel_shuffle_scale) * ( - grid_widths // self.image_processor.pixel_shuffle_scale - ) +@dataclass +class IsaacCausalLMOutputWithPast(CausalLMOutputWithPast): + """ + Base class for Isaac causal language model (or autoregressive) outputs. - expanded_texts = [] - expected_image_lengths_per_sample = [] + Args: + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ - for batch_idx, (text_value, sample_images) in enumerate(pairs): - segments = text_value.split(self.image_token) - num_images = len(segments) - 1 - num_provided_images = len(sample_images) - if num_images != num_provided_images: - raise ValueError( - f"IsaacProcessor expects one image per image token, got {num_images} tokens and {num_provided_images} images in sample with text {text_value} " - ) + rope_deltas: torch.LongTensor | None = None - expected_image_lengths = [] - expanded_text_parts = [segments[0]] - for image_idx in range(num_images): - segment_length = int(vision_segment_lengths[batch_idx, image_idx].item()) - expected_image_lengths.append(segment_length) - expanded_text_parts.append(self.image_token * segment_length) - expanded_text_parts.append(segments[image_idx + 1]) - - expected_image_lengths_per_sample.append(expected_image_lengths) - expanded_texts.append("".join(expanded_text_parts)) - - effective_max_length = self.max_sequence_length - if max_length is not None and (truncation or padding == "max_length"): - effective_max_length = max_length - - self.tokenizer.truncation_side = "left" - self.tokenizer.padding_side = padding_side - tokenized_text_inputs = self.tokenizer( - expanded_texts, - truncation=True, - max_length=effective_max_length, - padding=padding, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=True, - stride=0, - return_tensors=None, - **text_kwargs, - ) - kept_input_ids_per_sample: list[list[int] | None] = [None] * len(texts) - overflow_input_ids_per_sample: list[list[list[int]]] = [[] for _ in texts] - overflow_to_sample_mapping = tokenized_text_inputs.get("overflow_to_sample_mapping") - if overflow_to_sample_mapping is None: - overflow_to_sample_mapping = list(range(len(tokenized_text_inputs["input_ids"]))) +@auto_docstring +class IsaacForConditionalGeneration(Qwen3VLForConditionalGeneration, GenerationMixin): + config_class = IsaacConfig + input_modalities = ("image", "text") + _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionEncoderLayer"] + _can_compile_fullgraph = False + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} - for row_input_ids, sample_idx in zip( - tokenized_text_inputs["input_ids"], overflow_to_sample_mapping, strict=True - ): - sample_idx = int(sample_idx) - if kept_input_ids_per_sample[sample_idx] is None: - kept_input_ids_per_sample[sample_idx] = row_input_ids - else: - overflow_input_ids_per_sample[sample_idx].append(row_input_ids) + def __init__(self, config: IsaacConfig): + PreTrainedModel.__init__(self, config) + self.model = IsaacModel(config) + self.vocab_size = config.get_text_config().vocab_size + self.lm_head = nn.Linear(config.get_text_config().hidden_size, config.get_text_config().vocab_size, bias=False) + self.post_init() - for batch_idx, expected_image_lengths in enumerate(expected_image_lengths_per_sample): - dropped_image_tokens = sum( - overflow_input_ids.count(self.image_token_id) - for overflow_input_ids in overflow_input_ids_per_sample[batch_idx] - ) + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + mm_token_type_ids: torch.IntTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | IsaacCausalLMOutputWithPast: + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + mm_token_type_ids=mm_token_type_ids, + **kwargs, + ) - remaining_dropped = dropped_image_tokens - for image_idx, expected_length in enumerate(expected_image_lengths): - if remaining_dropped <= 0: - offset = 0 - length = expected_length - elif remaining_dropped < expected_length: - offset = remaining_dropped - length = expected_length - offset - remaining_dropped = 0 - else: - offset = 0 - length = 0 - remaining_dropped -= expected_length + hidden_states = outputs[0] + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) - # Record which suffix of this image's placeholder span survives left truncation. - # The model still encodes the full image and uses this window for both feature gathering and vision RoPE. - image_metadata[batch_idx, image_idx, 0] = offset - image_metadata[batch_idx, image_idx, 1] = length + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) - input_ids = torch.tensor(kept_input_ids_per_sample, dtype=torch.long) - attention_mask = input_ids.ne(self.pad_token_id).to(dtype=torch.long) + return IsaacCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.model.rope_deltas, + ) - data = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "pixel_values": image_inputs["pixel_values"], + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values: Cache = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + mm_token_type_ids: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + image_metadata: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + is_first_iteration: bool = False, + use_cache: bool = True, + **kwargs, + ) -> dict[str, Any]: + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + is_first_iteration=is_first_iteration, + use_cache=use_cache, + **kwargs, + ) + + multimodal_inputs = { + "mm_token_type_ids": mm_token_type_ids, + "pixel_values": pixel_values, "image_grid_thw": image_grid_thw, "image_metadata": image_metadata, } - if return_mm_token_type_ids: - data["mm_token_type_ids"] = input_ids.eq(self.image_token_id).to(dtype=torch.long) + is_prefill = is_first_iteration or not use_cache + for key, value in multimodal_inputs.items(): + model_inputs[key] = value if is_prefill else None + if model_inputs["mm_token_type_ids"] is not None: + sequence_length = None + if model_inputs.get("input_ids") is not None: + sequence_length = model_inputs["input_ids"].shape[1] + elif model_inputs.get("inputs_embeds") is not None: + sequence_length = model_inputs["inputs_embeds"].shape[1] - return BatchFeature( - data=data, - tensor_type=return_tensors, - ) + if sequence_length is not None: + current_length = model_inputs["mm_token_type_ids"].shape[1] + if current_length < sequence_length: + padding = model_inputs["mm_token_type_ids"].new_zeros( + (model_inputs["mm_token_type_ids"].shape[0], sequence_length - current_length) + ) + model_inputs["mm_token_type_ids"] = torch.cat([model_inputs["mm_token_type_ids"], padding], dim=1) + elif current_length > sequence_length: + model_inputs["mm_token_type_ids"] = model_inputs["mm_token_type_ids"][:, -sequence_length:] + return model_inputs -class IsaacRotaryEmbedding(Qwen3VLTextRotaryEmbedding): - def __init__(self, config: IsaacTextConfig, device=None): - rope_parameters = config.rope_parameters + def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs): + text_positions = GenerationMixin._prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs) - super().__init__( - config, - device=device if device is not None and getattr(device, "type", None) != "meta" else None, - ) + past_length = 0 + if (cache := model_kwargs.get("past_key_values")) is not None: + past_length = cache.get_seq_length() + if past_length != 0 and self.model.rope_deltas is not None: + return text_positions[None, ...] + self.model.rope_deltas - self.mrope_section = self._resolve_mrope_section(rope_parameters.get("mrope_section"), self.inv_freq.shape[0]) - self.hidden_size = config.hidden_size + if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0: + inputs_tensor = model_kwargs["input_ids"] - @staticmethod - def _resolve_mrope_section(section: list[int] | None, rotary_half_dim: int) -> list[int]: - if section is None: - weights = (2, 1, 1) - base = [rotary_half_dim * w // sum(weights) for w in weights] - base[0] += rotary_half_dim - sum(base) - return base + is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long] + if ( + is_input_ids + and model_kwargs.get("mm_token_type_ids") is not None + and model_kwargs.get("image_grid_thw") is not None + and model_kwargs.get("image_metadata") is not None + ): + model_kwargs = {k: v for k, v in model_kwargs.items() if k != "input_ids"} + vision_positions, rope_deltas = self.model.get_rope_index(inputs_tensor, **model_kwargs) + self.model.rope_deltas = rope_deltas + else: + vision_positions = text_positions.unsqueeze(0).expand(3, -1, -1) + self.model.rope_deltas = torch.zeros( + inputs_tensor.shape[0], 1, dtype=torch.long, device=inputs_tensor.device + ) - section = [int(v) for v in section] - return section + return torch.cat([text_positions[None, ...], vision_positions], dim=0) - def apply_interleaved_mrope(self, freqs: torch.Tensor, mrope_section: list[int]) -> torch.Tensor: - chunks = freqs.split(tuple(mrope_section), dim=-1) - return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: torch.LongTensor | None = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + position_ids = model_kwargs.pop("position_ids", None) + if expand_size == 1: + if position_ids is not None: + model_kwargs["position_ids"] = position_ids + return input_ids, model_kwargs + visual_keys = ["pixel_values", "image_grid_thw", "image_metadata"] + for key in visual_keys: + value = model_kwargs.get(key) + if value is not None: + model_kwargs[key] = value.repeat_interleave(expand_size, dim=0) -class IsaacTextAttention(Qwen3VLTextAttention): - pass + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + for key, value in list(model_kwargs.items()): + if key == "position_ids" and value is not None and value.ndim == 3: + model_kwargs[key] = value.repeat_interleave(expand_size, dim=1) + elif value is not None and isinstance(value, torch.Tensor) and key not in visual_keys: + model_kwargs[key] = value.repeat_interleave(expand_size, dim=0) -class IsaacTextDecoderLayer(Qwen3VLTextDecoderLayer): - pass + if position_ids is not None: + dim = 1 if position_ids.ndim == 3 else 0 + model_kwargs["position_ids"] = position_ids.repeat_interleave(expand_size, dim=dim) + return input_ids, model_kwargs -class IsaacTextModel(Qwen3VLTextModel): - def __init__(self, config: IsaacTextConfig): - super().__init__(config) - self.rotary_emb = IsaacRotaryEmbedding(config=config, device=self.device) +# --------------------------------Isaac Image Processor-------------------------------- -@auto_docstring -class IsaacModel(Qwen3VLModel): - input_modalities = ("image", "text") - supports_gradient_checkpointing = True - _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionEncoderLayer"] - _can_compile_fullgraph = False - _supports_flex_attn = False - _tied_weights_keys = {} - _input_embed_layer = "language_model.embed_tokens" +class IsaacImageProcessorKwargs(ImagesKwargs, total=False): + """ + patch_size (`int`, *optional*): + Side length (in pixels) for square patches extracted from resized images. + max_num_patches (`int`, *optional*): + Upper bound on extracted patches per image after resizing. + min_num_patches (`int`, *optional*): + Lower bound on extracted patches per image after resizing. + pixel_shuffle_scale (`int`, *optional*): + Pixel-shuffle reduction factor applied in the vision tower. + """ - def __init__(self, config: IsaacConfig): - PreTrainedModel.__init__(self, config) - self.language_model = IsaacTextModel._from_config(config.text_config) - self.visual = IsaacVisionTransformer(config.vision_config) - self.multimodal_projector = IsaacMultiModalProjector(config) - self.max_sequence_length = config.max_sequence_length - self.vision_rescale_factor = config.vision_rescale_factor - self.rope_deltas = None + patch_size: int + max_num_patches: int + min_num_patches: int + pixel_shuffle_scale: int + + +def get_scaled_image_size( + scale: float, + original_size: int, + patch_size: int, + pixel_shuffle_scale: int, +) -> int: + scaled_size = scale * original_size + divisor = patch_size * pixel_shuffle_scale + scaled_size = math.ceil(scaled_size / divisor) * divisor + scaled_size = max(divisor, scaled_size) + return int(scaled_size) + + +def get_image_size_for_max_num_patches( + image_height: int, + image_width: int, + patch_size: int, + max_num_patches: int, + min_num_patches: int | None = None, + eps: float = 1e-5, + pixel_shuffle_scale: int = 1, +) -> tuple[int, int]: + r"""Compute a target resolution whose patch grid satisfies patching parametrization. + + Args: + image_height (`int`): + Height in pixels of the source image prior to any resizing. + image_width (`int`): + Width in pixels of the source image prior to any resizing. + patch_size (`int`): + Size of the square patch used by the vision encoder. + max_num_patches (`int`): + Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. + min_num_patches (`int`, *optional*): + Lower bound on the number of patches. When provided the image will be scaled up if necessary. + eps (`float`, *optional*, defaults to 1e-5): + Convergence tolerance for the internal binary search to determing the target dimensions. + pixel_shuffle_scale (`int`, *optional*, defaults to 1): + Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. - self.post_init() + Returns: + `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` + and respect both the maximum and optional minimum patch-count constraints. + """ - @can_return_tuple - @auto_docstring - def get_image_features( - self, - pixel_values: torch.Tensor, - image_grid_thw: torch.Tensor, - image_metadata: torch.Tensor | None = None, - **kwargs: Unpack[TransformersKwargs], - ) -> tuple | BaseModelOutputWithPooling: - """ - Args: - pixel_values (`torch.Tensor`): - Batch-major patch vectors with shape `(batch_size, max_images, max_patches, patch_dim)`. - image_grid_thw (`torch.Tensor`): - Batch-major grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. - image_metadata (`torch.Tensor`, *optional*): - Batch-major per-slot metadata `(offset, length)` shaped `(batch_size, max_images, 2)`. - """ - if pixel_values.shape[0] == 0: - hidden_size = self.config.get_text_config().hidden_size - return BaseModelOutputWithPooling( - last_hidden_state=pixel_values.new_zeros((0, 0, hidden_size)), - pooler_output=(), - hidden_states=None, - attentions=None, - ) + # Ensure divisibility + divisor = patch_size * pixel_shuffle_scale + adjusted_height = math.ceil(image_height / divisor) * divisor + adjusted_height = max(divisor, adjusted_height) + adjusted_width = math.ceil(image_width / divisor) * divisor + adjusted_width = max(divisor, adjusted_width) - image_grid_thw = image_grid_thw.to(dtype=torch.long) - active_slot_mask = image_grid_thw[..., 0].eq(1) - if not active_slot_mask.any(): - hidden_size = self.config.get_text_config().hidden_size - return BaseModelOutputWithPooling( - last_hidden_state=pixel_values.new_zeros((0, 0, hidden_size)), - pooler_output=(), - hidden_states=None, - attentions=None, - ) + num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) - flat_pixel_values = pixel_values[active_slot_mask] - flat_image_grid_thw = image_grid_thw[active_slot_mask] - vision_outputs: BaseModelOutputWithPooling = self.visual( - pixel_values=flat_pixel_values, - image_grid_thw=flat_image_grid_thw, - return_dict=True, - **kwargs, - ) - projected_features = self.multimodal_projector(vision_outputs.last_hidden_state) + if min_num_patches is not None and num_patches < min_num_patches: + # Scale up via binary search to satisfy the minimum patch budget while + # preserving divisibility by patch_size * pixel_shuffle_scale. + scale_min, scale_max = 1.0, 100.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + num_patches = (target_height / patch_size) * (target_width / patch_size) + if num_patches >= min_num_patches: + scale_max = scale + else: + scale_min = scale + scale = scale_max + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + return target_height, target_width + elif num_patches <= max_num_patches: + return adjusted_height, adjusted_width + else: + # Scale down + scale_min, scale_max = eps / 10, 1.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + num_patches = (target_height / patch_size) * (target_width / patch_size) + if num_patches <= max_num_patches: + scale_min = scale + else: + scale_max = scale + scale = scale_min + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + return target_height, target_width - pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor - full_lengths = flat_image_grid_thw[:, 1].div(pixel_shuffle_scale, rounding_mode="floor") * flat_image_grid_thw[ - :, 2 - ].div(pixel_shuffle_scale, rounding_mode="floor") - if image_metadata is None: - offsets = torch.zeros_like(full_lengths) - lengths = full_lengths - else: - torch_compilable_check( - image_metadata.shape[:2] == image_grid_thw.shape[:2], - "IsaacModel.get_image_features expects batch-major metadata aligned with `image_grid_thw`.", - ) - active_metadata = image_metadata[active_slot_mask] - offsets = active_metadata[:, 0].to(device=projected_features.device, dtype=torch.long) - lengths = active_metadata[:, 1].to(device=projected_features.device, dtype=torch.long) - image_features = tuple( - projected_features[image_idx, offset : offset + length] - for image_idx, (offset, length) in enumerate(zip(offsets.tolist(), lengths.tolist(), strict=True)) - ) +@auto_docstring +class IsaacImageProcessor(TorchvisionBackend): + model_input_names = ["pixel_values", "image_grid_thw"] + valid_kwargs = IsaacImageProcessorKwargs - return BaseModelOutputWithPooling( - last_hidden_state=projected_features, - pooler_output=image_features, - hidden_states=vision_outputs.hidden_states, - attentions=vision_outputs.attentions, - ) + resample = PILImageResampling.BILINEAR + do_resize = True + do_center_crop = False + patch_size = 16 + max_num_patches = 256 + min_num_patches = None + pixel_shuffle_scale = 1 + do_pad = True + do_rescale = True + do_normalize = True + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + do_convert_rgb = True + disable_grouping = False - def get_placeholder_mask( - self, - mm_token_type_ids: torch.LongTensor, - inputs_embeds: torch.FloatTensor, - image_features: torch.FloatTensor, - ) -> torch.BoolTensor: - image_token_mask = mm_token_type_ids.to(dtype=torch.long) == 1 - n_image_tokens = image_token_mask.sum() - image_token_mask = image_token_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - torch_compilable_check( - inputs_embeds[image_token_mask].numel() == image_features.numel(), - f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", - ) - return image_token_mask + def __init__(self, **kwargs: Unpack[IsaacImageProcessorKwargs]): + super().__init__(**kwargs) - def get_video_features( - self, - pixel_values_videos: torch.FloatTensor, - video_grid_thw: torch.LongTensor | None = None, - **kwargs: Unpack[TransformersKwargs], - ) -> tuple | BaseModelOutputWithPooling: - raise ValueError("Isaac is image-only and does not support `pixel_values_videos` or `video_grid_thw`.") + def _validate_preprocess_kwargs(self, **kwargs): + # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for. + kwargs.pop("do_resize", None) + return super()._validate_preprocess_kwargs(**kwargs) - def get_vision_position_ids( + def _prepare_images_structure( self, - start_position: int, - grid_thw: torch.LongTensor, - image_metadata: torch.LongTensor, - ) -> torch.LongTensor: - pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor - height = grid_thw[1].div(pixel_shuffle_scale, rounding_mode="floor").item() - width = grid_thw[2].div(pixel_shuffle_scale, rounding_mode="floor").item() - token_positions = torch.arange(height * width, device=grid_thw.device, dtype=torch.long) - vision_position_ids = torch.stack( - ( - torch.full((token_positions.shape[0],), start_position, device=grid_thw.device, dtype=torch.long), - token_positions.div(width, rounding_mode="floor"), - token_positions.remainder(width), - ), - dim=0, - ) - token_offset = int(image_metadata[0].item()) - token_length = int(image_metadata[1].item()) - return vision_position_ids[:, token_offset : token_offset + token_length] + images: ImageInput, + expected_ndims: int = 3, + ) -> ImageInput: + images = self.fetch_images(images) + return make_nested_list_of_images(images, expected_ndims=expected_ndims) - def get_rope_index( + def resize( self, - input_ids: torch.LongTensor | None, - mm_token_type_ids: torch.Tensor, - image_grid_thw: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - image_metadata: torch.Tensor | None = None, + image: torch.Tensor, + size: SizeDict, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor]: - if image_grid_thw is None or image_metadata is None: - raise ValueError("Isaac multimodal RoPE requires both `image_grid_thw` and `image_metadata`.") - - if attention_mask is None: - if input_ids is None: - attention_mask = mm_token_type_ids.new_ones(mm_token_type_ids.shape, dtype=torch.long) - else: - attention_mask = input_ids.new_ones(input_ids.shape, dtype=torch.long) + ) -> torch.Tensor: + if image.dtype == torch.uint8: + image = F.interpolate(image.float(), size=(size.height, size.width), mode="bilinear", align_corners=False) + return image.clamp(0, 255).round().to(torch.uint8) + return F.interpolate(image, size=(size.height, size.width), mode="bilinear", align_corners=False) - if input_ids is None: - batch_size, seq_len = attention_mask.shape - position_dtype = torch.long - else: - batch_size, seq_len = input_ids.shape - position_dtype = input_ids.dtype + def pack_images( + self, + vision_patches: list[list[torch.Tensor]], + vision_token_grids: list[list[torch.Tensor]], + ) -> dict[str, torch.Tensor | None]: + batch_size = len(vision_patches) + flat_patches = [patches for sample_patches in vision_patches for patches in sample_patches] + if len(flat_patches) == 0: + return {"pixel_values": None, "image_grid_thw": None} - device = attention_mask.device - mm_token_type_ids = mm_token_type_ids.to(dtype=torch.long) - image_grid_thw = image_grid_thw.to(dtype=torch.long) - image_metadata = image_metadata.to(dtype=torch.long) - attention_mask = attention_mask.to(dtype=torch.long) - active_slot_mask = image_grid_thw[..., 0].eq(1) + first_patch = flat_patches[0] + max_patches = max(patches.shape[0] for patches in flat_patches) + max_images = max((len(sample_patches) for sample_patches in vision_patches), default=0) - position_ids = torch.zeros((3, batch_size, seq_len), device=device, dtype=position_dtype) - rope_deltas = torch.zeros((batch_size, 1), device=device, dtype=torch.long) + patch_dim = first_patch.shape[-1] + tensors = { + "pixel_values": torch.zeros( + (batch_size, max_images, max_patches, patch_dim), + device=first_patch.device, + dtype=first_patch.dtype, + ), + "image_grid_thw": torch.zeros((batch_size, max_images, 3), device=first_patch.device, dtype=torch.long), + } - for batch_idx in range(batch_size): - sample_attention_mask = attention_mask[batch_idx].bool() - sample_token_types = mm_token_type_ids[batch_idx][sample_attention_mask] - sample_grids = image_grid_thw[batch_idx] - sample_metadata = image_metadata[batch_idx] - sample_active_slots = active_slot_mask[batch_idx] + for batch_idx, (sample_patches, sample_token_grids) in enumerate( + zip(vision_patches, vision_token_grids, strict=True) + ): + for image_slot, (patches, token_grid) in enumerate(zip(sample_patches, sample_token_grids, strict=True)): + patch_count = int(patches.shape[0]) + tensors["pixel_values"][batch_idx, image_slot, :patch_count] = patches + tensors["image_grid_thw"][batch_idx, image_slot, 0] = 1 + tensors["image_grid_thw"][batch_idx, image_slot, 1:] = token_grid - current_pos = 0 - image_idx = 0 - seq_pos = 0 - llm_pos_ids_list = [] + return tensors - while seq_pos < sample_token_types.shape[0]: - modality_type = int(sample_token_types[seq_pos].item()) - if modality_type == 0: - group_end = seq_pos + 1 - while group_end < sample_token_types.shape[0] and sample_token_types[group_end] == 0: - group_end += 1 - group_length = group_end - seq_pos - llm_pos_ids_list.append( - torch.arange(group_length, device=device, dtype=torch.long).view(1, -1).expand(3, -1) - + current_pos - ) - current_pos += group_length - seq_pos = group_end - else: - while image_idx < sample_metadata.shape[0] and ( - not bool(sample_active_slots[image_idx].item()) or sample_metadata[image_idx, 1].item() == 0 - ): - image_idx += 1 - torch_compilable_check( - image_idx < sample_metadata.shape[0], - "Isaac multimodal sequence has more visible image tokens than batch-major image metadata slots.", - ) - token_length = int(sample_metadata[image_idx, 1].item()) - torch_compilable_check( - token_length <= sample_token_types.shape[0] - seq_pos, - "Isaac image metadata length exceeds the remaining multimodal placeholder span.", - ) - llm_pos_ids_list.append( - self.get_vision_position_ids(current_pos, sample_grids[image_idx], sample_metadata[image_idx]) + def _preprocess( + self, + images: list[list[torch.Tensor]], + do_resize: bool, + resample: PILImageResampling | tvF.InterpolationMode | int | None, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: float | list[float] | None, + image_std: float | list[float] | None, + do_pad: bool, + patch_size: int, + max_num_patches: int, + min_num_patches: int, + pixel_shuffle_scale: int, + disable_grouping: bool | None, + return_tensors: str | TensorType | None, + **kwargs, + ) -> BatchFeature: + if all(len(sample_images) == 0 for sample_images in images): + return BatchFeature(data={"pixel_values": None, "image_grid_thw": None}, tensor_type=return_tensors) + + grouped_images, grouped_images_index = group_images_by_shape( + images, disable_grouping=disable_grouping, is_nested=True + ) + grouped_outputs = {} + for shape, stacked_images in grouped_images.items(): + grouped_batch_size, channels, original_height, original_width = stacked_images.shape + if do_resize: + target_height, target_width = get_image_size_for_max_num_patches( + original_height, + original_width, + patch_size, + max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + ) + image_batch = self.resize( + stacked_images, SizeDict(height=target_height, width=target_width), resample=resample + ) + else: + if (original_height % patch_size) or (original_width % patch_size): + raise ValueError( + f"Image dimensions (h={original_height}, w={original_width}) must be divisible by patch_size={patch_size} when resize is disabled; enable resizing or adjust the input resolution." ) - current_pos += 1 - seq_pos += token_length - image_idx += 1 + image_batch, target_height, target_width = stacked_images, original_height, original_width - llm_positions = ( - torch.cat(llm_pos_ids_list, dim=1) - if llm_pos_ids_list - else torch.zeros((3, 0), device=device, dtype=torch.long) + image_batch = self.rescale_and_normalize( + image_batch, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, ) - position_ids[:, batch_idx, sample_attention_mask] = llm_positions - rope_deltas[batch_idx, 0] = ( - llm_positions.max() + 1 - sample_token_types.shape[0] if llm_positions.numel() > 0 else 0 + + patches = torch_extract_patches(image_batch, patch_size, patch_size) + _, height_tokens, width_tokens, patch_dim = patches.shape + + token_grid = ( + torch.tensor([height_tokens, width_tokens], device=patches.device).long().expand(grouped_batch_size, 2) ) - return position_ids, rope_deltas + if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): + raise ValueError( + f"Token grid (h={height_tokens}, w={width_tokens}) must be divisible by pixel_shuffle_scale={pixel_shuffle_scale};" + f" adjust resize/patch parameters or disable pixel shuffle." + ) - def compute_3d_position_ids( + grouped_outputs[shape] = ( + patches.reshape(grouped_batch_size, -1, patch_dim), + token_grid, + ) + + keys = ("vision_patches", "vision_token_grids") + nested_outputs = {} + for i, key in enumerate(keys): + nested_outputs[key] = reorder_images( + {shape: values[i] for shape, values in grouped_outputs.items()}, + dict(grouped_images_index), + is_nested=True, + ) + + if not do_pad: + raise ValueError("IsaacImageProcessor doesn't support `do_pad=False` mode.") + + tensors = self.pack_images( + vision_patches=nested_outputs["vision_patches"], + vision_token_grids=nested_outputs["vision_token_grids"], + ) + + return BatchFeature(data=tensors, tensor_type=return_tensors) + + def get_number_of_image_patches( self, - input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - attention_mask: torch.Tensor | None, - mm_token_type_ids: torch.Tensor | None = None, - image_grid_thw: torch.Tensor | None = None, - image_metadata: torch.Tensor | None = None, - past_key_values: Cache | None = None, - ) -> torch.Tensor: - past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() + image_height: int, + image_width: int, + images_kwargs: dict[str, Any] | None = None, + ) -> int: + images_kwargs = images_kwargs or {} + patch_size = images_kwargs.get("patch_size", self.patch_size) + max_num_patches = images_kwargs.get("max_num_patches", self.max_num_patches) + min_num_patches = images_kwargs.get("min_num_patches", self.min_num_patches) + pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale", self.pixel_shuffle_scale) - has_multimodal = ( - image_grid_thw is not None - and image_metadata is not None - and bool(image_grid_thw[..., 0].eq(1).any().item()) + target_height, target_width = get_image_size_for_max_num_patches( + image_height, + image_width, + patch_size, + max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, ) - if has_multimodal and mm_token_type_ids is None and input_ids is not None: - raise ValueError( - "Multimodal data was passed (via `image_grid_thw`) but `mm_token_type_ids` is missing. " - "Please pass `mm_token_type_ids` so Isaac can build multimodal RoPE positions." - ) + return (target_height // patch_size) * (target_width // patch_size) - if has_multimodal and past_seen_tokens == 0: - position_ids, rope_deltas = self.get_rope_index( - input_ids=input_ids, - mm_token_type_ids=mm_token_type_ids, - image_grid_thw=image_grid_thw, - image_metadata=image_metadata, - attention_mask=attention_mask, - ) - self.rope_deltas = rope_deltas - return position_ids - if self.rope_deltas is None: - return None +# --------------------------------Isaac Processor-------------------------------- - rope_deltas = torch.as_tensor(self.rope_deltas, device=inputs_embeds.device, dtype=torch.long).reshape(-1, 1) - if rope_deltas.shape[0] != inputs_embeds.shape[0]: - if inputs_embeds.shape[0] % rope_deltas.shape[0] == 0: - rope_deltas = rope_deltas.repeat_interleave(inputs_embeds.shape[0] // rope_deltas.shape[0], dim=0) - else: - rope_deltas = rope_deltas[:1].expand(inputs_embeds.shape[0], -1) - if attention_mask is not None and attention_mask.shape[-1] > inputs_embeds.shape[1]: - rope_position = attention_mask.long().cumsum(dim=-1) - 1 - rope_position = rope_position.masked_fill(attention_mask == 0, 0) - rope_position = rope_position[:, -inputs_embeds.shape[1] :] - else: - rope_position = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - dtype=torch.long, - ).view(1, -1) - rope_position = rope_position.expand(inputs_embeds.shape[0], -1) +class IsaacProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs = IsaacImageProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": True, + "truncation": True, + "truncation_side": "left", + "return_attention_mask": True, + "return_overflowing_tokens": True, + "return_mm_token_type_ids": True, + "add_special_tokens": False, + }, + } - position_ids = rope_position.view(1, inputs_embeds.shape[0], -1).expand(3, -1, -1) - return position_ids + rope_deltas.to(device=inputs_embeds.device).unsqueeze(0) - @auto_docstring( - custom_intro=""" - Forward pass with multimodal MRoPE position ids. +class SinglePoint(NamedTuple): + x: int + y: int + mention: str | None = None + t: float | None = None - When image placeholders are present, Isaac computes vision features, scatters them into the token - embeddings, and runs the shared text backbone on the mixed sequence. - """, - ) - @can_return_tuple - def forward( - self, - input_ids: torch.LongTensor | None = None, - mm_token_type_ids: torch.LongTensor | None = None, - pixel_values: torch.Tensor | None = None, - image_grid_thw: torch.LongTensor | None = None, - image_metadata: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - use_cache: bool | None = None, - **kwargs: Unpack[TransformersKwargs], - ) -> tuple | BaseModelOutputWithPast: - """ - Args: - mm_token_type_ids (`torch.LongTensor`, *optional*): - Multimodal token type ids aligned with the token sequence, using `0 -> text` and `1 -> image`. - pixel_values (`torch.FloatTensor`, *optional*): - Batch-major patch vectors shaped `(batch_size, max_images, max_patches, patch_dim)`. - image_grid_thw (`torch.LongTensor`, *optional*): - Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. - image_metadata (`torch.LongTensor`, *optional*): - Batch-major per-slot metadata shaped `(batch_size, max_images, 2)` with `(offset, length)`. - """ - if (input_ids is None) == (inputs_embeds is None): - raise ValueError("You must specify exactly one of `input_ids` or `inputs_embeds`.") - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) +class BoundingBox(NamedTuple): + top_left: Any + bottom_right: Any + mention: str | None = None + t: float | None = None - batch_size, seq_len = inputs_embeds.shape[:2] - if mm_token_type_ids is None: - mm_token_type_ids = torch.full((batch_size, seq_len), 0, device=inputs_embeds.device, dtype=torch.long) - else: - mm_token_type_ids = mm_token_type_ids.to(device=inputs_embeds.device, dtype=torch.long) - if mm_token_type_ids.shape[1] < seq_len: - padding = mm_token_type_ids.new_zeros((batch_size, seq_len - mm_token_type_ids.shape[1])) - mm_token_type_ids = torch.cat([mm_token_type_ids, padding], dim=1) - elif mm_token_type_ids.shape[1] > seq_len: - mm_token_type_ids = mm_token_type_ids[:, -seq_len:] - if image_metadata is not None: - image_metadata = image_metadata.to(device=inputs_embeds.device, dtype=torch.long) +class Polygon(NamedTuple): + points: tuple[Any, ...] + mention: str | None = None + t: float | None = None - image_mask = None - has_active_images = ( - pixel_values is not None and image_grid_thw is not None and bool(image_grid_thw[..., 0].eq(1).any().item()) - ) - if has_active_images: - image_outputs = self.get_image_features( - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - image_metadata=image_metadata, - return_dict=True, - ) - image_embeds = image_outputs.pooler_output - if len(image_embeds) > 0: - image_embeds = torch.cat(image_embeds, dim=0).to( - device=inputs_embeds.device, dtype=inputs_embeds.dtype - ) - image_mask = self.get_placeholder_mask( - mm_token_type_ids=mm_token_type_ids, - inputs_embeds=inputs_embeds, - image_features=image_embeds, - ) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - if isinstance(attention_mask, dict): - attention_mask = attention_mask["full_attention"] +_point_box_or_polygon_tag = re.compile( + r"<(?Ppoint|point_box|polygon)(?P[^>]*)>(?P[\s\S]*?)", re.IGNORECASE +) +_attr_re = re.compile(r"(\w+)\s*=\s*(?:\"([^\"]*)\"|([^\s>]+))") +_coord_re = re.compile(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)") - past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() - computed_position_ids = self.compute_3d_position_ids( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - mm_token_type_ids=mm_token_type_ids, - image_grid_thw=image_grid_thw, - image_metadata=image_metadata, - attention_mask=attention_mask, - past_key_values=past_key_values, + +@auto_docstring +class IsaacProcessor(ProcessorMixin): + def __init__( + self, + image_processor, + tokenizer, + chat_template: str | dict[str, str] | None = None, + max_sequence_length: int = 16384, + ): + r""" + max_sequence_length (`int`, *optional*, defaults to 16384): + Maximum packed multimodal sequence length produced by the processor. + """ + if chat_template is None: + chat_template = getattr(tokenizer, "chat_template", None) + + self.pad_token_id = tokenizer.pad_token_id + self.image_token = getattr(tokenizer, "image_pad_token", None) or getattr(tokenizer, "image_token", None) + self.image_token_id = getattr(tokenizer, "image_pad_token_id", None) or getattr( + tokenizer, "image_token_id", None ) - if computed_position_ids is not None: - position_ids = computed_position_ids - elif past_seen_tokens > 0: - position_ids = None - elif position_ids is not None and past_seen_tokens == 0: - position_ids = position_ids.to(device=inputs_embeds.device) - if position_ids.ndim == 2: - position_ids = position_ids.view(1, position_ids.shape[0], -1).expand(3, -1, -1) + self.max_sequence_length = max_sequence_length + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + text: str | list[str], + images: ImageInput | None = None, + **kwargs, + ) -> BatchFeature: + output_kwargs = self._merge_kwargs( + IsaacProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + # 1. Validate number of that text and images match + texts = [text] if isinstance(text, str) else text.copy() + rendered_image_token = "" + if self.image_token is not None and self.image_token != rendered_image_token: + # Isaac's current chat template still renders ``, while the tokenizer exposes + # `<|image_pad|>`. Normalize here so apply_chat_template(..., tokenize=True, + # return_dict=True) follows the standard ProcessorMixin path. + texts = [text_value.replace(rendered_image_token, self.image_token) for text_value in texts] + if images is None: + batched_images = [[] for _ in texts] + else: + fetched_images = self.image_processor.fetch_images(images) + batched_images = make_nested_list_of_images(fetched_images) + if len(batched_images) != len(texts): + num_images_in_text = [text_value.count(self.image_token) for text_value in texts] + num_images_in_images = [len(sample_images) for sample_images in batched_images] + add_message = "" + if sum(num_images_in_text) == sum(num_images_in_images): + add_message = " Make sure to pass your images as a nested list, where each sub-list holds images for one text sample." + raise ValueError( + f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(texts)}).{add_message}" + ) - outputs = self.language_model( - input_ids=None, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - visual_pos_masks=image_mask[..., 0] if image_mask is not None else None, - deepstack_visual_embeds=None, - use_cache=use_cache, - **kwargs, - ) + # 2. Process images + image_inputs = self.image_processor(images=batched_images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] - outputs_with_rope = BaseModelOutputWithPast( - last_hidden_state=outputs.last_hidden_state, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - outputs_with_rope["rope_deltas"] = self.rope_deltas - return outputs_with_rope + # 3. Expand text with image placeholders + merge_length = self.image_processor.pixel_shuffle_scale**2 + if image_grid_thw is None: + vision_segment_lengths = None + else: + vision_segment_lengths = image_grid_thw.prod(dim=-1) // merge_length + for batch_idx in range(len(texts)): + image_idx = 0 + while self.image_token in texts[batch_idx]: + num_image_tokens = vision_segment_lengths[batch_idx, image_idx] + texts[batch_idx] = texts[batch_idx].replace( + self.image_token, "<|placeholder|>" * num_image_tokens, 1 + ) + image_idx += 1 + texts[batch_idx] = texts[batch_idx].replace("<|placeholder|>", self.image_token) + + # 4. Process text + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids") + max_length = output_kwargs["text_kwargs"].pop("max_length", None) + max_length = self.max_sequence_length if max_length is None else max_length + text_inputs = self.tokenizer(texts, max_length=max_length, **output_kwargs["text_kwargs"]) + + truncated_input_ids: list[list[int] | None] = [None] * len(texts) + truncated_attention_mask: list[list[int] | None] = [None] * len(texts) + offset_mappings = text_inputs.get("offset_mapping") + truncated_offset_mapping: list[list[list[int]] | None] | None = None + if offset_mappings is not None: + truncated_offset_mapping = [None] * len(texts) + overflow_input_ids_per_sample = defaultdict(int) + + # 5. Drop overflowing token ids + if offset_mappings is None: + iterator = zip( + text_inputs["overflow_to_sample_mapping"], text_inputs["input_ids"], text_inputs["attention_mask"] + ) + else: + iterator = zip( + text_inputs["overflow_to_sample_mapping"], + text_inputs["input_ids"], + text_inputs["attention_mask"], + offset_mappings, + ) + for sample in iterator: + if offset_mappings is None: + batch_idx, input_ids, attention_mask = sample + offset_mapping = None + else: + batch_idx, input_ids, attention_mask, offset_mapping = sample -@auto_docstring -class IsaacForConditionalGeneration(Qwen3VLForConditionalGeneration, GenerationMixin): - config_class = IsaacConfig - input_modalities = ("image", "text") - _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionEncoderLayer"] - _can_compile_fullgraph = False - _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + if truncated_input_ids[batch_idx] is None: + truncated_input_ids[batch_idx] = input_ids + truncated_attention_mask[batch_idx] = attention_mask + if truncated_offset_mapping is not None: + truncated_offset_mapping[batch_idx] = offset_mapping + else: + overflow_input_ids_per_sample[batch_idx] += input_ids.count(self.image_token_id) - def __init__(self, config: IsaacConfig): - PreTrainedModel.__init__(self, config) - self.model = IsaacModel(config) - self.vocab_size = config.get_text_config().vocab_size - self.lm_head = nn.Linear(config.get_text_config().hidden_size, config.get_text_config().vocab_size, bias=False) - self.post_init() + # 6. Do the same for overflowing pixel values. Isaac truncates images based on `max_length` + # We can't really truncate pixels, so we pass over an image offset mask. Model will crop off + # truncated image pixels at run-time using this mask + image_metadata = None + if image_grid_thw is not None: + batch_size, max_images = image_grid_thw.shape[:2] + image_metadata = torch.zeros((batch_size, max_images, 2), dtype=torch.long) + for batch_idx, image_lengths in enumerate(vision_segment_lengths): + remaining_dropped = overflow_input_ids_per_sample[batch_idx] + for image_idx, length in enumerate(image_lengths): + offset = 0 + if 0 < remaining_dropped < length: + offset = remaining_dropped + length -= offset + remaining_dropped = 0 + elif remaining_dropped >= length: + dropped_length = length + length = 0 + remaining_dropped -= dropped_length + + # Record which suffix of this image's placeholder span survives left truncation. + # The model still encodes the full image and uses this window for both feature gathering and vision RoPE. + image_metadata[batch_idx, image_idx, 0] = offset + image_metadata[batch_idx, image_idx, 1] = length - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask: torch.Tensor | None = None, - inputs_embeds: torch.FloatTensor | None = None, - mm_token_type_ids: torch.LongTensor | None = None, - pixel_values: torch.Tensor | None = None, - image_grid_thw: torch.LongTensor | None = None, - image_metadata: torch.LongTensor | None = None, - position_ids: torch.LongTensor | None = None, - is_first_iteration=False, - use_cache=True, - **kwargs, - ) -> dict[str, Any]: - model_inputs = super().prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - is_first_iteration=is_first_iteration, - use_cache=use_cache, - **kwargs, - ) - is_prefill = is_first_iteration or not use_cache - multimodal_inputs = { - "mm_token_type_ids": mm_token_type_ids, - "pixel_values": pixel_values, - "image_grid_thw": image_grid_thw, + data = { + "input_ids": torch.tensor(truncated_input_ids, dtype=torch.long), + "attention_mask": torch.tensor(truncated_attention_mask, dtype=torch.long), "image_metadata": image_metadata, + **image_inputs, } - for key, value in multimodal_inputs.items(): - model_inputs[key] = value if is_prefill else None - if model_inputs["mm_token_type_ids"] is not None: - sequence_length = None - if model_inputs.get("input_ids") is not None: - sequence_length = model_inputs["input_ids"].shape[1] - elif model_inputs.get("inputs_embeds") is not None: - sequence_length = model_inputs["inputs_embeds"].shape[1] + if truncated_offset_mapping is not None: + data["offset_mapping"] = torch.tensor(truncated_offset_mapping, dtype=torch.long) - if sequence_length is not None: - current_length = model_inputs["mm_token_type_ids"].shape[1] - if current_length < sequence_length: - padding = model_inputs["mm_token_type_ids"].new_zeros( - (model_inputs["mm_token_type_ids"].shape[0], sequence_length - current_length) - ) - model_inputs["mm_token_type_ids"] = torch.cat([model_inputs["mm_token_type_ids"], padding], dim=1) - elif current_length > sequence_length: - model_inputs["mm_token_type_ids"] = model_inputs["mm_token_type_ids"][:, -sequence_length:] - return model_inputs + if return_mm_token_type_ids: + data["mm_token_type_ids"] = self.create_mm_token_type_ids(data["input_ids"]) - def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs): - text_positions = GenerationMixin._prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs) + return BatchFeature(data=data, tensor_type=return_tensors) - past_length = 0 - if (cache := model_kwargs.get("past_key_values")) is not None: - past_length = cache.get_seq_length() - if past_length != 0 and self.model.rope_deltas is not None: - return text_positions[None, ...] + self.model.rope_deltas + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + vision_data = {} + if image_sizes is not None: + images_kwargs = dict(IsaacProcessorKwargs._defaults.get("images_kwargs", {})) + images_kwargs.update(kwargs) - if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0: - inputs_tensor = model_kwargs["input_ids"] + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) + for image_size in image_sizes + ] + pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale") or self.image_processor.pixel_shuffle_scale + num_image_tokens = [num_patches // pixel_shuffle_scale**2 for num_patches in num_image_patches] + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) - is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long] - if ( - is_input_ids - and model_kwargs.get("mm_token_type_ids") is not None - and model_kwargs.get("image_grid_thw") is not None - and model_kwargs.get("image_metadata") is not None - ): - model_kwargs = {k: v for k, v in model_kwargs.items() if k != "input_ids"} - vision_positions, rope_deltas = self.model.get_rope_index(inputs_tensor, **model_kwargs) - self.model.rope_deltas = rope_deltas - else: - vision_positions = text_positions.unsqueeze(0).expand(3, -1, -1) - self.model.rope_deltas = torch.zeros( - inputs_tensor.shape[0], 1, dtype=torch.long, device=inputs_tensor.device - ) + return MultiModalData(**vision_data) - return torch.cat([text_positions[None, ...], vision_positions], dim=0) + @property + def model_input_names(self): + return super().model_input_names + ["mm_token_type_ids", "image_metadata"] - def _expand_inputs_for_generation( - self, - expand_size: int = 1, - is_encoder_decoder: bool = False, - input_ids: torch.LongTensor | None = None, - **model_kwargs, - ) -> tuple[torch.LongTensor, dict[str, Any]]: - position_ids = model_kwargs.pop("position_ids", None) - if expand_size == 1: - if position_ids is not None: - model_kwargs["position_ids"] = position_ids - return input_ids, model_kwargs + @staticmethod + def _maybe_float(value: str | None) -> float | None: + try: + return float(value) + except (ValueError, TypeError): + return None - visual_keys = ["pixel_values", "image_grid_thw", "image_metadata"] - for key in visual_keys: - value = model_kwargs.get(key) - if value is not None: - model_kwargs[key] = value.repeat_interleave(expand_size, dim=0) + @classmethod + def _parse_attrs(cls, attr_text: str) -> dict[str, str]: + attrs = {} + for match in _attr_re.finditer(attr_text): + key = match.group(1) + value = match.group(2) or match.group(3) or "" + attrs[key] = value + return attrs - if input_ids is not None: - input_ids = input_ids.repeat_interleave(expand_size, dim=0) + @classmethod + def _parse_point_body( + cls, + body: str, + mention: str | None = None, + t: str | None = None, + ) -> Any: + match = _coord_re.search(body) + if not match: + raise ValueError(f"Malformed tag: {body!r}") + x, y = int(match.group(1)), int(match.group(2)) + return SinglePoint(x=x, y=y, mention=mention, t=cls._maybe_float(t)) - for key, value in list(model_kwargs.items()): - if key == "position_ids" and value is not None and value.ndim == 3: - model_kwargs[key] = value.repeat_interleave(expand_size, dim=1) - elif value is not None and isinstance(value, torch.Tensor) and key not in visual_keys: - model_kwargs[key] = value.repeat_interleave(expand_size, dim=0) + @classmethod + def _parse_box_body( + cls, + body: str, + mention: str | None = None, + t: str | None = None, + ) -> Any: + coords = list(_coord_re.finditer(body)) + if len(coords) < 2: + raise ValueError(f"Malformed tag: {body!r}") - if position_ids is not None: - dim = 1 if position_ids.ndim == 3 else 0 - model_kwargs["position_ids"] = position_ids.repeat_interleave(expand_size, dim=dim) - return input_ids, model_kwargs + top_left = SinglePoint(x=int(coords[0].group(1)), y=int(coords[0].group(2))) + bottom_right = SinglePoint(x=int(coords[1].group(1)), y=int(coords[1].group(2))) + return BoundingBox(top_left=top_left, bottom_right=bottom_right, mention=mention, t=cls._maybe_float(t)) + + @classmethod + def _parse_polygon_body( + cls, + body: str, + mention: str | None = None, + t: str | None = None, + ) -> Any: + coords = list(_coord_re.finditer(body)) + if len(coords) < 3: + raise ValueError(f"Malformed tag: {body!r}") + + points = tuple(SinglePoint(x=int(coord.group(1)), y=int(coord.group(2))) for coord in coords) + return Polygon(points=points, mention=mention, t=cls._maybe_float(t)) + + @classmethod + def clean_text_and_extract_points( + cls, + text: str, + expected: str | None = None, + ) -> tuple[str, list[Any]]: + results: list[Any] = [] + for match in _point_box_or_polygon_tag.finditer(text): + tag = match.group("tag").lower() + attrs = cls._parse_attrs(match.group("attrs")) + mention = attrs.get("mention") + t = attrs.get("t") + if tag == "point": + if expected not in (None, "point"): + continue + results.append(cls._parse_point_body(match.group("body"), mention=mention, t=t)) + elif tag == "point_box": + if expected not in (None, "box"): + continue + results.append(cls._parse_box_body(match.group("body"), mention=mention, t=t)) + else: + if expected not in (None, "polygon"): + continue + results.append(cls._parse_polygon_body(match.group("body"), mention=mention, t=t)) + + clean_text = re.sub(r"\s+", " ", _point_box_or_polygon_tag.sub("", text or "")).strip() + return clean_text, results + + def post_process_generation( + self, + text: str, + expected: str | None = None, + cleanup_and_extract: bool = True, + ) -> str | tuple[str, list[Any]]: + if cleanup_and_extract: + return self.clean_text_and_extract_points(text, expected=expected) + return text + + def post_process_image_text_to_text( + self, + generated_outputs, + skip_special_tokens: bool = True, + cleanup_and_extract: bool = False, + expected: str | None = None, + **kwargs, + ): + generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs) + return [ + self.post_process_generation(text, expected=expected, cleanup_and_extract=cleanup_and_extract) + for text in generated_texts + ] __all__ = [ @@ -1681,7 +1630,7 @@ def _expand_inputs_for_generation( "IsaacTextConfig", "IsaacTextModel", "IsaacVisionConfig", - "IsaacVisionTransformer", + "IsaacVisionModel", "IsaacModel", "IsaacPreTrainedModel", # noqa: F822 "IsaacForConditionalGeneration", diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 4ba3b68e5542..563ea90bde65 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -19,54 +19,49 @@ # limitations under the License. import re -from typing import Any, NamedTuple +from collections import defaultdict +from typing import Any from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, make_nested_list_of_images from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin -from ...utils import TensorType, auto_docstring +from ...utils import auto_docstring from ...utils.import_utils import is_torch_available +from .image_processing_isaac import IsaacImageProcessorKwargs +from .modeling_isaac import BoundingBox, Polygon, SinglePoint if is_torch_available(): import torch +# --------------------------------Isaac Processor-------------------------------- + + class IsaacProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs = IsaacImageProcessorKwargs _defaults = { "text_kwargs": { "padding": True, + "truncation": True, + "truncation_side": "left", "return_attention_mask": True, + "return_overflowing_tokens": True, "return_mm_token_type_ids": True, + "add_special_tokens": False, }, } +_point_box_or_polygon_tag = re.compile( + r"<(?Ppoint|point_box|polygon)(?P[^>]*)>(?P[\s\S]*?)", re.IGNORECASE +) +_attr_re = re.compile(r"(\w+)\s*=\s*(?:\"([^\"]*)\"|([^\s>]+))") +_coord_re = re.compile(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)") + + @auto_docstring class IsaacProcessor(ProcessorMixin): - class _SinglePoint(NamedTuple): - x: int - y: int - mention: str | None = None - t: float | None = None - - class _BoundingBox(NamedTuple): - top_left: Any - bottom_right: Any - mention: str | None = None - t: float | None = None - - class _Polygon(NamedTuple): - points: tuple[Any, ...] - mention: str | None = None - t: float | None = None - - _point_box_or_polygon_tag = re.compile( - r"<(?Ppoint|point_box|polygon)(?P[^>]*)>(?P[\s\S]*?)", re.IGNORECASE - ) - _attr_re = re.compile(r"(\w+)\s*=\s*(?:\"([^\"]*)\"|([^\s>]+))") - _coord_re = re.compile(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)") - def __init__( self, image_processor, @@ -74,25 +69,173 @@ def __init__( chat_template: str | dict[str, str] | None = None, max_sequence_length: int = 16384, ): - """ - Args: - chat_template (`str` or `dict[str, str]`, *optional*): - Chat template override forwarded to [`~processing_utils.ProcessorMixin`]. - max_sequence_length (`int`, *optional*, defaults to 16384): - Maximum packed multimodal sequence length produced by the processor. + r""" + max_sequence_length (`int`, *optional*, defaults to 16384): + Maximum packed multimodal sequence length produced by the processor. """ if chat_template is None: chat_template = getattr(tokenizer, "chat_template", None) - self.image_processor = image_processor - super().__init__(image_processor, tokenizer, chat_template=chat_template) - self.text_pad_token_id = self.pad_token_id = tokenizer.pad_token_id + self.pad_token_id = tokenizer.pad_token_id self.image_token = getattr(tokenizer, "image_pad_token", None) or getattr(tokenizer, "image_token", None) self.image_token_id = getattr(tokenizer, "image_pad_token_id", None) or getattr( tokenizer, "image_token_id", None ) - self.max_sequence_length = max_sequence_length + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + text: str | list[str], + images: ImageInput | None = None, + **kwargs, + ) -> BatchFeature: + output_kwargs = self._merge_kwargs( + IsaacProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + # 1. Validate number of that text and images match + texts = [text] if isinstance(text, str) else text.copy() + rendered_image_token = "" + if self.image_token is not None and self.image_token != rendered_image_token: + # Isaac's current chat template still renders ``, while the tokenizer exposes + # `<|image_pad|>`. Normalize here so apply_chat_template(..., tokenize=True, + # return_dict=True) follows the standard ProcessorMixin path. + texts = [text_value.replace(rendered_image_token, self.image_token) for text_value in texts] + if images is None: + batched_images = [[] for _ in texts] + else: + fetched_images = self.image_processor.fetch_images(images) + batched_images = make_nested_list_of_images(fetched_images) + if len(batched_images) != len(texts): + num_images_in_text = [text_value.count(self.image_token) for text_value in texts] + num_images_in_images = [len(sample_images) for sample_images in batched_images] + add_message = "" + if sum(num_images_in_text) == sum(num_images_in_images): + add_message = " Make sure to pass your images as a nested list, where each sub-list holds images for one text sample." + raise ValueError( + f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(texts)}).{add_message}" + ) + + # 2. Process images + image_inputs = self.image_processor(images=batched_images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + + # 3. Expand text with image placeholders + merge_length = self.image_processor.pixel_shuffle_scale**2 + if image_grid_thw is None: + vision_segment_lengths = None + else: + vision_segment_lengths = image_grid_thw.prod(dim=-1) // merge_length + for batch_idx in range(len(texts)): + image_idx = 0 + while self.image_token in texts[batch_idx]: + num_image_tokens = vision_segment_lengths[batch_idx, image_idx] + texts[batch_idx] = texts[batch_idx].replace( + self.image_token, "<|placeholder|>" * num_image_tokens, 1 + ) + image_idx += 1 + texts[batch_idx] = texts[batch_idx].replace("<|placeholder|>", self.image_token) + + # 4. Process text + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids") + max_length = output_kwargs["text_kwargs"].pop("max_length", None) + max_length = self.max_sequence_length if max_length is None else max_length + text_inputs = self.tokenizer(texts, max_length=max_length, **output_kwargs["text_kwargs"]) + + truncated_input_ids: list[list[int] | None] = [None] * len(texts) + truncated_attention_mask: list[list[int] | None] = [None] * len(texts) + offset_mappings = text_inputs.get("offset_mapping") + truncated_offset_mapping: list[list[list[int]] | None] | None = None + if offset_mappings is not None: + truncated_offset_mapping = [None] * len(texts) + overflow_input_ids_per_sample = defaultdict(int) + + # 5. Drop overflowing token ids + if offset_mappings is None: + iterator = zip( + text_inputs["overflow_to_sample_mapping"], text_inputs["input_ids"], text_inputs["attention_mask"] + ) + else: + iterator = zip( + text_inputs["overflow_to_sample_mapping"], + text_inputs["input_ids"], + text_inputs["attention_mask"], + offset_mappings, + ) + + for sample in iterator: + if offset_mappings is None: + batch_idx, input_ids, attention_mask = sample + offset_mapping = None + else: + batch_idx, input_ids, attention_mask, offset_mapping = sample + + if truncated_input_ids[batch_idx] is None: + truncated_input_ids[batch_idx] = input_ids + truncated_attention_mask[batch_idx] = attention_mask + if truncated_offset_mapping is not None: + truncated_offset_mapping[batch_idx] = offset_mapping + else: + overflow_input_ids_per_sample[batch_idx] += input_ids.count(self.image_token_id) + + # 6. Do the same for overflowing pixel values. Isaac truncates images based on `max_length` + # We can't really truncate pixels, so we pass over an image offset mask. Model will crop off + # truncated image pixels at run-time using this mask + image_metadata = None + if image_grid_thw is not None: + batch_size, max_images = image_grid_thw.shape[:2] + image_metadata = torch.zeros((batch_size, max_images, 2), dtype=torch.long) + for batch_idx, image_lengths in enumerate(vision_segment_lengths): + remaining_dropped = overflow_input_ids_per_sample[batch_idx] + for image_idx, length in enumerate(image_lengths): + offset = 0 + if 0 < remaining_dropped < length: + offset = remaining_dropped + length -= offset + remaining_dropped = 0 + elif remaining_dropped >= length: + dropped_length = length + length = 0 + remaining_dropped -= dropped_length + + # Record which suffix of this image's placeholder span survives left truncation. + # The model still encodes the full image and uses this window for both feature gathering and vision RoPE. + image_metadata[batch_idx, image_idx, 0] = offset + image_metadata[batch_idx, image_idx, 1] = length + + data = { + "input_ids": torch.tensor(truncated_input_ids, dtype=torch.long), + "attention_mask": torch.tensor(truncated_attention_mask, dtype=torch.long), + "image_metadata": image_metadata, + **image_inputs, + } + if truncated_offset_mapping is not None: + data["offset_mapping"] = torch.tensor(truncated_offset_mapping, dtype=torch.long) + + if return_mm_token_type_ids: + data["mm_token_type_ids"] = self.create_mm_token_type_ids(data["input_ids"]) + + return BatchFeature(data=data, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + vision_data = {} + if image_sizes is not None: + images_kwargs = dict(IsaacProcessorKwargs._defaults.get("images_kwargs", {})) + images_kwargs.update(kwargs) + + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) + for image_size in image_sizes + ] + pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale") or self.image_processor.pixel_shuffle_scale + num_image_tokens = [num_patches // pixel_shuffle_scale**2 for num_patches in num_image_patches] + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) @property def model_input_names(self): @@ -100,17 +243,15 @@ def model_input_names(self): @staticmethod def _maybe_float(value: str | None) -> float | None: - if value is None: - return None try: return float(value) - except ValueError: + except (ValueError, TypeError): return None @classmethod def _parse_attrs(cls, attr_text: str) -> dict[str, str]: attrs = {} - for match in cls._attr_re.finditer(attr_text or ""): + for match in _attr_re.finditer(attr_text): key = match.group(1) value = match.group(2) or match.group(3) or "" attrs[key] = value @@ -123,11 +264,11 @@ def _parse_point_body( mention: str | None = None, t: str | None = None, ) -> Any: - match = cls._coord_re.search(body) + match = _coord_re.search(body) if not match: raise ValueError(f"Malformed tag: {body!r}") x, y = int(match.group(1)), int(match.group(2)) - return cls._SinglePoint(x=x, y=y, mention=mention, t=cls._maybe_float(t)) + return SinglePoint(x=x, y=y, mention=mention, t=cls._maybe_float(t)) @classmethod def _parse_box_body( @@ -136,13 +277,13 @@ def _parse_box_body( mention: str | None = None, t: str | None = None, ) -> Any: - coords = list(cls._coord_re.finditer(body)) + coords = list(_coord_re.finditer(body)) if len(coords) < 2: raise ValueError(f"Malformed tag: {body!r}") - top_left = cls._SinglePoint(x=int(coords[0].group(1)), y=int(coords[0].group(2))) - bottom_right = cls._SinglePoint(x=int(coords[1].group(1)), y=int(coords[1].group(2))) - return cls._BoundingBox(top_left=top_left, bottom_right=bottom_right, mention=mention, t=cls._maybe_float(t)) + top_left = SinglePoint(x=int(coords[0].group(1)), y=int(coords[0].group(2))) + bottom_right = SinglePoint(x=int(coords[1].group(1)), y=int(coords[1].group(2))) + return BoundingBox(top_left=top_left, bottom_right=bottom_right, mention=mention, t=cls._maybe_float(t)) @classmethod def _parse_polygon_body( @@ -151,12 +292,12 @@ def _parse_polygon_body( mention: str | None = None, t: str | None = None, ) -> Any: - coords = list(cls._coord_re.finditer(body)) + coords = list(_coord_re.finditer(body)) if len(coords) < 3: raise ValueError(f"Malformed tag: {body!r}") - points = tuple(cls._SinglePoint(x=int(coord.group(1)), y=int(coord.group(2))) for coord in coords) - return cls._Polygon(points=points, mention=mention, t=cls._maybe_float(t)) + points = tuple(SinglePoint(x=int(coord.group(1)), y=int(coord.group(2))) for coord in coords) + return Polygon(points=points, mention=mention, t=cls._maybe_float(t)) @classmethod def clean_text_and_extract_points( @@ -165,7 +306,7 @@ def clean_text_and_extract_points( expected: str | None = None, ) -> tuple[str, list[Any]]: results: list[Any] = [] - for match in cls._point_box_or_polygon_tag.finditer(text or ""): + for match in _point_box_or_polygon_tag.finditer(text): tag = match.group("tag").lower() attrs = cls._parse_attrs(match.group("attrs")) mention = attrs.get("mention") @@ -183,25 +324,9 @@ def clean_text_and_extract_points( continue results.append(cls._parse_polygon_body(match.group("body"), mention=mention, t=t)) - clean_text = re.sub(r"\s+", " ", cls._point_box_or_polygon_tag.sub("", text or "")).strip() + clean_text = re.sub(r"\s+", " ", _point_box_or_polygon_tag.sub("", text or "")).strip() return clean_text, results - def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): - vision_data = {} - if image_sizes is not None: - images_kwargs = dict(IsaacProcessorKwargs._defaults.get("images_kwargs", {})) - images_kwargs.update(kwargs) - - num_image_patches = [ - self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) - for image_size in image_sizes - ] - pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale") or self.image_processor.pixel_shuffle_scale - num_image_tokens = [num_patches // pixel_shuffle_scale**2 for num_patches in num_image_patches] - vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) - - return MultiModalData(**vision_data) - def post_process_generation( self, text: str, @@ -226,160 +351,5 @@ def post_process_image_text_to_text( for text in generated_texts ] - def __call__( - self, - text: str | list[str], - images: ImageInput | None = None, - return_tensors: str | TensorType | None = TensorType.PYTORCH, - **kwargs, - ) -> BatchFeature: - output_kwargs = self._merge_kwargs( - IsaacProcessorKwargs, - tokenizer_init_kwargs=self.tokenizer.init_kwargs, - **kwargs, - ) - text_kwargs = output_kwargs["text_kwargs"] - truncation = text_kwargs.pop("truncation", None) - max_length = text_kwargs.pop("max_length", None) - padding = text_kwargs.pop("padding", True) - padding_side = text_kwargs.pop("padding_side", "left") - return_attention_mask = text_kwargs.pop("return_attention_mask", True) - return_mm_token_type_ids = text_kwargs.pop("return_mm_token_type_ids", True) - pad_to_multiple_of = text_kwargs.pop("pad_to_multiple_of", None) - text_kwargs.pop("return_tensors", None) - text_kwargs.pop("return_overflowing_tokens", None) - text_kwargs.setdefault("add_special_tokens", False) - - texts = [text] if isinstance(text, str) else text - if images is None: - batched_images = [[] for _ in texts] - else: - fetched_images = self.image_processor.fetch_images(images) - batched_images = make_nested_list_of_images(fetched_images) - if len(batched_images) != len(texts): - num_images_in_text = [text_value.count(self.image_token) for text_value in texts] - num_images_in_images = [len(sample_images) for sample_images in batched_images] - add_message = "" - if sum(num_images_in_text) == sum(num_images_in_images): - add_message = " Make sure to pass your images as a nested list, where each sub-list holds images for one text sample." - - raise ValueError( - f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(texts)}).{add_message}" - ) - - pairs = list(zip(texts, batched_images, strict=True)) - image_inputs = self.image_processor(images=batched_images, return_tensors=TensorType.PYTORCH) - image_grid_thw = image_inputs["image_grid_thw"] - image_metadata = None - vision_segment_lengths = None - if image_grid_thw is not None: - batch_size, max_images = image_grid_thw.shape[:2] - image_metadata = torch.zeros((batch_size, max_images, 2), dtype=torch.long) - grid_heights = image_grid_thw[..., 1] - grid_widths = image_grid_thw[..., 2] - vision_segment_lengths = (grid_heights // self.image_processor.pixel_shuffle_scale) * ( - grid_widths // self.image_processor.pixel_shuffle_scale - ) - - expanded_texts = [] - expected_image_lengths_per_sample = [] - - for batch_idx, (text_value, sample_images) in enumerate(pairs): - segments = text_value.split(self.image_token) - num_images = len(segments) - 1 - num_provided_images = len(sample_images) - if num_images != num_provided_images: - raise ValueError( - f"IsaacProcessor expects one image per image token, got {num_images} tokens and {num_provided_images} images in sample with text {text_value} " - ) - - expected_image_lengths = [] - expanded_text_parts = [segments[0]] - for image_idx in range(num_images): - segment_length = int(vision_segment_lengths[batch_idx, image_idx].item()) - expected_image_lengths.append(segment_length) - expanded_text_parts.append(self.image_token * segment_length) - expanded_text_parts.append(segments[image_idx + 1]) - - expected_image_lengths_per_sample.append(expected_image_lengths) - expanded_texts.append("".join(expanded_text_parts)) - - effective_max_length = self.max_sequence_length - if max_length is not None and (truncation or padding == "max_length"): - effective_max_length = max_length - - self.tokenizer.truncation_side = "left" - self.tokenizer.padding_side = padding_side - tokenized_text_inputs = self.tokenizer( - expanded_texts, - truncation=True, - max_length=effective_max_length, - padding=padding, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=True, - stride=0, - return_tensors=None, - **text_kwargs, - ) - - kept_input_ids_per_sample: list[list[int] | None] = [None] * len(texts) - overflow_input_ids_per_sample: list[list[list[int]]] = [[] for _ in texts] - overflow_to_sample_mapping = tokenized_text_inputs.get("overflow_to_sample_mapping") - if overflow_to_sample_mapping is None: - overflow_to_sample_mapping = list(range(len(tokenized_text_inputs["input_ids"]))) - - for row_input_ids, sample_idx in zip( - tokenized_text_inputs["input_ids"], overflow_to_sample_mapping, strict=True - ): - sample_idx = int(sample_idx) - if kept_input_ids_per_sample[sample_idx] is None: - kept_input_ids_per_sample[sample_idx] = row_input_ids - else: - overflow_input_ids_per_sample[sample_idx].append(row_input_ids) - - for batch_idx, expected_image_lengths in enumerate(expected_image_lengths_per_sample): - dropped_image_tokens = sum( - overflow_input_ids.count(self.image_token_id) - for overflow_input_ids in overflow_input_ids_per_sample[batch_idx] - ) - - remaining_dropped = dropped_image_tokens - for image_idx, expected_length in enumerate(expected_image_lengths): - if remaining_dropped <= 0: - offset = 0 - length = expected_length - elif remaining_dropped < expected_length: - offset = remaining_dropped - length = expected_length - offset - remaining_dropped = 0 - else: - offset = 0 - length = 0 - remaining_dropped -= expected_length - - # Record which suffix of this image's placeholder span survives left truncation. - # The model still encodes the full image and uses this window for both feature gathering and vision RoPE. - image_metadata[batch_idx, image_idx, 0] = offset - image_metadata[batch_idx, image_idx, 1] = length - - input_ids = torch.tensor(kept_input_ids_per_sample, dtype=torch.long) - attention_mask = input_ids.ne(self.pad_token_id).to(dtype=torch.long) - - data = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "pixel_values": image_inputs["pixel_values"], - "image_grid_thw": image_grid_thw, - "image_metadata": image_metadata, - } - if return_mm_token_type_ids: - data["mm_token_type_ids"] = input_ids.eq(self.image_token_id).to(dtype=torch.long) - - return BatchFeature( - data=data, - tensor_type=return_tensors, - ) - __all__ = ["IsaacProcessor"] diff --git a/tests/models/isaac/test_image_processing_isaac.py b/tests/models/isaac/test_image_processing_isaac.py index b03ec3337972..7c7627805a5f 100644 --- a/tests/models/isaac/test_image_processing_isaac.py +++ b/tests/models/isaac/test_image_processing_isaac.py @@ -16,16 +16,10 @@ import unittest import numpy as np -import pytest -from transformers.testing_utils import ( - require_torch, - require_torch_accelerator, - require_vision, - slow, - torch_device, -) -from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available +from transformers.models.isaac.image_processing_isaac import get_image_size_for_max_num_patches +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -36,9 +30,6 @@ if is_vision_available(): from PIL import Image -if is_torchvision_available(): - from transformers.models.isaac.image_processing_isaac import IsaacImageProcessor - def _make_dummy_image(size=(32, 32), color=(255, 0, 0)): return Image.new("RGB", size, color=color) @@ -117,14 +108,38 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F ) return [[image] for image in images] + def expected_output_image_shape(self, images): + max_images = 0 + max_patches = 0 + for sample_images in images: + if not isinstance(sample_images, (list, tuple)): + sample_images = [sample_images] + + max_images = max(max_images, len(sample_images)) + for image in sample_images: + if isinstance(image, Image.Image): + width, height = image.size + elif isinstance(image, np.ndarray): + height, width = image.shape[:2] + else: + height, width = image.shape[-2:] + + target_height, target_width = get_image_size_for_max_num_patches( + image_height=height, + image_width=width, + patch_size=self.patch_size, + max_num_patches=self.max_num_patches, + min_num_patches=self.min_num_patches, + pixel_shuffle_scale=self.pixel_shuffle_scale, + ) + max_patches = max(max_patches, (target_height // self.patch_size) * (target_width // self.patch_size)) + + return (max_images, max_patches, self.patch_dim) + @require_torch @require_vision class IsaacImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): - image_processing_class = None - fast_image_processing_class = IsaacImageProcessor if is_torchvision_available() else None - test_slow_image_processor = False - def setUp(self): super().setUp() self.image_processor_tester = IsaacImageProcessingTester(self) @@ -133,153 +148,86 @@ def setUp(self): def image_processor_dict(self): return self.image_processor_tester.prepare_image_processor_dict() - def _assert_output_contract( - self, - encoding, - *, - expected_batch_size=None, - expected_max_images=None, - expected_patch_dim=None, - ): - self.assertEqual(set(encoding.keys()), {"pixel_values", "image_grid_thw"}) - - pixel_values = encoding["pixel_values"] - image_grid_thw = encoding["image_grid_thw"] - - if expected_batch_size is None: - self.assertIsNone(pixel_values) - self.assertIsNone(image_grid_thw) - return - - self.assertIsNotNone(pixel_values) - self.assertIsNotNone(image_grid_thw) - self.assertEqual(pixel_values.dtype, torch.float32) - self.assertEqual(image_grid_thw.dtype, torch.long) - - if expected_batch_size is not None: - self.assertEqual(pixel_values.shape[0], expected_batch_size) - self.assertEqual(image_grid_thw.shape[0], expected_batch_size) - if expected_max_images is not None: - self.assertEqual(pixel_values.shape[1], expected_max_images) - self.assertEqual(image_grid_thw.shape[1], expected_max_images) - if expected_patch_dim is not None: - self.assertEqual(pixel_values.shape[-1], expected_patch_dim) - - self.assertEqual(tuple(image_grid_thw.shape), (pixel_values.shape[0], pixel_values.shape[1], 3)) - - active_slots = image_grid_thw[..., 0].eq(1) - self.assertTrue(torch.all(image_grid_thw[~active_slots].eq(0))) - self.assertTrue(torch.all(image_grid_thw[active_slots, 1:] > 0)) - - expected_patch_counts = image_grid_thw[..., 1] * image_grid_thw[..., 2] - token_positions = torch.arange(pixel_values.shape[2], device=pixel_values.device).view(1, 1, -1) - image_patch_attention_mask = active_slots.unsqueeze(-1) & token_positions.lt( - expected_patch_counts.unsqueeze(-1) - ) - - padded_patch_rows = pixel_values[~image_patch_attention_mask] - if padded_patch_rows.numel() > 0: - self.assertTrue(torch.all(padded_patch_rows == 0)) - - def _assert_encoding_close(self, eager_encoding, compiled_encoding): - torch.testing.assert_close( - eager_encoding["pixel_values"], - compiled_encoding["pixel_values"], - atol=1e-4, - rtol=1e-4, - ) - torch.testing.assert_close(eager_encoding["image_grid_thw"], compiled_encoding["image_grid_thw"]) - - def test_image_processor_properties(self): - for image_processing_class in self.image_processing_classes.values(): - image_processor = image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processor, "do_resize")) - self.assertTrue(hasattr(image_processor, "do_rescale")) - self.assertTrue(hasattr(image_processor, "rescale_factor")) - self.assertTrue(hasattr(image_processor, "do_normalize")) - self.assertTrue(hasattr(image_processor, "image_mean")) - self.assertTrue(hasattr(image_processor, "image_std")) - self.assertTrue(hasattr(image_processor, "patch_size")) - self.assertTrue(hasattr(image_processor, "max_num_patches")) - self.assertTrue(hasattr(image_processor, "min_num_patches")) - self.assertTrue(hasattr(image_processor, "pixel_shuffle_scale")) - self.assertTrue(hasattr(image_processor, "do_convert_rgb")) - def test_call_pil(self): for image_processing_class in self.image_processing_classes.values(): - image_processor = image_processing_class(**self.image_processor_dict) + image_processing = image_processing_class(**self.image_processor_dict) image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) - - for image in image_inputs: - self.assertIsInstance(image[0], Image.Image) - - single_output = image_processor(image_inputs[0], return_tensors="pt") - self._assert_output_contract( - single_output, - expected_batch_size=1, - expected_max_images=1, - expected_patch_dim=self.image_processor_tester.patch_dim, - ) - - batched_output = image_processor(image_inputs, return_tensors="pt") - self._assert_output_contract( - batched_output, - expected_batch_size=self.image_processor_tester.batch_size, - expected_max_images=1, - expected_patch_dim=self.image_processor_tester.patch_dim, + for sample_images in image_inputs: + self.assertEqual(len(sample_images), 1) + self.assertIsInstance(sample_images[0], Image.Image) + + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual( + tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) ) def test_call_numpy(self): for image_processing_class in self.image_processing_classes.values(): - image_processor = image_processing_class(**self.image_processor_dict) + image_processing = image_processing_class(**self.image_processor_dict) image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) - - for image in image_inputs: - self.assertIsInstance(image[0], np.ndarray) - - single_output = image_processor(image_inputs[0], return_tensors="pt") - self._assert_output_contract( - single_output, - expected_batch_size=1, - expected_max_images=1, - expected_patch_dim=self.image_processor_tester.patch_dim, - ) - - batched_output = image_processor(image_inputs, return_tensors="pt") - self._assert_output_contract( - batched_output, - expected_batch_size=self.image_processor_tester.batch_size, - expected_max_images=1, - expected_patch_dim=self.image_processor_tester.patch_dim, + for sample_images in image_inputs: + self.assertEqual(len(sample_images), 1) + self.assertIsInstance(sample_images[0], np.ndarray) + + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual( + tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) ) def test_call_pytorch(self): for image_processing_class in self.image_processing_classes.values(): - image_processor = image_processing_class(**self.image_processor_dict) + image_processing = image_processing_class(**self.image_processor_dict) image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + for sample_images in image_inputs: + self.assertEqual(len(sample_images), 1) + self.assertIsInstance(sample_images[0], torch.Tensor) + + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual( + tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) + ) - for image in image_inputs: - self.assertIsInstance(image[0], torch.Tensor) + @unittest.skip(reason="Isaac image processor 4-channel coverage is not defined") + def test_call_numpy_4_channels(self): + pass - single_output = image_processor(image_inputs[0], return_tensors="pt") - self._assert_output_contract( - single_output, - expected_batch_size=1, - expected_max_images=1, - expected_patch_dim=self.image_processor_tester.patch_dim, + def test_flat_list_is_single_multi_image_sample(self): + for image_processing_class in self.image_processing_classes.values(): + image_processor = image_processing_class( + **{ + **self.image_processor_dict, + "do_resize": False, + "patch_size": 16, + "max_num_patches": 64, + "min_num_patches": 1, + "pixel_shuffle_scale": 1, + } ) + image_inputs = [ + _make_dummy_image(size=(32, 32), color=(255, 0, 0)), + _make_dummy_image(size=(32, 32), color=(0, 255, 0)), + ] - batched_output = image_processor(image_inputs, return_tensors="pt") - self._assert_output_contract( - batched_output, - expected_batch_size=self.image_processor_tester.batch_size, - expected_max_images=1, - expected_patch_dim=self.image_processor_tester.patch_dim, - ) + encoding = image_processor(image_inputs, return_tensors="pt") + self.assertEqual(tuple(encoding["pixel_values"].shape), (1, 2, 4, 768)) - @unittest.skip(reason="Isaac image processor 4-channel coverage is not defined yet") - def test_call_numpy_4_channels(self): - pass + expected_grids = torch.tensor([[[1, 2, 2], [1, 2, 2]]], dtype=torch.long) + torch.testing.assert_close(encoding["image_grid_thw"], expected_grids) def test_nested_multi_image_batch_preserves_grids_and_padding(self): for image_processing_class in self.image_processing_classes.values(): @@ -302,12 +250,6 @@ def test_nested_multi_image_batch_preserves_grids_and_padding(self): ] encoding = image_processor(image_inputs, return_tensors="pt") - self._assert_output_contract( - encoding, - expected_batch_size=2, - expected_max_images=2, - expected_patch_dim=768, - ) self.assertEqual(tuple(encoding["pixel_values"].shape), (2, 2, 6, 768)) expected_grids = torch.tensor( @@ -319,26 +261,7 @@ def test_nested_multi_image_batch_preserves_grids_and_padding(self): ) torch.testing.assert_close(encoding["image_grid_thw"], expected_grids) - - def test_all_empty_images_returns_none_visual_fields(self): - for image_processing_class in self.image_processing_classes.values(): - image_processor = image_processing_class(**self.image_processor_dict) - encoding = image_processor([[], []], return_tensors="pt") - - self._assert_output_contract(encoding, expected_batch_size=None) - - def test_do_resize_false_requires_patch_divisibility(self): - for image_processing_class in self.image_processing_classes.values(): - image_processor = image_processing_class( - **{ - **self.image_processor_dict, - "do_resize": False, - "patch_size": 16, - } - ) - - with self.assertRaisesRegex(ValueError, "must be divisible by patch_size"): - image_processor([[_make_dummy_image(size=(31, 32))]], return_tensors="pt") + self.assertTrue(torch.all(encoding["pixel_values"][0, 1] == 0)) def test_pixel_shuffle_scale_requires_divisible_token_grid(self): for image_processing_class in self.image_processing_classes.values(): @@ -353,52 +276,3 @@ def test_pixel_shuffle_scale_requires_divisible_token_grid(self): with self.assertRaisesRegex(ValueError, "must be divisible by pixel_shuffle_scale"): image_processor([[_make_dummy_image(size=(32, 16))]], return_tensors="pt") - - def test_cast_dtype_device(self): - for image_processing_class in self.image_processing_classes.values(): - image_processor = image_processing_class(**self.image_processor_dict) - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) - - encoding = image_processor(image_inputs, return_tensors="pt") - self.assertEqual(encoding["pixel_values"].device, torch.device("cpu")) - self.assertEqual(encoding["pixel_values"].dtype, torch.float32) - self.assertEqual(encoding["image_grid_thw"].dtype, torch.long) - - encoding = image_processor(image_inputs, return_tensors="pt").to(torch.float16) - self.assertEqual(encoding["pixel_values"].device, torch.device("cpu")) - self.assertEqual(encoding["pixel_values"].dtype, torch.float16) - self.assertEqual(encoding["image_grid_thw"].dtype, torch.long) - - encoding = image_processor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16) - self.assertEqual(encoding["pixel_values"].device, torch.device("cpu")) - self.assertEqual(encoding["pixel_values"].dtype, torch.bfloat16) - self.assertEqual(encoding["image_grid_thw"].dtype, torch.long) - - with self.assertRaises(TypeError): - _ = image_processor(image_inputs, return_tensors="pt").to(torch.bfloat16, "cpu") - - encoding = image_processor(image_inputs, return_tensors="pt") - encoding.update({"input_ids": torch.LongTensor([[1, 2, 3], [4, 5, 6]])}) - encoding = encoding.to(torch.float16) - - self.assertEqual(encoding["pixel_values"].device, torch.device("cpu")) - self.assertEqual(encoding["pixel_values"].dtype, torch.float16) - self.assertEqual(encoding["image_grid_thw"].dtype, torch.long) - self.assertEqual(encoding["input_ids"].dtype, torch.long) - - @slow - @require_torch_accelerator - @require_vision - @pytest.mark.torch_compile_test - def test_can_compile_torchvision_backend(self): - if "torchvision" not in self.image_processing_classes: - self.skipTest("Skipping compilation test as torchvision backend is not available") - - torch.compiler.reset() - input_image = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8) - image_processor = self.image_processing_classes["torchvision"](**self.image_processor_dict) - output_eager = image_processor(input_image, device=torch_device, return_tensors="pt") - - image_processor = torch.compile(image_processor, mode="reduce-overhead") - output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt") - self._assert_encoding_close(output_eager, output_compiled) diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py index b43c2e183ca3..ccfb83411f66 100644 --- a/tests/models/isaac/test_modeling_isaac.py +++ b/tests/models/isaac/test_modeling_isaac.py @@ -17,36 +17,25 @@ import base64 import io import os -import re import unittest from functools import lru_cache from pathlib import Path -from types import SimpleNamespace -from unittest.mock import patch import pytest from huggingface_hub import is_offline_mode -from tests.generation.test_utils import GenerationTesterMixin +from tests.generation.test_utils import ( + GenerationTesterMixin, +) from tests.test_configuration_common import ConfigTester from tests.test_pipeline_mixin import PipelineTesterMixin from transformers import ( IsaacConfig, IsaacForConditionalGeneration, IsaacModel, - PythonBackend, is_torch_available, ) -from transformers.image_utils import load_image -from transformers.masking_utils import create_bidirectional_mask -from transformers.models.isaac.image_processing_isaac import IsaacImageProcessor -from transformers.models.isaac.modeling_isaac import ( - IsaacVisionAttention, - IsaacVisionConfig, - pixel_shuffle_padded, -) from transformers.models.isaac.processing_isaac import IsaacProcessor -from transformers.pipelines import ImageTextToTextPipeline from transformers.testing_utils import ( require_flash_attn, require_torch, @@ -80,58 +69,6 @@ ISAAC_IMAGE_TOKEN = "<|image_pad|>" -def document_to_messages( - document: list[dict], image_token: str = ISAAC_IMAGE_TOKEN -) -> tuple[list[dict[str, str]], list[Image]]: - """ - Convert a Document to messages format compatible with chat templates. - Each content turn creates its own message entry. - - Args: - document: list of dicts containing Text and/or Image content - image_token: Token to use for image placeholder - - Returns: - Tuple of (messages, images) where messages is a list of dicts with 'role' and 'content' - """ - messages = [] - images = [] - - for item in document: - itype = item.get("type") - if itype == "text": - content = item.get("content") - if content: - messages.append( - { - "role": item.get("role", "user"), - "content": content, - } - ) - elif itype == "image": - content = item.get("content") - if content: - img = load_image(content) - images.append(img) - messages.append( - { - "role": item.get("role", "user"), - "content": image_token, - } - ) - - return messages, images - - -def strip_trailing_stop_string(text: str, stop_strings: list[str] | tuple[str, ...] | None = None) -> str: - if stop_strings is not None: - for stop_string in stop_strings: - if text.endswith(stop_string): - text = text[: -len(stop_string)] - break - return re.sub(r"^\n{2,}", "\n", text) - - def compute_logits_statistics(tensor: torch.Tensor) -> dict[str, object]: """ Summarize logits with simple statistics that are stable across minor @@ -156,115 +93,6 @@ def _rounded(value: torch.Tensor | float) -> float: } -def infer_pad_from_tail(sequence: torch.Tensor) -> tuple[int | None, int]: - """ - Infer the pad value used in a 1D sequence by scanning the repeated tail. - - Returns (pad_value or None if no padding detected, last_nonpad_index). - """ - - if sequence.ndim != 1: - raise ValueError("sequence must be 1D") - - pad_candidate = sequence[-1].item() - idx = sequence.shape[0] - 1 - while idx >= 0 and sequence[idx].item() == pad_candidate: - idx -= 1 - - if idx == sequence.shape[0] - 1: - return None, idx - if idx < 0: - return pad_candidate, -1 - return pad_candidate, idx - - -def _pixel_shuffle_reference(x: torch.Tensor, token_grids: torch.Tensor, scale_factor: int): - num_images, _, embed_dim = x.shape - output_lengths = [] - for i in range(num_images): - h, w = token_grids[i].tolist() - output_lengths.append((h // scale_factor) * (w // scale_factor)) - - max_output_tokens = max(output_lengths, default=0) - output_dim = embed_dim * scale_factor * scale_factor - out = x.new_zeros((num_images, max_output_tokens, output_dim)) - out_mask = torch.zeros((num_images, max_output_tokens), device=x.device, dtype=torch.long) - - for i in range(num_images): - h, w = token_grids[i].tolist() - if h == 0 or w == 0: - continue - seq_len = h * w - tokens = x[i, :seq_len] - hb, wb = h // scale_factor, w // scale_factor - t = tokens.view(h, w, embed_dim).permute(2, 0, 1).unsqueeze(0) - t = torch.nn.functional.pixel_unshuffle(t, downscale_factor=scale_factor) - t = t.view(1, embed_dim, scale_factor, scale_factor, hb, wb) - t = t.permute(0, 4, 5, 2, 3, 1).contiguous().view(hb * wb, output_dim) - out[i, : hb * wb] = t - out_mask[i, : hb * wb] = 1 - - return out, out_mask, torch.tensor(output_lengths, device=x.device, dtype=torch.long) - - -def create_isaac_processor( - tokenizer, - isaac_config, - *, - image_processor=None, - **overrides, -): - """Helper to construct IsaacProcessor without requiring an IsaacConfig instance.""" - vision_config = isaac_config.vision_config - params = { - "max_sequence_length": isaac_config.max_sequence_length, - "vision_patch_size": vision_config.patch_size, - "vision_max_num_patches": vision_config.num_patches, - "vision_min_num_patches": getattr(vision_config, "min_num_patches", None), - "pixel_shuffle_scale": vision_config.pixel_shuffle_scale_factor, - "rescale_factor": isaac_config.vision_rescale_factor, - } - params.update(overrides) - - processor_image = image_processor - if processor_image is None: - image_processor_kwargs = { - "patch_size": params["vision_patch_size"], - "max_num_patches": params["vision_max_num_patches"], - "min_num_patches": params["vision_min_num_patches"], - "pixel_shuffle_scale": params["pixel_shuffle_scale"], - "rescale_factor": params["rescale_factor"], - } - if "image_mean" in params: - image_processor_kwargs["image_mean"] = params["image_mean"] - if "image_std" in params: - image_processor_kwargs["image_std"] = params["image_std"] - processor_image = IsaacImageProcessor(**image_processor_kwargs) - processor_params = { - "max_sequence_length": isaac_config.max_sequence_length, - } - - return IsaacProcessor( - image_processor=processor_image, - tokenizer=tokenizer, - **processor_params, - ) - - -def to_model_multimodal_inputs(processor_output, device): - keys = ( - "mm_token_type_ids", - "pixel_values", - "image_grid_thw", - "image_metadata", - ) - return { - key: (value.to(device) if isinstance(value, torch.Tensor) else value) - for key, value in processor_output.items() - if key in keys - } - - def pack_image_inputs(pixel_values, image_token_grids, image_token_offsets=None, image_token_lengths=None): batch_size, max_images, _, _ = pixel_values.shape device = pixel_values.device @@ -320,68 +148,6 @@ def _reference_checkpoint_or_skip(): return MODEL_ID -class SimpleIsaacTokenizer(PythonBackend): - vocab_files_names = {} - model_input_names = ["input_ids"] - - def __init__(self): - self._vocab = { - "": 0, - "": 1, - "": 2, - "": 3, - ISAAC_IMAGE_TOKEN: 4, - } - self._ids_to_tokens = {idx: tok for tok, idx in self._vocab.items()} - super().__init__( - bos_token="", - eos_token="", - pad_token="", - unk_token="", - extra_special_tokens={"image_pad_token": ISAAC_IMAGE_TOKEN}, - model_max_length=512, - ) - self.image_pad_token = ISAAC_IMAGE_TOKEN - self.image_pad_token_id = self._vocab[self.image_pad_token] - self.chat_template = ( - "{% for message in messages %}" - "{{ message['role'] }}: {{ message['content'] | trim }}\n" - "{% endfor %}" - "{% if add_generation_prompt %}assistant:{% endif %}" - ) - - def get_vocab(self): - return dict(self._vocab) - - def _tokenize(self, text): - clean = text.replace("\n", " ").strip() - if not clean: - return [] - return [token for token in clean.split(" ") if token] - - def _convert_token_to_id(self, token): - if token not in self._vocab: - next_id = len(self._vocab) - self._vocab[token] = next_id - self._ids_to_tokens[next_id] = token - return self._vocab[token] - - def _convert_id_to_token(self, index): - return self._ids_to_tokens.get(index, self.unk_token) - - @property - def vocab_size(self) -> int: - return len(self._vocab) - - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - if token_ids_1 is not None: - token_ids_0 = token_ids_0 + token_ids_1 - return [self.bos_token_id] + list(token_ids_0) + [self.eos_token_id] - - def save_vocabulary(self, save_directory, filename_prefix=None): - return () - - class IsaacModelTester: def __init__( self, @@ -547,383 +313,6 @@ def prepare_config_and_inputs_for_generate(self, batch_size=2): return config, filtered_inputs_dict - @unittest.skip(reason="Assisted decoding not supported; Qwen3 backbone does not implement returning attentions") - def test_assisted_decoding_matches_greedy_search_0_random(self): - pass - - @unittest.skip(reason="Assisted decoding not supported; Qwen3 backbone does not implement returning attentions") - def test_assisted_decoding_matches_greedy_search_1_same(self): - pass - - @unittest.skip(reason="Unsupported") - def test_flash_attn_kernels_inference_equivalence(self): - pass - - @unittest.skip(reason="Assisted decoding not supported; Qwen3 backbone does not implement returning attentions") - def test_assisted_decoding_sample(self): - pass - - @unittest.skip(reason="Prompt lookup decoding not supported; Qwen3 backbone does not return attentions") - def test_prompt_lookup_decoding_matches_greedy_search(self): - pass - - @unittest.skip(reason="Output attentions not supported") - def test_retain_grad_hidden_states_attentions(self): - pass - - def test_text_only_forward_ignores_metadata_without_vision_patches(self): - config, input_ids, attention_mask, _ = self.model_tester.prepare_config_and_inputs() - model = IsaacModel(config) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - reference = model(input_ids=input_ids, attention_mask=attention_mask) - - with patch.object(model, "get_image_features", wraps=model.get_image_features) as mock_get_image_features: - with torch.no_grad(): - result = model( - input_ids=input_ids, - attention_mask=attention_mask, - image_grid_thw=None, - image_metadata=None, - ) - - mock_get_image_features.assert_not_called() - torch.testing.assert_close(result.last_hidden_state, reference.last_hidden_state) - - def test_image_text_to_text_pipeline_supports_text_only_inputs(self): - config = self.model_tester.get_config() - model = IsaacForConditionalGeneration(config).to(torch_device).eval() - processor = create_isaac_processor(SimpleIsaacTokenizer(), config) - pipe = ImageTextToTextPipeline(model=model, processor=processor, max_new_tokens=4) - - outputs = pipe(text="What is two plus two?", return_full_text=False) - - self.assertEqual(len(outputs), 1) - self.assertEqual(outputs[0]["input_text"], "What is two plus two?") - self.assertIsInstance(outputs[0]["generated_text"], str) - - def test_get_image_features_pooler_output_is_scatter_ready(self): - config = self.model_tester.get_config() - model = IsaacModel(config) - model.to(torch_device) - model.eval() - - patch_size = self.model_tester.vision_config["patch_size"] - patch_dim = self.model_tester.vision_config["num_channels"] * patch_size * patch_size - pixel_values = torch.randn((2, 2, 4, patch_dim), device=torch_device, dtype=torch.float32) - image_token_grids = torch.tensor( - [[[2, 2], [2, 2]], [[2, 2], [0, 0]]], - device=torch_device, - dtype=torch.long, - ) - image_token_offsets = torch.tensor([[1, 0], [2, 0]], device=torch_device, dtype=torch.long) - image_token_lengths = torch.tensor([[2, 1], [1, 0]], device=torch_device, dtype=torch.long) - pixel_values, image_grid_thw, image_metadata = pack_image_inputs( - pixel_values=pixel_values, - image_token_grids=image_token_grids, - image_token_offsets=image_token_offsets, - image_token_lengths=image_token_lengths, - ) - - with torch.no_grad(): - outputs = model.get_image_features( - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - image_metadata=image_metadata, - return_dict=True, - ) - - expected = torch.cat( - ( - outputs.last_hidden_state[0, 1:3], - outputs.last_hidden_state[1, 0:1], - outputs.last_hidden_state[2, 2:3], - ), - dim=0, - ) - pooled_output = torch.cat(outputs.pooler_output, dim=0) - - self.assertEqual(pooled_output.ndim, 2) - torch.testing.assert_close(pooled_output, expected) - - def test_get_rope_index_batch_major_skips_padded_and_fully_truncated_slots(self): - config = self.model_tester.get_config() - model = IsaacModel(config).to(torch_device).eval() - - input_ids = torch.zeros((2, 8), device=torch_device, dtype=torch.long) - attention_mask = torch.ones_like(input_ids) - mm_token_type_ids = torch.tensor( - [ - [0, 0, 1, 1, 0, 1, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0], - ], - device=torch_device, - dtype=torch.long, - ) - image_grid_thw = torch.tensor( - [ - [[1, 2, 2], [1, 2, 2], [1, 2, 2]], - [[1, 2, 2], [0, 0, 0], [0, 0, 0]], - ], - device=torch_device, - dtype=torch.long, - ) - image_metadata = torch.tensor( - [ - [[1, 2], [0, 0], [2, 1]], - [[0, 1], [0, 0], [0, 0]], - ], - device=torch_device, - dtype=torch.long, - ) - - position_ids, rope_deltas = model.get_rope_index( - input_ids=input_ids, - mm_token_type_ids=mm_token_type_ids, - image_grid_thw=image_grid_thw, - image_metadata=image_metadata, - attention_mask=attention_mask, - ) - - expected_sample0 = torch.tensor( - [ - [0, 1, 2, 2, 3, 4, 5, 6], - [0, 1, 0, 1, 3, 1, 5, 6], - [0, 1, 1, 0, 3, 0, 5, 6], - ], - device=torch_device, - dtype=torch.long, - ) - expected_sample1 = torch.tensor( - [ - [0, 1, 2, 3, 4, 5, 6, 7], - [0, 0, 2, 3, 4, 5, 6, 7], - [0, 0, 2, 3, 4, 5, 6, 7], - ], - device=torch_device, - dtype=torch.long, - ) - - torch.testing.assert_close(position_ids[:, 0], expected_sample0) - torch.testing.assert_close(position_ids[:, 1], expected_sample1) - torch.testing.assert_close( - rope_deltas, - torch.tensor([[-1], [0]], device=torch_device, dtype=torch.long), - ) - - def test_forward_scatters_batch_major_image_features_in_slot_order(self): - config = self.model_tester.get_config() - model = IsaacModel(config).to(torch_device).eval() - - input_ids = torch.randint( - 0, - config.get_text_config().vocab_size, - (2, 6), - device=torch_device, - dtype=torch.long, - ) - mm_token_type_ids = torch.tensor( - [ - [0, 1, 1, 0, 1, 0], - [0, 0, 0, 0, 0, 0], - ], - device=torch_device, - dtype=torch.long, - ) - patch_size = self.model_tester.vision_config["patch_size"] - patch_dim = self.model_tester.vision_config["num_channels"] * patch_size * patch_size - pixel_values = torch.zeros((2, 2, 4, patch_dim), device=torch_device, dtype=torch.float32) - image_grid_thw = torch.tensor( - [ - [[1, 2, 2], [1, 2, 2]], - [[0, 0, 0], [0, 0, 0]], - ], - device=torch_device, - dtype=torch.long, - ) - image_metadata = torch.tensor( - [ - [[0, 2], [1, 1]], - [[0, 0], [0, 0]], - ], - device=torch_device, - dtype=torch.long, - ) - - hidden_size = config.get_text_config().hidden_size - scattered_features = ( - torch.full((2, hidden_size), 11.0, device=torch_device), - torch.full((1, hidden_size), 22.0, device=torch_device), - ) - captured = {} - - def fake_language_model(**kwargs): - captured["inputs_embeds"] = kwargs["inputs_embeds"].detach().clone() - return SimpleNamespace( - last_hidden_state=kwargs["inputs_embeds"], - past_key_values=None, - hidden_states=None, - attentions=None, - ) - - with patch.object( - model, - "get_image_features", - return_value=SimpleNamespace(pooler_output=scattered_features), - ) as mock_get_image_features: - with patch.object(model, "compute_3d_position_ids", return_value=None): - with patch.object(model.language_model, "forward", side_effect=fake_language_model): - model( - input_ids=input_ids, - mm_token_type_ids=mm_token_type_ids, - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - image_metadata=image_metadata, - ) - - mock_get_image_features.assert_called_once() - call_kwargs = mock_get_image_features.call_args.kwargs - torch.testing.assert_close(call_kwargs["pixel_values"], pixel_values) - torch.testing.assert_close(call_kwargs["image_grid_thw"], image_grid_thw) - torch.testing.assert_close(call_kwargs["image_metadata"], image_metadata) - - scattered = captured["inputs_embeds"][mm_token_type_ids.bool()] - expected = torch.cat(scattered_features, dim=0).to(dtype=scattered.dtype) - torch.testing.assert_close(scattered, expected) - - def test_prepare_position_ids_for_generation_uses_batch_major_rope(self): - config = self.model_tester.get_config() - model = IsaacForConditionalGeneration(config).to(torch_device).eval() - - input_ids = torch.tensor([[4, 5, 6], [7, 8, 9]], device=torch_device, dtype=torch.long) - mm_token_type_ids = torch.tensor([[0, 1, 0], [0, 0, 0]], device=torch_device, dtype=torch.long) - image_grid_thw = torch.tensor( - [ - [[1, 2, 2]], - [[0, 0, 0]], - ], - device=torch_device, - dtype=torch.long, - ) - image_metadata = torch.tensor( - [ - [[0, 1]], - [[0, 0]], - ], - device=torch_device, - dtype=torch.long, - ) - expected_positions = torch.arange(18, device=torch_device, dtype=torch.long).view(3, 2, 3) - expected_deltas = torch.tensor([[0], [1]], device=torch_device, dtype=torch.long) - - with patch.object( - model.model, - "get_rope_index", - return_value=(expected_positions, expected_deltas), - ) as mock_get_rope_index: - position_ids = model._prepare_position_ids_for_generation( - input_ids, - { - "input_ids": input_ids, - "mm_token_type_ids": mm_token_type_ids, - "image_grid_thw": image_grid_thw, - "image_metadata": image_metadata, - "attention_mask": torch.ones_like(input_ids), - }, - ) - - mock_get_rope_index.assert_called_once() - torch.testing.assert_close(position_ids[1:], expected_positions) - torch.testing.assert_close(model.model.rope_deltas, expected_deltas) - - def test_expand_inputs_for_generation_repeats_batch_major_visual_tensors(self): - config = self.model_tester.get_config() - model = IsaacForConditionalGeneration(config).to(torch_device).eval() - - input_ids = torch.tensor([[1, 2], [3, 4]], device=torch_device, dtype=torch.long) - mm_token_type_ids = torch.tensor([[0, 1], [1, 0]], device=torch_device, dtype=torch.long) - pixel_values = torch.arange(2 * 2 * 3 * 4, device=torch_device, dtype=torch.float32).view(2, 2, 3, 4) - image_grid_thw = torch.tensor( - [ - [[1, 2, 2], [0, 0, 0]], - [[1, 2, 2], [1, 2, 2]], - ], - device=torch_device, - dtype=torch.long, - ) - image_metadata = torch.tensor( - [ - [[0, 1], [0, 0]], - [[1, 2], [0, 1]], - ], - device=torch_device, - dtype=torch.long, - ) - - expanded_input_ids, expanded_kwargs = model._expand_inputs_for_generation( - expand_size=2, - input_ids=input_ids, - mm_token_type_ids=mm_token_type_ids, - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - image_metadata=image_metadata, - ) - - torch.testing.assert_close(expanded_input_ids, input_ids.repeat_interleave(2, dim=0)) - torch.testing.assert_close(expanded_kwargs["mm_token_type_ids"], mm_token_type_ids.repeat_interleave(2, dim=0)) - torch.testing.assert_close(expanded_kwargs["pixel_values"], pixel_values.repeat_interleave(2, dim=0)) - torch.testing.assert_close(expanded_kwargs["image_grid_thw"], image_grid_thw.repeat_interleave(2, dim=0)) - torch.testing.assert_close(expanded_kwargs["image_metadata"], image_metadata.repeat_interleave(2, dim=0)) - - def test_for_conditional_generation(self): - config, input_ids, attention_mask, labels = self.model_tester.prepare_config_and_inputs() - model = IsaacForConditionalGeneration(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - result = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - - self.assertEqual( - result.logits.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.vocab_size), - ) - self.assertIsNotNone(result.loss) - - def test_isaac_for_conditional_generation_initialization(self): - config = self.model_tester.get_config() - model = IsaacForConditionalGeneration(config) - model.to(torch_device) - - self.assertTrue(hasattr(model, "model")) - self.assertTrue(hasattr(model, "lm_head")) - self.assertTrue(hasattr(model.model, "visual")) - self.assertTrue(hasattr(model.model, "multimodal_projector")) - - input_vocab_size = model.get_input_embeddings().num_embeddings - output_vocab_size = model.get_output_embeddings().out_features - input_ids = torch.randint(0, input_vocab_size, (1, 10), device=torch_device, dtype=torch.long) - with torch.no_grad(): - outputs = model(input_ids=input_ids, return_dict=True) - self.assertEqual(outputs.logits.shape, (1, 10, output_vocab_size)) - - def test_isaac_for_conditional_generation_loss_and_generate_flag(self): - config = self.model_tester.get_config() - model = IsaacForConditionalGeneration(config).to(torch_device) - self.assertTrue(model.can_generate()) - - batch_size, seq_len = 1, 8 - input_vocab_size = model.get_input_embeddings().num_embeddings - output_vocab_size = model.get_output_embeddings().out_features - input_ids = torch.randint(0, input_vocab_size, (batch_size, seq_len), device=torch_device) - labels = torch.randint(0, output_vocab_size, (batch_size, seq_len), device=torch_device) - with torch.no_grad(): - outputs = model(input_ids=input_ids, labels=labels, return_dict=True) - self.assertIsNotNone(outputs.loss) - self.assertEqual(outputs.loss.ndim, 0) - self.assertEqual(outputs.logits.shape, (batch_size, seq_len, output_vocab_size)) - @pytest.mark.generate def test_left_padding_compatibility(self): _, inputs_dict = self.prepare_config_and_inputs_for_generate() @@ -938,6 +327,18 @@ def test_left_padding_compatibility(self): padded_custom_inputs={"mm_token_type_ids": padded_mm_token_type_ids}, ) + @unittest.skip(reason="Assisted decoding not supported; Qwen3 backbone does not implement returning attentions") + def test_assisted_decoding_matches_greedy_search_0_random(self): + pass + + @unittest.skip(reason="Assisted decoding not supported; Qwen3 backbone does not implement returning attentions") + def test_assisted_decoding_matches_greedy_search_1_same(self): + pass + + @unittest.skip(reason="Unsupported") + def test_flash_attn_kernels_inference_equivalence(self): + pass + @unittest.skip(reason="Isaac is image-only.") def test_get_video_features_output_0(self): pass @@ -959,148 +360,6 @@ def test_get_video_features_attentions(self): pass -@require_torch -class IsaacPixelShufflePaddedTest(unittest.TestCase): - def test_pixel_shuffle_padded_matches_reference_no_attention_mask(self): - x = torch.arange(2 * 16 * 4, device=torch_device, dtype=torch.float32).view(2, 16, 4) - token_grids = torch.tensor([[4, 4], [2, 4]], device=torch_device, dtype=torch.long) - expected_hidden, expected_mask, expected_lengths = _pixel_shuffle_reference(x, token_grids, scale_factor=2) - - hidden = pixel_shuffle_padded(hidden_states=x, token_grids=token_grids, scale_factor=2) - - torch.testing.assert_close(hidden, expected_hidden) - - def test_pixel_shuffle_padded_raises_on_non_divisible_grid(self): - x = torch.randn(1, 15, 8, device=torch_device) - token_grids = torch.tensor([[3, 5]], device=torch_device, dtype=torch.long) - - with pytest.raises(ValueError, match="divisible"): - pixel_shuffle_padded(hidden_states=x, token_grids=token_grids, scale_factor=2) - - def test_pixel_shuffle_padded_zero_grid(self): - x = torch.randn(1, 4, 8, device=torch_device) - token_grids = torch.tensor([[0, 0]], device=torch_device, dtype=torch.long) - - hidden = pixel_shuffle_padded(hidden_states=x, token_grids=token_grids, scale_factor=2) - - self.assertEqual(hidden.shape, (1, 0, 32)) - - -@require_torch -@require_flash_attn -class IsaacAttentionDtypeTest(unittest.TestCase): - def _make_config(self): - return IsaacVisionConfig( - hidden_size=32, - intermediate_size=64, - num_hidden_layers=1, - num_attention_heads=4, - num_channels=3, - num_patches=64, - patch_size=4, - attention_dropout=0.0, - pixel_shuffle_scale_factor=1, - ) - - def _skip_if_no_cuda_bf16(self): - if not torch.cuda.is_available(): - pytest.skip("CUDA required for flash attention dtype/parity tests.") - if not torch.cuda.is_bf16_supported(): - pytest.skip("CUDA bfloat16 support required.") - - def test_flash_attention_matches_weight_dtype_bf16(self): - self._skip_if_no_cuda_bf16() - torch.manual_seed(0) - - device = torch.device("cuda") - config = self._make_config() - config._attn_implementation = "flash_attention_2" - - attn = IsaacVisionAttention(config).to(device=device, dtype=torch.bfloat16).eval() - - hidden_states = torch.randn(2, 4, config.hidden_size, device=device, dtype=torch.bfloat16) - - with torch.no_grad(): - attn_output, _ = attn(hidden_states) - - assert attn_output.dtype == attn.out_proj.weight.dtype - assert attn_output.dtype == hidden_states.dtype - - def test_flash_attention_matches_weight_dtype_bf16_with_padding(self): - self._skip_if_no_cuda_bf16() - torch.manual_seed(0) - - device = torch.device("cuda") - config = self._make_config() - config._attn_implementation = "flash_attention_2" - - attn = IsaacVisionAttention(config).to(device=device, dtype=torch.bfloat16).eval() - - hidden_states = torch.randn(2, 4, config.hidden_size, device=device, dtype=torch.bfloat16) - attention_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]], device=device, dtype=torch.bool) - - with torch.no_grad(): - attn_output, _ = attn(hidden_states, attention_mask=attention_mask) - - assert attn_output.dtype == attn.out_proj.weight.dtype - assert attn_output.dtype == hidden_states.dtype - - def test_flash_attention_matches_weight_dtype_bf16_with_prepared_mask(self): - self._skip_if_no_cuda_bf16() - torch.manual_seed(0) - - device = torch.device("cuda") - config = self._make_config() - config._attn_implementation = "flash_attention_2" - - attn = IsaacVisionAttention(config).to(device=device, dtype=torch.bfloat16).eval() - - hidden_states = torch.randn(1, 5, config.hidden_size, device=device, dtype=torch.bfloat16) - attention_mask = torch.tensor([[1, 1, 1, 0, 0]], device=device, dtype=torch.long) - prepared_attention_mask = create_bidirectional_mask( - config=config, - inputs_embeds=hidden_states, - attention_mask=attention_mask, - ) - - with torch.no_grad(): - attn_output, _ = attn(hidden_states, attention_mask=prepared_attention_mask) - - assert attn_output.dtype == attn.out_proj.weight.dtype - assert attn_output.dtype == hidden_states.dtype - - def test_flash_attention_parity_with_sdpa_bf16(self): - self._skip_if_no_cuda_bf16() - torch.manual_seed(0) - - device = torch.device("cuda") - config_sdpa = self._make_config() - config_sdpa._attn_implementation = "sdpa" - - config_fa2 = self._make_config() - config_fa2._attn_implementation = "flash_attention_2" - - attn_sdpa = IsaacVisionAttention(config_sdpa).to(device=device, dtype=torch.bfloat16).eval() - attn_fa2 = IsaacVisionAttention(config_fa2).to(device=device, dtype=torch.bfloat16).eval() - - # Align weights so the only difference is the backend - attn_fa2.load_state_dict(attn_sdpa.state_dict()) - - hidden_states = torch.randn(2, 4, config_sdpa.hidden_size, device=device, dtype=torch.bfloat16) - - with torch.no_grad(): - out_sdpa, _ = attn_sdpa(hidden_states) - out_fa2, _ = attn_fa2(hidden_states) - - torch.testing.assert_close( - out_fa2.float(), - out_sdpa.float(), - rtol=1e-3, - atol=1e-3, - msg="FlashAttention2 output deviates from SDPA baseline beyond tolerance", - ) - - @require_torch @require_vision @slow @@ -1123,288 +382,251 @@ def setUp(self): self.model = self.model.to(device=self.device, dtype=self.dtype) self.model.eval() - def _generate_from_messages(self, messages, images, num_tokens=None, generate_kwargs=None): - prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() - processor_output = self.processor(text=prompt, images=images or None, return_tensors="pt") - input_ids = processor_output["input_ids"].to(self.device) - attention_mask = processor_output.get("attention_mask") - if attention_mask is None: - pad_id = self.tokenizer.pad_token_id - if pad_id is None: - pad_id = getattr(self.processor, "pad_token_id", 0) - attention_mask = processor_output["input_ids"].ne(pad_id).long() - attention_mask = attention_mask.to(self.device) - prompt_len = input_ids.shape[1] - multimodal_inputs = to_model_multimodal_inputs(processor_output, self.device) - generate_kwargs = {} if generate_kwargs is None else dict(generate_kwargs) - generate_kwargs.setdefault("max_new_tokens", num_tokens or self.max_new_tokens) - generate_kwargs.setdefault("do_sample", False) - generate_kwargs.setdefault("pad_token_id", self.tokenizer.eos_token_id) - generate_kwargs.setdefault("eos_token_id", self.tokenizer.eos_token_id) - generate_kwargs.setdefault("return_dict_in_generate", True) - generate_kwargs.setdefault("output_logits", True) - - with torch.no_grad(): - outputs = self.model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - **multimodal_inputs, - **generate_kwargs, - ) - - generated_ids = outputs.sequences - generated_tail = generated_ids[:, prompt_len:] - generated_text = self.tokenizer.decode(generated_tail[0], skip_special_tokens=True) - return generated_text - def test_generate_from_image_text(self): image = _load_red_dot_image() if image is None: pytest.skip("PIL.Image is required for Isaac generation tests.") - messages = [ - {"role": "user", "content": "Describe this image:"}, - {"role": "user", "content": self.processor.image_token}, + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image:"}, + {"type": "image", "image": image}, + ], + } ] - generated_text = self._generate_from_messages(messages, [image]) + inputs = self.processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + ).to(self.device, dtype=self.dtype) + + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=self.max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + ) + + generated_ids = outputs.sequences[:, inputs["input_ids"].shape[1] :] + generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] expected_fragment = "The image is a close-up photograph of a red cross symbol." assert expected_fragment in generated_text def test_generate_from_text_only(self): - document = [ + conversation = [ { - "type": "text", - "content": "What is the pythogorean theorem?", "role": "user", + "content": [{"type": "text", "text": "What is the pythogorean theorem?"}], } ] - messages, _ = document_to_messages(document) - generated_text = self._generate_from_messages(messages, [], num_tokens=100) - expected_fragmenet = "The Pythagorean theorem is a fundamental principle in geometry that relates the lengths of the sides of a right-angled triangle. Let's break down the theorem step by step:" + inputs = self.processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + ).to(self.device, dtype=self.dtype) + + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=100, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + ) + + generated_ids = outputs.sequences[:, inputs["input_ids"].shape[1] :] + generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + expected_fragmenet = "The Pythagorean theorem is a fundamental principle in geometry that relates the lengths of the sides of a right-angled triangle. Let's break it down step by step:" assert expected_fragmenet in generated_text def test_vqa_from_image(self): - document = [ + conversation = [ { - "type": "image", - "content": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", "role": "user", - }, - { - "type": "text", - "content": "Is it safe to cross the street at this moment?", - "role": "user", - }, + "content": [ + { + "type": "image", + "url": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", + }, + {"type": "text", "text": "Is it safe to cross the street at this moment?"}, + ], + } ] - messages, images = document_to_messages(document, image_token=self.processor.image_token) - generated_text = self._generate_from_messages(messages, images, num_tokens=256) - expected_response = "\nNo, it is not safe to cross the street at this moment. The traffic light for pedestrians is red, indicating that it is not safe to cross." - assert generated_text == expected_response - - def _generate_batch(self, prompts, images_list, num_tokens=None, generate_kwargs=None): - processor_output = self.processor(text=prompts, images=images_list, return_tensors="pt") - input_ids = processor_output["input_ids"] - if input_ids.dim() == 1: - input_ids = input_ids.unsqueeze(0) - - # Use processor-provided attention_mask if available; otherwise fallback. - attention_mask = processor_output.get("attention_mask", None) - if attention_mask is None: - pad_id = self.tokenizer.pad_token_id - if pad_id is None: - pad_id = getattr(self.processor, "pad_token_id", 0) - attention_mask = input_ids.ne(pad_id).long() - - input_ids = input_ids.to(self.device) - attention_mask = attention_mask.to(self.device) - - multimodal_inputs = to_model_multimodal_inputs(processor_output, self.device) - generate_kwargs = {} if generate_kwargs is None else dict(generate_kwargs) - generate_kwargs.setdefault("max_new_tokens", num_tokens or self.max_new_tokens) - generate_kwargs.setdefault("do_sample", False) - generate_kwargs.setdefault("pad_token_id", self.tokenizer.eos_token_id) - generate_kwargs.setdefault("eos_token_id", self.tokenizer.eos_token_id) - generate_kwargs.setdefault("return_dict_in_generate", True) + inputs = self.processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + ).to(self.device, dtype=self.dtype) with torch.no_grad(): outputs = self.model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - **multimodal_inputs, - **generate_kwargs, + **inputs, + max_new_tokens=256, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, ) - sequences = outputs.sequences - generated_texts = [] - for i in range(sequences.shape[0]): - tail_ids = sequences[i, :] # only newly generated tokens - generated_texts.append(self.tokenizer.decode(tail_ids, skip_special_tokens=True)) - return generated_texts + generated_ids = outputs.sequences[:, inputs["input_ids"].shape[1] :] + generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + expected_response = "\nNo, it is not safe to cross the street at this moment. The traffic light for pedestrians is red, indicating that it is not safe to cross." + assert generated_text == expected_response def test_logit_equivalence(self): image = _load_red_dot_image() if image is None: pytest.skip("PIL.Image is required for Isaac generation tests.") - image_bytes = base64.b64decode(RED_DOT_B64) - pil_image = Image.open(io.BytesIO(image_bytes)) - images = [] - images.append(pil_image) - num_tokens = 10 - - messages = [ - {"role": "user", "content": "Describe this image:"}, - {"role": "user", "content": self.processor.image_token}, + + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image:"}, + {"type": "image", "image": image}, + ], + } ] - prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() - processor_output = self.processor(text=prompt, images=images, return_tensors="pt") - input_ids = processor_output["input_ids"] - device = next(self.model.parameters()).device - input_ids = input_ids.to(device) - attention_mask = processor_output.get("attention_mask") - if attention_mask is not None: - attention_mask = attention_mask.to(device) - multimodal_inputs = to_model_multimodal_inputs(processor_output, device) + inputs = self.processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + ).to(self.device, dtype=self.dtype) with torch.no_grad(): outputs = self.model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - **multimodal_inputs, - max_new_tokens=num_tokens or self.max_new_tokens, + **inputs, + max_new_tokens=10, do_sample=False, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, return_dict_in_generate=True, output_logits=True, ) + hf_logits = torch.cat(outputs.logits, dim=0) logit_stats = compute_logits_statistics(hf_logits) expected_logit_stats = { "shape": [10, 151936], "numel": 1519360, - "mean": 0.0879677375, - "std": 2.8382794404, - "min": -12.125, + "mean": 0.0608877803, + "std": 2.8308793244, + "min": -12.0625, "max": 31.0, - "sum": 133654.661714755, - "l2_norm": 3500.2090570868, + "sum": 92510.4578057677, + "l2_norm": 3490.2146142251, } assert logit_stats == expected_logit_stats def test_batched_generation_matches_individual(self): - # Build individual scenarios matching existing integration tests - red_image = _load_red_dot_image() - if red_image is None: + image = _load_red_dot_image() + if image is None: pytest.skip("PIL.Image is required for Isaac generation tests.") - vqa_document = [ - { - "type": "image", - "content": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", - "role": "user", - }, - { - "type": "text", - "content": "Is it safe to cross the street at this moment?", - "role": "user", - }, - ] - - # Text-only - doc_text_only = [{"type": "text", "content": "What is the pythogorean theorem?", "role": "user"}] - messages_text_only, images_text_only = document_to_messages(doc_text_only) - single_text_only = self._generate_from_messages( - messages_text_only, images_text_only, num_tokens=self.max_new_tokens - ) - assert single_text_only, "Text-only single generation is empty" - - # Image + text - messages_image_text = [ - {"role": "user", "content": "Describe this image:"}, - {"role": "user", "content": self.processor.image_token}, + conversations = [ + [ + { + "role": "user", + "content": [{"type": "text", "text": "What is the pythogorean theorem?"}], + } + ], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image:"}, + {"type": "image", "image": image}, + ], + } + ], + [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", + }, + {"type": "text", "text": "Is it safe to cross the street at this moment?"}, + ], + } + ], ] - single_image_text = self._generate_from_messages(messages_image_text, [red_image]) - assert single_image_text, "Image-text single generation is empty" - - # VQA - messages_vqa, images_vqa = document_to_messages(vqa_document, image_token=self.processor.image_token) - single_vqa = self._generate_from_messages(messages_vqa, images_vqa, num_tokens=self.max_new_tokens) - assert single_vqa, "VQA single generation is empty" - - single_texts = [single_text_only, single_image_text, single_vqa] - # Build batch inputs - prompts = [ - self.processor.apply_chat_template(messages_text_only, tokenize=False, add_generation_prompt=True).strip(), + single_inputs = [ self.processor.apply_chat_template( - messages_image_text, tokenize=False, add_generation_prompt=True - ).strip(), - self.processor.apply_chat_template(messages_vqa, tokenize=False, add_generation_prompt=True).strip(), - ] - images_list = [images_text_only, [red_image], images_vqa] - - # Input-level sanity - assert len(prompts) == len(images_list) == 3 - for i, (p, imgs) in enumerate(zip(prompts, images_list)): - expected_tokens = p.count(self.processor.image_token) - num_imgs = len(imgs) - assert expected_tokens == num_imgs, ( - f"sample {i} image token/image mismatch: {expected_tokens} vs {num_imgs}" + conversation, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", ) + for conversation in conversations + ] + batch_inputs = self.processor.apply_chat_template( + conversations, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + processor_kwargs={"padding_side": "left"}, + ) + batch_input_ids = batch_inputs["input_ids"] + max_length = batch_input_ids.shape[1] pad_id = self.tokenizer.pad_token_id if pad_id is None: pad_id = getattr(self.processor, "pad_token_id", 0) - per_sample_outputs = [ - self.processor(text=prompt, images=imgs or None, return_tensors="pt") - for prompt, imgs in zip(prompts, images_list) - ] - batch_outputs = self.processor(text=prompts, images=images_list, return_tensors="pt") - batch_input_ids = batch_outputs["input_ids"] - batch_packed = batch_outputs - - sample_lengths = [output["input_ids"].squeeze(0).shape[0] for output in per_sample_outputs] - max_length = max(sample_lengths) - - for i, (single_output, batch_ids, single_len) in enumerate( - zip(per_sample_outputs, batch_input_ids, sample_lengths) - ): - single_ids = single_output["input_ids"].squeeze(0) - single_packed = single_output - + sample_lengths = [single_input["input_ids"].squeeze(0).shape[0] for single_input in single_inputs] + for i, (single_input, batch_ids, single_len) in enumerate(zip(single_inputs, batch_input_ids, sample_lengths)): + single_ids = single_input["input_ids"].squeeze(0) torch.testing.assert_close(batch_ids[-single_len:], single_ids) - batch_modality_row = batch_packed["mm_token_type_ids"][i] + batch_modality_row = batch_inputs["mm_token_type_ids"][i] expected_modality = torch.full( (max_length,), batch_modality_row[-1].item(), dtype=batch_modality_row.dtype, device=batch_modality_row.device, ) - expected_modality[-single_len:] = single_packed["mm_token_type_ids"].squeeze(0) + expected_modality[-single_len:] = single_input["mm_token_type_ids"].squeeze(0) torch.testing.assert_close(batch_modality_row, expected_modality) - if batch_packed["image_grid_thw"] is not None: - batch_image_mask = batch_packed["image_grid_thw"][i, :, 0].eq(1) + if batch_inputs["image_grid_thw"] is not None: + batch_image_mask = batch_inputs["image_grid_thw"][i, :, 0].eq(1) expected_image_count = int(batch_image_mask.sum().item()) - if single_packed["image_grid_thw"] is None: + if single_input["image_grid_thw"] is None: assert expected_image_count == 0 else: - single_image_mask = single_packed["image_grid_thw"][0, :, 0].eq(1) + single_image_mask = single_input["image_grid_thw"][0, :, 0].eq(1) assert expected_image_count == int(single_image_mask.sum().item()) if expected_image_count > 0: - batch_image_grid_thw = batch_packed["image_grid_thw"][i, batch_image_mask] - single_image_grid_thw = single_packed["image_grid_thw"][0, single_image_mask] - batch_image_metadata = batch_packed["image_metadata"][i, batch_image_mask] - single_image_metadata = single_packed["image_metadata"][0, single_image_mask] + batch_image_grid_thw = batch_inputs["image_grid_thw"][i, batch_image_mask] + single_image_grid_thw = single_input["image_grid_thw"][0, single_image_mask] + batch_image_metadata = batch_inputs["image_metadata"][i, batch_image_mask] + single_image_metadata = single_input["image_metadata"][0, single_image_mask] torch.testing.assert_close(batch_image_grid_thw, single_image_grid_thw) torch.testing.assert_close(batch_image_metadata, single_image_metadata) for batch_pixel_values, single_pixel_values, grid_thw in zip( - batch_packed["pixel_values"][i, batch_image_mask], - single_packed["pixel_values"][0, single_image_mask], + batch_inputs["pixel_values"][i, batch_image_mask], + single_input["pixel_values"][0, single_image_mask], batch_image_grid_thw, strict=True, ): @@ -1419,86 +641,127 @@ def test_batched_generation_matches_individual(self): pad_span = batch_ids[: max_length - single_len] assert torch.all(pad_span == pad_id), f"sample {i} left pad span not padded with pad id" + torch.testing.assert_close( + batch_inputs["attention_mask"][i], + batch_ids.ne(pad_id).long(), + ) - attention_mask = batch_ids.ne(pad_id).long() - assert not torch.any(attention_mask[: max_length - single_len]), f"sample {i} mask ones inside left pad" - assert torch.all(attention_mask[-single_len:]), f"sample {i} mask zeros inside content" - - assert batch_packed["pixel_values"] is not None - assert batch_packed["image_grid_thw"] is not None - assert batch_packed["image_metadata"] is not None + single_texts = [] + for single_input in single_inputs: + single_input = single_input.to(self.device, dtype=self.dtype) + with torch.no_grad(): + outputs = self.model.generate( + **single_input, + max_new_tokens=self.max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + ) + generated_ids = outputs.sequences[:, single_input["input_ids"].shape[1] :] + single_texts.append(self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]) - batch_texts = self._generate_batch(prompts, images_list, num_tokens=100) + batch_inputs = batch_inputs.to(self.device, dtype=self.dtype) + with torch.no_grad(): + batch_outputs = self.model.generate( + **batch_inputs, + max_new_tokens=100, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + ) + batch_generated_ids = batch_outputs.sequences[:, batch_inputs["input_ids"].shape[1] :] + batch_texts = self.processor.batch_decode(batch_generated_ids, skip_special_tokens=True) assert len(batch_texts) == len(single_texts) == 3 - for i, (btxt, stxt) in enumerate(zip(batch_texts, single_texts)): - assert stxt in btxt, f"batch[{i}] mismatch: {btxt!r} vs single[{i}] {stxt!r}" + for i, (batch_text, single_text) in enumerate(zip(batch_texts, single_texts)): + assert single_text in batch_text, f"batch[{i}] mismatch: {batch_text!r} vs single[{i}] {single_text!r}" def test_batched_beam_generation_matches_individual(self): - red_image = _load_red_dot_image() - if red_image is None: + image = _load_red_dot_image() + if image is None: pytest.skip("PIL.Image is required for Isaac generation tests.") - vqa_document = [ - { - "type": "image", - "content": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", - "role": "user", - }, - { - "type": "text", - "content": "Is it safe to cross the street at this moment?", - "role": "user", - }, + conversations = [ + [ + { + "role": "user", + "content": [{"type": "text", "text": "What is the pythogorean theorem?"}], + } + ], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image:"}, + {"type": "image", "image": image}, + ], + } + ], + [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", + }, + {"type": "text", "text": "Is it safe to cross the street at this moment?"}, + ], + } + ], ] beam_kwargs = {"num_beams": 2} - doc_text_only = [{"type": "text", "content": "What is the pythogorean theorem?", "role": "user"}] - messages_text_only, images_text_only = document_to_messages(doc_text_only) - single_text_only = self._generate_from_messages( - messages_text_only, - images_text_only, - num_tokens=self.max_new_tokens, - generate_kwargs=beam_kwargs, - ) - assert single_text_only, "Text-only beam generation is empty" - - messages_image_text = [ - {"role": "user", "content": "Describe this image:"}, - {"role": "user", "content": self.processor.image_token}, - ] - single_image_text = self._generate_from_messages(messages_image_text, [red_image], generate_kwargs=beam_kwargs) - assert single_image_text, "Image-text beam generation is empty" - - messages_vqa, images_vqa = document_to_messages(vqa_document, image_token=self.processor.image_token) - single_vqa = self._generate_from_messages( - messages_vqa, - images_vqa, - num_tokens=self.max_new_tokens, - generate_kwargs=beam_kwargs, - ) - assert single_vqa, "VQA beam generation is empty" - - single_texts = [single_text_only, single_image_text, single_vqa] - prompts = [ - self.processor.apply_chat_template(messages_text_only, tokenize=False, add_generation_prompt=True).strip(), - self.processor.apply_chat_template( - messages_image_text, tokenize=False, add_generation_prompt=True - ).strip(), - self.processor.apply_chat_template(messages_vqa, tokenize=False, add_generation_prompt=True).strip(), - ] - images_list = [images_text_only, [red_image], images_vqa] - - batch_texts = self._generate_batch( - prompts, - images_list, - num_tokens=100, - generate_kwargs=beam_kwargs, - ) + single_texts = [] + for conversation in conversations: + single_input = self.processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + ).to(self.device, dtype=self.dtype) + with torch.no_grad(): + outputs = self.model.generate( + **single_input, + max_new_tokens=self.max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + **beam_kwargs, + ) + generated_ids = outputs.sequences[:, single_input["input_ids"].shape[1] :] + single_texts.append(self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]) + + batch_inputs = self.processor.apply_chat_template( + conversations, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + processor_kwargs={"padding_side": "left"}, + ).to(self.device, dtype=self.dtype) + with torch.no_grad(): + batch_outputs = self.model.generate( + **batch_inputs, + max_new_tokens=100, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + **beam_kwargs, + ) + batch_generated_ids = batch_outputs.sequences[:, batch_inputs["input_ids"].shape[1] :] + batch_texts = self.processor.batch_decode(batch_generated_ids, skip_special_tokens=True) assert len(batch_texts) == len(single_texts) == 3 - for i, (btxt, stxt) in enumerate(zip(batch_texts, single_texts)): - assert stxt in btxt, f"beam batch[{i}] mismatch: {btxt!r} vs single[{i}] {stxt!r}" + for i, (batch_text, single_text) in enumerate(zip(batch_texts, single_texts)): + assert single_text in batch_text, ( + f"beam batch[{i}] mismatch: {batch_text!r} vs single[{i}] {single_text!r}" + ) @require_torch @@ -1525,50 +788,43 @@ def setUp(self): self.model.eval() def test_hf_generate_box_points(self): - document = [ + conversation = [ { - "type": "text", - "content": "BOX", "role": "user", - }, - { - "type": "image", - "content": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", - "role": "user", - }, - { - "type": "text", - "content": "Determine whether it is safe to cross the street. Look for signage and moving traffic.", - "role": "user", - }, + "content": [ + {"type": "text", "text": "BOX"}, + { + "type": "image", + "url": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", + }, + { + "type": "text", + "text": "Determine whether it is safe to cross the street. Look for signage and moving traffic.", + }, + ], + } ] - messages, images = document_to_messages(document, image_token=self.processor.image_token) - prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() - processor_output = self.processor(text=prompt, images=images, return_tensors="pt") - input_ids = processor_output["input_ids"].to(self.device) - attention_mask = processor_output.get("attention_mask") - if attention_mask is not None: - attention_mask = attention_mask.to(self.device) - prompt_len = input_ids.shape[1] - multimodal_inputs = to_model_multimodal_inputs(processor_output, self.device) + inputs = self.processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + ).to(self.device, dtype=self.dtype) with torch.no_grad(): outputs = self.model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - **multimodal_inputs, + **inputs, max_new_tokens=self.max_new_tokens, do_sample=False, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, - tokenizer=self.tokenizer, return_dict_in_generate=True, ) - generated_ids = outputs.sequences - hf_generated_tail = generated_ids[:, prompt_len:] - hf_generated_text = self.tokenizer.decode(hf_generated_tail[0], skip_special_tokens=True) - clean_text, points = self.processor.post_process_generation(hf_generated_text, expected="box") + generated_ids = outputs.sequences[:, inputs["input_ids"].shape[1] :] + generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + _, points = self.processor.post_process_generation(generated_text, expected="box") assert len(points) == 1 first_point = points[0] assert first_point.top_left.x < first_point.bottom_right.x @@ -1580,38 +836,33 @@ def test_hf_generate_box_points(self): assert first_point.bottom_right.y == 386 def test_hf_generate_polygon_points(self): - document = [ - { - "type": "text", - "content": "POLYGON", - "role": "user", - }, + conversation = [ { - "type": "image", - "content": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", "role": "user", - }, - { - "type": "text", - "content": "Determine whether it is safe to cross the street. Look for signage and moving traffic.", - "role": "user", - }, + "content": [ + {"type": "text", "text": "POLYGON"}, + { + "type": "image", + "url": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", + }, + { + "type": "text", + "text": "Determine whether it is safe to cross the street. Look for signage and moving traffic.", + }, + ], + } ] - messages, images = document_to_messages(document, image_token=self.processor.image_token) - prompt = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).strip() - processor_output = self.processor(text=prompt, images=images, return_tensors="pt") - input_ids = processor_output["input_ids"].to(self.device) - attention_mask = processor_output.get("attention_mask") - if attention_mask is not None: - attention_mask = attention_mask.to(self.device) - prompt_len = input_ids.shape[1] - multimodal_inputs = to_model_multimodal_inputs(processor_output, self.device) + inputs = self.processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + ).to(self.device, dtype=self.dtype) with torch.no_grad(): outputs = self.model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - **multimodal_inputs, + **inputs, max_new_tokens=self.max_new_tokens, do_sample=False, pad_token_id=self.tokenizer.eos_token_id, @@ -1619,10 +870,9 @@ def test_hf_generate_polygon_points(self): return_dict_in_generate=True, ) - generated_ids = outputs.sequences - hf_generated_tail = generated_ids[:, prompt_len:] - hf_generated_text = self.tokenizer.decode(hf_generated_tail[0], skip_special_tokens=True) - _, polygons = self.processor.post_process_generation(hf_generated_text, expected="polygon") + generated_ids = outputs.sequences[:, inputs["input_ids"].shape[1] :] + generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + _, polygons = self.processor.post_process_generation(generated_text, expected="polygon") assert len(polygons) == 1 first_polygon = polygons[0] xs = [point.x for point in first_polygon.points] diff --git a/tests/models/isaac/test_processing_isaac.py b/tests/models/isaac/test_processing_isaac.py index 431850cdbe64..7c2bbc3f6048 100644 --- a/tests/models/isaac/test_processing_isaac.py +++ b/tests/models/isaac/test_processing_isaac.py @@ -15,466 +15,35 @@ """Testing suite for the Isaac processor.""" import os -import re import unittest from pathlib import Path import numpy as np import pytest -import torch from huggingface_hub import is_offline_mode -from transformers import IsaacConfig, PythonBackend -from transformers.models.isaac.image_processing_isaac import IsaacImageProcessor from transformers.models.isaac.processing_isaac import IsaacProcessor from transformers.testing_utils import require_torch, require_vision -from transformers.tokenization_utils_base import BatchEncoding -from transformers.utils import is_vision_available +from transformers.utils import is_torch_available, is_vision_available from ...test_processing_common import ProcessorTesterMixin +if is_torch_available(): + import torch + if is_vision_available(): from PIL import Image else: Image = None -ISAAC_OUTPUT_KEYS = { - "input_ids", - "attention_mask", - "mm_token_type_ids", - "pixel_values", - "image_grid_thw", - "image_metadata", -} - - -def _simple_tokenizer_call( - tokenizer, - text, - padding=False, - truncation=None, - max_length=None, - pad_to_multiple_of=None, - return_attention_mask=True, - return_overflowing_tokens=False, - return_tensors=None, - add_special_tokens=True, - **kwargs, -): - texts = [text] if isinstance(text, str) else list(text) - rows = [] - row_kinds = [] - overflow_to_sample_mapping = [] - - for sample_idx, sample in enumerate(texts): - token_ids = [tokenizer._convert_token_to_id(token) for token in tokenizer._tokenize(sample)] - if add_special_tokens: - token_ids = tokenizer.build_inputs_with_special_tokens(token_ids) - - kept_ids = list(token_ids) - dropped_ids = [] - if truncation and max_length is not None and len(token_ids) > max_length: - if tokenizer.truncation_side == "left": - dropped_ids = token_ids[:-max_length] - kept_ids = token_ids[-max_length:] - else: - kept_ids = token_ids[:max_length] - dropped_ids = token_ids[max_length:] - - rows.append(kept_ids) - row_kinds.append("kept") - overflow_to_sample_mapping.append(sample_idx) - - if return_overflowing_tokens and dropped_ids: - rows.append(dropped_ids) - row_kinds.append("overflow") - overflow_to_sample_mapping.append(sample_idx) - - kept_rows = [row for row, row_kind in zip(rows, row_kinds, strict=True) if row_kind == "kept"] - target_length = None - if padding in (True, "longest"): - target_length = max((len(row) for row in kept_rows), default=0) - elif padding == "max_length": - target_length = max_length - - if target_length is not None and pad_to_multiple_of is not None and target_length % pad_to_multiple_of != 0: - target_length = ((target_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - - padded_rows = [] - attention_masks = [] - for row, row_kind in zip(rows, row_kinds, strict=True): - if row_kind == "kept" and target_length is not None: - pad_len = target_length - len(row) - if tokenizer.padding_side == "left": - padded_row = [tokenizer.pad_token_id] * pad_len + row - attention_mask = [0] * pad_len + [1] * len(row) - else: - padded_row = row + [tokenizer.pad_token_id] * pad_len - attention_mask = [1] * len(row) + [0] * pad_len - else: - padded_row = row - attention_mask = [1] * len(row) - - padded_rows.append(padded_row) - attention_masks.append(attention_mask) - - data = {"input_ids": padded_rows} - if return_attention_mask: - data["attention_mask"] = attention_masks - if return_overflowing_tokens: - data["overflow_to_sample_mapping"] = overflow_to_sample_mapping - - return BatchEncoding(data=data, tensor_type=return_tensors) - - -class SimpleIsaacTokenizer(PythonBackend): - vocab_files_names = {} - model_input_names = ["input_ids"] - - def __init__(self): - self._vocab = { - "": 0, - "": 1, - "": 2, - "": 3, - "": 4, - "<|image_pad|>": 5, - } - self._ids_to_tokens = {idx: tok for tok, idx in self._vocab.items()} - super().__init__( - bos_token="", - eos_token="", - pad_token="", - unk_token="", - additional_special_tokens=[""], - extra_special_tokens={"image_pad_token": "<|image_pad|>"}, - model_max_length=512, - ) - - def get_vocab(self): - return dict(self._vocab) - - def _tokenize(self, text): - clean = text.replace("\n", " ").strip() - if not clean: - return [] - - special_tokens = sorted( - (token for token in self._vocab if token.startswith("<") and token.endswith(">")), - key=len, - reverse=True, - ) - if not special_tokens: - return [token for token in clean.split(" ") if token] - - split_pattern = "(" + "|".join(re.escape(token) for token in special_tokens) + ")" - tokens = [] - for chunk in re.split(split_pattern, clean): - if not chunk or chunk.isspace(): - continue - if chunk in self._vocab: - tokens.append(chunk) - else: - tokens.extend(token for token in chunk.split(" ") if token) - return tokens - - def _convert_token_to_id(self, token): - if token not in self._vocab: - next_id = len(self._vocab) - self._vocab[token] = next_id - self._ids_to_tokens[next_id] = token - return self._vocab[token] - - def _convert_id_to_token(self, index): - return self._ids_to_tokens.get(index, self.unk_token) - - @property - def vocab_size(self) -> int: - return len(self._vocab) - - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - if token_ids_1 is not None: - token_ids_0 = token_ids_0 + token_ids_1 - return [self.bos_token_id] + list(token_ids_0) + [self.eos_token_id] - - def save_vocabulary(self, save_directory, filename_prefix=None): - return () - - def __call__(self, text, **kwargs): - return _simple_tokenizer_call(self, text, **kwargs) - - -class SimpleIsaacTokenizerWithNamedImagePad(PythonBackend): - vocab_files_names = {} - model_input_names = ["input_ids"] - - def __init__(self): - self._vocab = { - "": 0, - "": 1, - "": 2, - "": 3, - "": 4, - "": 5, - "<|image_pad|>": 6, - } - self._ids_to_tokens = {idx: tok for tok, idx in self._vocab.items()} - super().__init__( - bos_token="", - eos_token="", - pad_token="", - unk_token="", - extra_special_tokens={"image_pad_token": ""}, - model_max_length=512, - ) - - def get_vocab(self): - return dict(self._vocab) - - def _tokenize(self, text): - clean = text.replace("\n", " ").strip() - if not clean: - return [] - - special_tokens = sorted( - (token for token in self._vocab if token.startswith("<") and token.endswith(">")), - key=len, - reverse=True, - ) - split_pattern = "(" + "|".join(re.escape(token) for token in special_tokens) + ")" - tokens = [] - for chunk in re.split(split_pattern, clean): - if not chunk or chunk.isspace(): - continue - if chunk in self._vocab: - tokens.append(chunk) - else: - tokens.extend(token for token in chunk.split(" ") if token) - return tokens - - def _convert_token_to_id(self, token): - return self._vocab.get(token, self._vocab[""]) - - def _convert_id_to_token(self, index): - return self._ids_to_tokens.get(index, self.unk_token) - - @property - def vocab_size(self) -> int: - return len(self._vocab) - - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - if token_ids_1 is not None: - token_ids_0 = token_ids_0 + token_ids_1 - return [self.bos_token_id] + list(token_ids_0) + [self.eos_token_id] - - def save_vocabulary(self, save_directory, filename_prefix=None): - return () - - def __call__(self, text, **kwargs): - return _simple_tokenizer_call(self, text, **kwargs) - - -class IsaacProcessorTestDouble(IsaacProcessor): - def check_argument_for_proper_class(self, argument_name, argument): - return type(argument) - - def _make_dummy_image(size=(32, 32), color=(255, 0, 0)): if Image is None: raise RuntimeError("PIL.Image is not available in this environment.") return Image.new("RGB", size, color=color) -def _make_processor_with_max_len(tokenizer, base_config, max_len): - config = IsaacConfig(**base_config.to_dict()) - config.max_sequence_length = max_len - vision_config = config.vision_config - image_processor = IsaacImageProcessor( - patch_size=vision_config.patch_size, - max_num_patches=vision_config.num_patches, - pixel_shuffle_scale=vision_config.pixel_shuffle_scale_factor, - rescale_factor=config.vision_rescale_factor, - ) - return IsaacProcessorTestDouble( - image_processor=image_processor, - tokenizer=tokenizer, - max_sequence_length=config.max_sequence_length, - ) - - -def _run_processor(processor, text, images=None): - return processor(text=text, images=images, return_tensors="pt") - - -def _make_post_process_processor(): - return IsaacProcessorTestDouble(image_processor=IsaacImageProcessor(), tokenizer=SimpleIsaacTokenizer()) - - -def test_processor_prefers_named_image_pad_token(): - processor = IsaacProcessorTestDouble( - image_processor=IsaacImageProcessor(), tokenizer=SimpleIsaacTokenizerWithNamedImagePad() - ) - - assert processor.image_token == "" - assert processor.image_token_id == processor.tokenizer.image_pad_token_id - assert processor.image_token_id != processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") - - -def _assert_common(outputs, batch_size=1): - assert set(outputs.keys()) == ISAAC_OUTPUT_KEYS - - input_ids = outputs["input_ids"] - attention_mask = outputs["attention_mask"] - mm_token_type_ids = outputs["mm_token_type_ids"] - pixel_values = outputs["pixel_values"] - image_grid_thw = outputs["image_grid_thw"] - image_metadata = outputs["image_metadata"] - - assert input_ids.shape[0] == batch_size - assert attention_mask.shape == input_ids.shape - assert mm_token_type_ids.shape == input_ids.shape - assert input_ids.dtype == torch.long - assert attention_mask.dtype == torch.long - assert mm_token_type_ids.dtype == torch.long - - if pixel_values is None: - assert image_grid_thw is None - assert image_metadata is None - else: - assert pixel_values.ndim == 4 - assert image_grid_thw.shape == (batch_size, pixel_values.shape[1], 3) - assert image_metadata.shape == (batch_size, pixel_values.shape[1], 2) - assert image_grid_thw.dtype == torch.long - assert image_metadata.dtype == torch.long - - active_slots = image_grid_thw[..., 0].eq(1) - assert torch.all(image_grid_thw[~active_slots].eq(0)) - if active_slots.any(): - assert torch.all(image_grid_thw[active_slots, 1:] > 0) - assert torch.all(image_metadata[active_slots] >= 0) - - return outputs - - -def _get_sample_image_mask(outputs, batch_index=0): - image_grid_thw = outputs["image_grid_thw"] - if image_grid_thw is None: - return torch.zeros((0,), dtype=torch.bool) - return image_grid_thw[batch_index, :, 0].eq(1) - - -def _assert_no_vision(outputs, batch_index=0): - assert not _get_sample_image_mask(outputs, batch_index=batch_index).any() - assert not outputs["mm_token_type_ids"][batch_index].eq(1).any() - - -def _assert_vision_segments(outputs, expected_segments, batch_index=0): - sample_image_mask = _get_sample_image_mask(outputs, batch_index=batch_index) - active_segments = int(sample_image_mask.sum().item()) - assert active_segments == expected_segments - assert torch.all(outputs["image_metadata"][batch_index, sample_image_mask, 1] > 0) - assert torch.all(outputs["image_grid_thw"][batch_index, sample_image_mask, 1:].prod(dim=-1) > 0) - - -def _count_modality(outputs, modality_value, batch_index=0): - return int( - (outputs["attention_mask"][batch_index].bool() & outputs["mm_token_type_ids"][batch_index].eq(modality_value)) - .sum() - .item() - ) - - -def _get_active_vision_grids(outputs, batch_index=0): - image_grid_thw = outputs["image_grid_thw"] - if image_grid_thw is None: - return torch.zeros((0, 2), dtype=torch.long) - return image_grid_thw[batch_index, _get_sample_image_mask(outputs, batch_index=batch_index), 1:] - - -def _get_active_vision_offsets(outputs, batch_index=0): - image_metadata = outputs["image_metadata"] - if image_metadata is None: - return torch.zeros((0,), dtype=torch.long) - return image_metadata[batch_index, _get_sample_image_mask(outputs, batch_index=batch_index), 0] - - -def _get_active_vision_lengths(outputs, batch_index=0): - image_metadata = outputs["image_metadata"] - if image_metadata is None: - return torch.zeros((0,), dtype=torch.long) - return image_metadata[batch_index, _get_sample_image_mask(outputs, batch_index=batch_index), 1] - - -def _get_expected_vision_lengths(outputs, pixel_shuffle_scale=1, batch_index=0): - grids = _get_active_vision_grids(outputs, batch_index=batch_index) - if grids.numel() == 0: - return grids.new_zeros((0,)) - return torch.prod(grids, dim=-1) // (pixel_shuffle_scale**2) - - -@pytest.fixture -def isaac_tiny_config(): - text_config = { - "bos_token_id": 0, - "eos_token_id": 1, - "pad_token_id": 2, - "hidden_act": "silu", - "head_dim": 32 // 4, - "hidden_size": 32, - "vocab_size": 99, - "intermediate_size": 32 * 3, - "max_position_embeddings": 128, - "model_type": "qwen3", - "num_attention_heads": 4, - "num_hidden_layers": 2, - "num_key_value_heads": 4, - "rope_parameters": {"rope_type": "default", "mrope_section": [2, 1, 1], "mrope_interleaved": True}, - "tie_word_embeddings": True, - } - - vision_config = { - "hidden_size": 32, - "intermediate_size": 32 * 2, - "num_hidden_layers": 1, - "num_attention_heads": 4, - "num_channels": 3, - "num_patches": 64, - "patch_size": 4, - "pixel_shuffle_scale_factor": 1, - "attention_dropout": 0.0, - "layer_norm_eps": 1e-6, - } - - config = IsaacConfig(text_config=text_config, vision_config=vision_config) - config._attn_implementation = "sdpa" - config.text_config._attn_implementation = "sdpa" - config.vision_attn_implementation = "sdpa" - return config - - -@pytest.fixture -def isaac_tokenizer(): - return SimpleIsaacTokenizer() - - -@pytest.fixture -def isaac_processor(isaac_tokenizer, isaac_tiny_config): - vision_config = isaac_tiny_config.vision_config - image_processor = IsaacImageProcessor( - patch_size=vision_config.patch_size, - max_num_patches=vision_config.num_patches, - pixel_shuffle_scale=vision_config.pixel_shuffle_scale_factor, - rescale_factor=isaac_tiny_config.vision_rescale_factor, - ) - return IsaacProcessorTestDouble( - image_processor=image_processor, - tokenizer=isaac_tokenizer, - max_sequence_length=isaac_tiny_config.max_sequence_length, - ) - - BASE_MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1-Base") BASE_MODEL_REVISION = os.environ.get("ISAAC_TEST_MODEL_REVISION", "refs/pr/3") or None LOCAL_CHECKPOINT = os.environ.get("ISAAC_TEST_MODEL_PATH") @@ -494,7 +63,7 @@ def _checkpoint_or_skip(model_id=BASE_MODEL_ID): @require_torch @require_vision class IsaacProcessorTest(ProcessorTesterMixin, unittest.TestCase): - processor_class = IsaacProcessorTestDouble + processor_class = IsaacProcessor model_id = BASE_MODEL_ID images_input_name = "pixel_values" @@ -523,24 +92,6 @@ def prepare_image_inputs(self, batch_size: int | None = None, nested: bool = Fal return [[image] for image in images] return images - def test_model_input_names(self): - processor = self.get_processor() - inputs = processor( - text=self.prepare_text_inputs(modalities="image"), - images=self.prepare_image_inputs(), - return_tensors="pt", - ) - - self.assertSetEqual(set(inputs.keys()), set(processor.model_input_names)) - - @unittest.skip("IsaacProcessor expands image placeholders into image pad tokens before tokenization") - def test_tokenizer_defaults(self): - pass - - @unittest.skip("IsaacProcessor does not return offset mappings needed for assistant masks") - def test_apply_chat_template_assistant_mask(self): - pass - @unittest.skip("Isaac chat templates emit placeholders but the processor consumes image pad tokens") def test_apply_chat_template_image_0(self): pass @@ -549,6 +100,58 @@ def test_apply_chat_template_image_0(self): def test_apply_chat_template_image_1(self): pass + def test_apply_chat_template_image_placeholder_expands_to_image_pad_tokens(self): + processor = self.get_processor() + image = _make_dummy_image(size=(16, 16)) + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this."}, + {"type": "image", "image": image}, + ], + } + ] + ] + + formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + self.assertEqual(len(formatted_prompt), 1) + self.assertIn("", formatted_prompt[0]) + + out_dict = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + self.assertTrue( + all( + key in out_dict + for key in [ + "input_ids", + "attention_mask", + "pixel_values", + "image_grid_thw", + "image_metadata", + "mm_token_type_ids", + ] + ) + ) + + expected_num_image_tokens = processor._get_num_multimodal_tokens(image_sizes=[(image.height, image.width)])[ + "num_image_tokens" + ][0] + actual_num_image_tokens = int(out_dict["input_ids"][0].eq(processor.image_token_id).sum().item()) + + self.assertEqual(actual_num_image_tokens, expected_num_image_tokens) + self.assertEqual(int(out_dict["mm_token_type_ids"][0].sum().item()), expected_num_image_tokens) + self.assertEqual(int(out_dict["image_metadata"][0, 0, 1].item()), expected_num_image_tokens) + self.assertTrue( + torch.all(out_dict["mm_token_type_ids"][0][out_dict["input_ids"][0].eq(processor.image_token_id)] == 1) + ) + def test_get_num_multimodal_tokens_matches_processor_call(self): processor = self.get_processor() @@ -567,313 +170,3 @@ def test_get_num_multimodal_tokens_matches_processor_call(self): num_image_tokens_from_call = inputs.mm_token_type_ids.sum(-1).tolist() num_image_tokens_from_helper = processor._get_num_multimodal_tokens(image_sizes=image_sizes) self.assertListEqual(num_image_tokens_from_call, num_image_tokens_from_helper["num_image_tokens"]) - - def test_single_vs_batched_consistency(self): - processor = self.get_processor() - prompt = f"hello {processor.image_token} world" - image = self.prepare_image_inputs() - - single = _assert_common(processor(text=prompt, images=[image], return_tensors="pt")) - batch = _assert_common( - processor(text=[prompt, "short"], images=[[image], []], return_tensors="pt"), batch_size=2 - ) - - single_ids = single["input_ids"].squeeze(0) - batch_ids = batch["input_ids"][0] - self.assertTrue(torch.equal(batch_ids[-single_ids.size(0) :], single_ids)) - - image_positions = batch["mm_token_type_ids"][0].eq(1) - if image_positions.any(): - self.assertTrue(torch.all(batch_ids[image_positions] == self.image_pad_token_id)) - self.assertTrue(torch.all(batch["attention_mask"][0][image_positions] == 1)) - - single_image_mask = _get_sample_image_mask(single, batch_index=0) - batch_image_mask = _get_sample_image_mask(batch, batch_index=0) - torch.testing.assert_close( - batch["pixel_values"][0, batch_image_mask], - single["pixel_values"][0, single_image_mask], - ) - torch.testing.assert_close( - batch["image_grid_thw"][0, batch_image_mask], - single["image_grid_thw"][0, single_image_mask], - ) - torch.testing.assert_close( - batch["image_metadata"][0, batch_image_mask], - single["image_metadata"][0, single_image_mask], - ) - - _assert_vision_segments(batch, expected_segments=1, batch_index=0) - _assert_no_vision(batch, batch_index=1) - - -@require_torch -@require_vision -def test_text_only_has_no_vision_fields(isaac_processor): - outputs = _assert_common(_run_processor(isaac_processor, text="Hello, how are you?", images=None)) - assert outputs["pixel_values"] is None - assert outputs["image_grid_thw"] is None - assert outputs["image_metadata"] is None - _assert_no_vision(outputs) - - -@require_torch -def test_post_process_generation_extracts_boxes_and_cleans_text(): - processor = _make_post_process_processor() - - generated_text = ( - "No, it is not safe to cross the street. " - '(808, 247), (863, 386)' - ) - - clean_text, annotations = processor.post_process_generation(generated_text) - - assert clean_text == "No, it is not safe to cross the street." - assert len(annotations) == 1 - box = annotations[0] - assert box.mention == "traffic light" - assert box.t == pytest.approx(0.5) - assert box.top_left.x == 808 - assert box.top_left.y == 247 - assert box.bottom_right.x == 863 - assert box.bottom_right.y == 386 - - -@require_torch -def test_post_process_generation_extracts_polygons_and_filters_by_expected_type(): - processor = _make_post_process_processor() - - generated_text = ( - 'Point (1, 2) ' - 'Box (3, 4), (5, 6) ' - 'Polygon (10, 20), (30, 40), (50, 60)' - ) - - clean_text, annotations = processor.post_process_generation(generated_text, expected="polygon") - - assert clean_text == "Point Box Polygon" - assert len(annotations) == 1 - polygon = annotations[0] - assert polygon.mention == "lane" - assert polygon.t == pytest.approx(0.25) - assert len(polygon.points) == 3 - assert polygon.points[0].x == 10 - assert polygon.points[0].y == 20 - assert polygon.points[1].x == 30 - assert polygon.points[1].y == 40 - assert polygon.points[2].x == 50 - assert polygon.points[2].y == 60 - - _, boxes = processor.post_process_generation(generated_text, expected="box") - assert len(boxes) == 1 - assert boxes[0].mention == "sign" - - -@require_torch -def test_post_process_generation_rejects_polygons_with_fewer_than_three_points(): - processor = _make_post_process_processor() - - with pytest.raises(ValueError, match=r"Malformed tag"): - processor.post_process_generation('(10, 20), (30, 40)', expected="polygon") - - -@require_torch -@require_vision -def test_single_image_returns_offsets_and_lengths(isaac_processor): - image_token = isaac_processor.image_token - outputs = _assert_common( - _run_processor( - isaac_processor, text=f"Look at this {image_token} and describe it.", images=[_make_dummy_image()] - ) - ) - _assert_vision_segments(outputs, expected_segments=1) - - grid_tokens = _get_expected_vision_lengths(outputs, isaac_processor.image_processor.pixel_shuffle_scale) - torch.testing.assert_close(_get_active_vision_lengths(outputs), grid_tokens) - torch.testing.assert_close( - _get_active_vision_offsets(outputs), torch.zeros_like(_get_active_vision_offsets(outputs)) - ) - - -@require_torch -@require_vision -def test_multiple_images_have_matching_offsets_lengths_and_grids(isaac_processor): - image_token = isaac_processor.image_token - images = [_make_dummy_image(color=(255, 0, 0)), _make_dummy_image(color=(0, 255, 0))] - - outputs = _assert_common( - _run_processor(isaac_processor, text=f"First {image_token} then {image_token}", images=images) - ) - _assert_vision_segments(outputs, expected_segments=2) - - grid_tokens = _get_expected_vision_lengths(outputs, isaac_processor.image_processor.pixel_shuffle_scale) - torch.testing.assert_close(_get_active_vision_lengths(outputs), grid_tokens) - torch.testing.assert_close( - _get_active_vision_offsets(outputs), torch.zeros_like(_get_active_vision_offsets(outputs)) - ) - - -@require_torch -@require_vision -def test_error_on_image_mismatch(isaac_processor): - image_token = isaac_processor.image_token - with pytest.raises(ValueError, match="one image per"): - _run_processor(isaac_processor, text=f"{image_token} {image_token}", images=[_make_dummy_image()]) - - -@require_torch -@require_vision -def test_consecutive_vision_tokens_allow_empty_text_segments(isaac_processor): - image_token = isaac_processor.image_token - images = [_make_dummy_image(), _make_dummy_image(color=(0, 0, 255))] - - outputs = _assert_common( - _run_processor(isaac_processor, text=f"prefix {image_token}{image_token} suffix", images=images) - ) - _assert_vision_segments(outputs, expected_segments=2) - - torch.testing.assert_close( - _get_active_vision_offsets(outputs), torch.zeros_like(_get_active_vision_offsets(outputs)) - ) - grid_tokens = _get_expected_vision_lengths(outputs, isaac_processor.image_processor.pixel_shuffle_scale) - torch.testing.assert_close(_get_active_vision_lengths(outputs), grid_tokens) - - -@require_torch -@require_vision -def test_device_and_dtype_consistency(isaac_processor): - image_token = isaac_processor.image_token - outputs = _assert_common( - _run_processor(isaac_processor, text=f"Describe this {image_token}", images=[_make_dummy_image()]) - ) - _assert_vision_segments(outputs, expected_segments=1) - - tensors = [ - outputs["input_ids"], - outputs["attention_mask"], - outputs["mm_token_type_ids"], - outputs["image_grid_thw"], - outputs["image_metadata"], - ] - devices = {tensor.device for tensor in tensors} - assert len(devices) == 1 - for tensor in tensors: - assert tensor.dtype == torch.long - - -@require_torch -@require_vision -def test_no_crop_when_total_below_max(isaac_processor): - image_token = isaac_processor.image_token - outputs = _assert_common( - _run_processor(isaac_processor, text=f"hello {image_token} world", images=[_make_dummy_image()]) - ) - _assert_vision_segments(outputs, expected_segments=1) - - grid_tokens = _get_expected_vision_lengths(outputs, isaac_processor.image_processor.pixel_shuffle_scale) - text_tokens = _count_modality(outputs, 0) - assert outputs["input_ids"].shape[1] == grid_tokens.item() + text_tokens - - -@require_torch -@require_vision -def test_exact_fit_keeps_all_tokens(isaac_processor, isaac_tokenizer, isaac_tiny_config): - image_token = isaac_processor.image_token - text = f"hey {image_token} there" - image = _make_dummy_image() - - base_outputs = _assert_common(_run_processor(isaac_processor, text=text, images=[image])) - base_length = base_outputs["input_ids"].shape[1] - base_vision_length = _get_active_vision_lengths(base_outputs).item() - - processor = _make_processor_with_max_len(isaac_tokenizer, isaac_tiny_config, base_length) - outputs = _assert_common(_run_processor(processor, text=text, images=[image])) - - _assert_vision_segments(outputs, expected_segments=1) - assert outputs["input_ids"].shape[1] == base_length - assert _get_active_vision_lengths(outputs).item() == base_vision_length - - -@require_torch -@require_vision -def test_crop_truncates_text_segment_only(isaac_processor, isaac_tokenizer, isaac_tiny_config): - image_token = isaac_processor.image_token - text_prefix_tokens = " ".join([f"t{i}" for i in range(8)]) - text = f"{text_prefix_tokens} {image_token} tail end" - image = _make_dummy_image() - - base_outputs = _assert_common(_run_processor(isaac_processor, text=text, images=[image])) - full_text_tokens = _count_modality(base_outputs, 0) - vision_length = _get_active_vision_lengths(base_outputs).item() - - max_len = base_outputs["input_ids"].shape[1] - 4 - processor = _make_processor_with_max_len(isaac_tokenizer, isaac_tiny_config, max_len) - outputs = _assert_common(_run_processor(processor, text=text, images=[image])) - - _assert_vision_segments(outputs, expected_segments=1) - assert outputs["input_ids"].shape[1] == max_len - assert _count_modality(outputs, 0) == full_text_tokens - 4 - torch.testing.assert_close( - _get_active_vision_offsets(outputs), torch.zeros_like(_get_active_vision_offsets(outputs)) - ) - assert _get_active_vision_lengths(outputs).item() == vision_length - - -@require_torch -@require_vision -def test_crop_cuts_through_image_segment(isaac_processor, isaac_tokenizer, isaac_tiny_config): - image_token = isaac_processor.image_token - text_before = "hi" - text_after = "bye" - text = f"{text_before} {image_token} {text_after}" - image = _make_dummy_image() - - base_outputs = _assert_common(_run_processor(isaac_processor, text=text, images=[image])) - vision_full = _get_active_vision_lengths(base_outputs).item() - text_before_len = len(isaac_tokenizer.encode(text_before, add_special_tokens=False)) - text_after_len = len(isaac_tokenizer.encode(text_after, add_special_tokens=False)) - total_length = vision_full + text_before_len + text_after_len - - max_len = 40 - start = total_length - max_len - expected_offset = max(0, start - text_before_len) - expected_length = vision_full - expected_offset - - processor = _make_processor_with_max_len(isaac_tokenizer, isaac_tiny_config, max_len) - outputs = _assert_common(_run_processor(processor, text=text, images=[image])) - - _assert_vision_segments(outputs, expected_segments=1) - assert outputs["input_ids"].shape[1] == max_len - assert _get_active_vision_offsets(outputs).item() == expected_offset - assert _get_active_vision_lengths(outputs).item() == expected_length - assert _count_modality(outputs, 0) == text_after_len - - -@require_torch -@require_vision -def test_batch_outputs_match_individual_calls(isaac_processor): - texts = ["hi", "this one is longer"] - - per_sample = [_assert_common(_run_processor(isaac_processor, text=text, images=None)) for text in texts] - batch_outputs = _assert_common(_run_processor(isaac_processor, text=texts, images=None), batch_size=len(texts)) - - pad_id = isaac_processor.pad_token_id - for index, single_output in enumerate(per_sample): - single_ids = single_output["input_ids"].squeeze(0) - single_mask = single_output["attention_mask"].squeeze(0) - single_mm = single_output["mm_token_type_ids"].squeeze(0) - - batch_ids = batch_outputs["input_ids"][index] - batch_mask = batch_outputs["attention_mask"][index] - batch_mm = batch_outputs["mm_token_type_ids"][index] - - single_len = single_ids.shape[0] - assert torch.equal(batch_ids[-single_len:], single_ids) - assert torch.equal(batch_mask[-single_len:], single_mask) - assert torch.equal(batch_mm[-single_len:], single_mm) - - if single_len < batch_ids.shape[0]: - pad_span = batch_ids[: batch_ids.shape[0] - single_len] - assert torch.all(pad_span == pad_id) - assert not torch.any(batch_mask[: batch_ids.shape[0] - single_len]) - - _assert_no_vision(batch_outputs, batch_index=index) diff --git a/utils/check_repo.py b/utils/check_repo.py index 9f259b481eef..4d040fc165c6 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -218,7 +218,7 @@ "Qwen3_5TextModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3_5ForConditionalGeneration. "Qwen3_5MoeTextModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3_5MoeForConditionalGeneration. "IsaacTextModel", # Building part of bigger (tested) model. Tested implicitly through IsaacForConditionalGeneration. - "IsaacVisionTransformer", # Building part of bigger (tested) model. Tested implicitly through IsaacForConditionalGeneration. + "IsaacVisionModel", # Building part of bigger (tested) model. Tested implicitly through IsaacForConditionalGeneration. "Qwen2_5OmniForConditionalGeneration", # Not a regular model. Testted in Qwen2_5OmniModelIntergrationTest "Qwen2_5OmniTalkerForConditionalGeneration", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest. "Qwen2_5OmniTalkerModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest. @@ -470,7 +470,7 @@ "PaddleOCRVisionTransformer", # Building part of bigger (tested) model "PaddleOCRTextModel", # Building part of bigger (tested) model "IsaacTextModel", # Building part of a bigger model - "IsaacVisionTransformer", # Building part of a bigger model + "IsaacVisionModel", # Building part of a bigger model "Qwen2_5OmniTalkerForConditionalGeneration", # Building part of a bigger model "Qwen2_5OmniTalkerModel", # Building part of a bigger model "Qwen2_5OmniThinkerForConditionalGeneration", # Building part of a bigger model @@ -1134,6 +1134,7 @@ def find_all_documented_objects() -> list[str]: "Ernie4_5_VL_MoeImageProcessorFast", # BC Alias "Ernie4_5_VL_MoeImageProcessorPil", # BC Alias "Ernie4_5_VL_MoeModel", # BC Alias + "IsaacVisionModel", # Internal building block tested implicitly through IsaacForConditionalGeneration. "Ernie4_5_VL_MoeTextConfig", # BC Alias "Ernie4_5_VL_MoeTextModel", # BC Alias "Ernie4_5_VL_MoeVariableResolutionResamplerModel", # BC Alias From 3d9e55da54424458d493805ab3cda912a5eeb8f7 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 13 Apr 2026 07:35:27 -0700 Subject: [PATCH 75/77] lint --- src/transformers/models/isaac/configuration_isaac.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py index 9e4c58954ffc..d5e0080d029a 100644 --- a/src/transformers/models/isaac/configuration_isaac.py +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -134,6 +134,7 @@ class IsaacConfig(PretrainedConfig): Rescale factor applied by the image processor before normalization. max_sequence_length (`int`, *optional*, defaults to 16384): Maximum multimodal sequence length produced by the processor and expected by the model. + Example: ```python From 251210fc727f70de23fbfb4f029bbcb285daf526 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 13 Apr 2026 16:44:53 +0200 Subject: [PATCH 76/77] fix: map isaac_vision to isaac module --- src/transformers/models/auto/configuration_auto.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 305aaf5f4e39..95b8378c0861 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -1113,6 +1113,7 @@ ("gemma4_audio", "gemma4"), ("gemma4_text", "gemma4"), ("gemma4_vision", "gemma4"), + ("isaac_vision", "isaac"), ("glm4v_vision", "glm4v"), ("glm4v_moe_vision", "glm4v_moe"), ("glm4v_text", "glm4v"), From bbadef8570e2d0a8cb0f847558d56693330ee739 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Mon, 13 Apr 2026 17:45:52 +0200 Subject: [PATCH 77/77] fix: specify required backend --- src/transformers/models/isaac/image_processing_isaac.py | 3 ++- src/transformers/models/isaac/modular_isaac.py | 3 +++ src/transformers/models/isaac/processing_isaac.py | 3 ++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/isaac/image_processing_isaac.py b/src/transformers/models/isaac/image_processing_isaac.py index 5f5af7905b6b..39750c1bd792 100644 --- a/src/transformers/models/isaac/image_processing_isaac.py +++ b/src/transformers/models/isaac/image_processing_isaac.py @@ -29,7 +29,7 @@ from ...processing_utils import ImagesKwargs, Unpack from ...utils import TensorType, auto_docstring from ...utils.constants import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD -from ...utils.import_utils import is_torch_available +from ...utils.import_utils import is_torch_available, requires if is_torch_available(): @@ -174,6 +174,7 @@ def get_image_size_for_max_num_patches( @auto_docstring +@requires(backends=("vision",)) class IsaacImageProcessor(TorchvisionBackend): model_input_names = ["pixel_values", "image_grid_thw"] valid_kwargs = IsaacImageProcessorKwargs diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py index 7be4e9bb1b00..5d2935064980 100644 --- a/src/transformers/models/isaac/modular_isaac.py +++ b/src/transformers/models/isaac/modular_isaac.py @@ -42,6 +42,7 @@ is_torch_available, is_torchdynamo_compiling, is_torchvision_available, + requires, ) from ...utils.output_capturing import capture_outputs from ..qwen3_vl.modeling_qwen3_vl import ( @@ -1091,6 +1092,7 @@ def get_image_size_for_max_num_patches( @auto_docstring +@requires(backends=("vision",)) class IsaacImageProcessor(TorchvisionBackend): model_input_names = ["pixel_values", "image_grid_thw"] valid_kwargs = IsaacImageProcessorKwargs @@ -1334,6 +1336,7 @@ class Polygon(NamedTuple): @auto_docstring +@requires(backends=("vision",)) class IsaacProcessor(ProcessorMixin): def __init__( self, diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py index 563ea90bde65..44dc5688e2a2 100644 --- a/src/transformers/models/isaac/processing_isaac.py +++ b/src/transformers/models/isaac/processing_isaac.py @@ -26,7 +26,7 @@ from ...image_utils import ImageInput, make_nested_list_of_images from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin from ...utils import auto_docstring -from ...utils.import_utils import is_torch_available +from ...utils.import_utils import is_torch_available, requires from .image_processing_isaac import IsaacImageProcessorKwargs from .modeling_isaac import BoundingBox, Polygon, SinglePoint @@ -61,6 +61,7 @@ class IsaacProcessorKwargs(ProcessingKwargs, total=False): @auto_docstring +@requires(backends=("vision",)) class IsaacProcessor(ProcessorMixin): def __init__( self,