diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index ba69db1c5e78..296df33bf11b 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1237,6 +1237,8 @@ title: InstructBlipVideo - local: model_doc/internvl title: InternVL + - local: model_doc/isaac + title: Isaac - local: model_doc/janus title: Janus - local: model_doc/kosmos-2 diff --git a/docs/source/en/model_doc/isaac.md b/docs/source/en/model_doc/isaac.md new file mode 100644 index 000000000000..91773a4962ed --- /dev/null +++ b/docs/source/en/model_doc/isaac.md @@ -0,0 +1,143 @@ + +*This model was released on {release_date} and added to Hugging Face Transformers on 2026-04-13.* + +
+
+ PyTorch + FlashAttention + SDPA +
+
+ +# Isaac + +## Overview + +Isaac is Perceptron's vision-language model (VLM) that pairs a SigLIP2 vision encoder with a Qwen3 decoder-only stack. The +Transformers implementation supports text-only and image-conditioned generation, including prompts with multiple interleaved +images. Isaac uses variable-resolution image preprocessing and can optionally reduce spatial tokens with pixel shuffle to keep +long multimodal prompts manageable. For more information, refer to the [technical report](https://github.com/perceptron-ai-inc/perceptron/blob/main/papers/isaac_01.pdf). + +Isaac checkpoints are distributed under Perceptron's Non-Production license; please review the license that ships with the +weights before using them in commercial settings. + +## Usage tips + +- Batched inputs can mix text-only and multimodal samples. For direct processor/model batching, pass images as a nested + list such as `[[], [image_a], [image_b, image_c]]`. +- `image_grid_thw[batch_idx, image_slot] == (0, 0, 0)` marks a padded empty slot. Real image slots have + `(T=1, H>0, W>0)`. +- If truncation is enabled, the processor keeps the rightmost part of the multimodal prompt and updates the slot-local + `image_metadata[..., 0]` and `image_metadata[..., 1]` values automatically. + +## Usage example + +Isaac uses explicit image placeholders in the rendered prompt. Every occurrence of `processor.image_token` (usually ``) must have a matching image in the `images` argument. + +```py +import torch +from PIL import Image +from transformers import AutoProcessor, IsaacForConditionalGeneration + +model_id = "PerceptronAI/Isaac-0.1" +processor = AutoProcessor.from_pretrained(model_id) +model = IsaacForConditionalGeneration.from_pretrained( + model_id, + dtype=torch.bfloat16, + device_map="auto", + attn_implementation="flash_attention_2", +) + +conversation = [ + { + + "role": "user", + "content": [ + {"type": "text", "text": "Compare the two figures and explain what changed."}, + {"type": "image", "path": "first_image.png"}, + {"type": "image", "path": "second_image.png"}, + ], + }, +] + +prompt = processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", +) + +inputs = processor(text=prompt, images=images, return_tensors="pt").to(model.device) +generated_ids = model.generate(**inputs, max_new_tokens=256, do_sample=False,) + +generated_ids = generated_ids[:, inputs["input_ids"].shape[1] :] +response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] +print(response) +``` + +### Post-processing grounded outputs + +Isaac can generate grounded points and boxes in tagged text spans. Use `post_process_generation()` to strip the tags and +recover structured annotations. + +```py +clean_text, annotations = processor.post_process_generation(response, expected="box") +print(clean_text) +print(annotations) +``` + +Set `expected="point"` to extract point annotations, or leave `expected=None` to collect both points and boxes. + +## IsaacVisionConfig + +[[autodoc]] IsaacVisionConfig + +## IsaacTextConfig + +[[autodoc]] IsaacTextConfig + +## IsaacConfig + +[[autodoc]] IsaacConfig + +## IsaacVisionModel + +[[autodoc]] IsaacVisionModel + +## IsaacTextModel + +[[autodoc]] IsaacTextModel + - forward + +## IsaacModel + +[[autodoc]] IsaacModel + - forward + +## IsaacForConditionalGeneration + +[[autodoc]] IsaacForConditionalGeneration + - forward + +## IsaacProcessor + +[[autodoc]] IsaacProcessor + +## IsaacImageProcessor + +[[autodoc]] IsaacImageProcessor diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 2a6dc23ba9d0..b3aebcd5cebe 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -133,6 +133,10 @@ def _build_checkpoint_conversion_mapping(): ), WeightRenaming(source_patterns=r"^visual", target_patterns="model.visual"), ], + "isaac": [ + WeightRenaming(source_patterns=r"text_model", target_patterns="language_model"), + WeightRenaming(source_patterns=r"vision_tower", target_patterns="visual"), + ], "colqwen2": [ WeightRenaming(source_patterns=r"vlm.model", target_patterns="vlm"), WeightRenaming(source_patterns=r"vlm(?!\.(language_model|visual))", target_patterns="vlm.language_model"), diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 81ff067b1470..df5f8bc4958d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1003,6 +1003,33 @@ class EmbeddingAccessMixin: _input_embed_layer = "embed_tokens" # default layer that holds input embeddings. + def _resolve_input_embed_layer(self) -> tuple[nn.Module | None, str]: + """ + Returns the parent module and leaf attribute for `_input_embed_layer`. + + Supports both a simple attribute name such as `embed_tokens` and a dotted path such as + `text_model.embed_tokens`. + """ + + name = getattr(self, "_input_embed_layer", "embed_tokens") + if "." not in name: + return None, name + + module_path, _, attribute_name = name.rpartition(".") + try: + module = self.get_submodule(module_path) + except AttributeError as error: + raise NotImplementedError( + f"`_input_embed_layer={name}` could not be resolved for {self.__class__.__name__}." + ) from error + + if not hasattr(module, attribute_name): + raise NotImplementedError( + f"`_input_embed_layer={name}` could not be resolved for {self.__class__.__name__}." + ) + + return module, attribute_name + def get_input_embeddings(self) -> nn.Module: """ Returns the model's input embeddings. @@ -1011,7 +1038,9 @@ def get_input_embeddings(self) -> nn.Module: `nn.Module`: A torch module mapping vocabulary to hidden states. """ - name = getattr(self, "_input_embed_layer", "embed_tokens") + module, name = self._resolve_input_embed_layer() + if module is not None: + return getattr(module, name) # 1) Direct attribute (most NLP models). if (default_embedding := getattr(self, name, None)) is not None: @@ -1044,7 +1073,11 @@ def set_input_embeddings(self, value: nn.Module): should) override for exotic layouts. """ - name = getattr(self, "_input_embed_layer", "embed_tokens") + module, name = self._resolve_input_embed_layer() + if module is not None: + setattr(module, name, value) + return + # 1) Direct attribute (most NLP models) if hasattr(self, name): setattr(self, name, value) diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 989be9eb114e..df538531c3ba 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -201,6 +201,7 @@ from .instructblip import * from .instructblipvideo import * from .internvl import * + from .isaac import * from .jais2 import * from .jamba import * from .janus import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 2c0fe88d0e74..95b8378c0861 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -238,6 +238,8 @@ ("instructblipvideo", "InstructBlipVideoConfig"), ("internvl", "InternVLConfig"), ("internvl_vision", "InternVLVisionConfig"), + ("isaac", "IsaacConfig"), + ("isaac_vision", "IsaacVisionConfig"), ("jais2", "Jais2Config"), ("jamba", "JambaConfig"), ("janus", "JanusConfig"), @@ -758,6 +760,8 @@ ("instructblipvideo", "InstructBlipVideo"), ("internvl", "InternVL"), ("internvl_vision", "InternVLVision"), + ("isaac", "Isaac"), + ("isaac_vision", "IsaacVision"), ("jais2", "Jais2"), ("jamba", "Jamba"), ("janus", "Janus"), @@ -1109,6 +1113,7 @@ ("gemma4_audio", "gemma4"), ("gemma4_text", "gemma4"), ("gemma4_vision", "gemma4"), + ("isaac_vision", "isaac"), ("glm4v_vision", "glm4v"), ("glm4v_moe_vision", "glm4v_moe"), ("glm4v_text", "glm4v"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 4ada3ba6b8ed..febd934a38ab 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -145,6 +145,7 @@ ("imagegpt", {"torchvision": "ImageGPTImageProcessor", "pil": "ImageGPTImageProcessorPil"}), ("instructblip", {"torchvision": "BlipImageProcessor", "pil": "BlipImageProcessorPil"}), ("internvl", {"torchvision": "GotOcr2ImageProcessor", "pil": "GotOcr2ImageProcessorPil"}), + ("isaac", {"torchvision": "IsaacImageProcessor"}), ("janus", {"torchvision": "JanusImageProcessor", "pil": "JanusImageProcessorPil"}), ("kosmos-2", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}), ("kosmos-2.5", {"torchvision": "Kosmos2_5ImageProcessor", "pil": "Kosmos2_5ImageProcessorPil"}), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index d4cb17cddfa6..120668f4ac8d 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -235,6 +235,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("instructblipvideo", "InstructBlipVideoModel"), ("internvl", "InternVLModel"), ("internvl_vision", "InternVLVisionModel"), + ("isaac", "IsaacModel"), + ("isaac_vision", "IsaacVisionModel"), ("jais2", "Jais2Model"), ("jamba", "JambaModel"), ("janus", "JanusModel"), @@ -990,6 +992,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("instructblip", "InstructBlipForConditionalGeneration"), ("instructblipvideo", "InstructBlipVideoForConditionalGeneration"), ("internvl", "InternVLForConditionalGeneration"), + ("isaac", "IsaacForConditionalGeneration"), ("janus", "JanusForConditionalGeneration"), ("kosmos-2", "Kosmos2ForConditionalGeneration"), ("kosmos-2.5", "Kosmos2_5ForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 262480b71485..61cf836d7c56 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -100,6 +100,7 @@ ("instructblip", "InstructBlipProcessor"), ("instructblipvideo", "InstructBlipVideoProcessor"), ("internvl", "InternVLProcessor"), + ("isaac", "IsaacProcessor"), ("janus", "JanusProcessor"), ("kosmos-2", "Kosmos2Processor"), ("kosmos-2.5", "Kosmos2_5Processor"), diff --git a/src/transformers/models/isaac/__init__.py b/src/transformers/models/isaac/__init__.py new file mode 100644 index 000000000000..bc0f3fcc6d7c --- /dev/null +++ b/src/transformers/models/isaac/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_isaac import * + from .modeling_isaac import * + from .processing_isaac import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py new file mode 100644 index 000000000000..d5e0080d029a --- /dev/null +++ b/src/transformers/models/isaac/configuration_isaac.py @@ -0,0 +1,179 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/isaac/modular_isaac.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_isaac.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 Perceptron, Inc and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig, PretrainedConfig +from ...modeling_rope_utils import RopeParameters +from ...utils import auto_docstring + + +@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") +@strict +class IsaacVisionConfig(PreTrainedConfig): + r""" + num_patches (`int`, *optional*, defaults to 256): + The number of patches in the image with the size of (`patch_size`, `patch_size`). The image is resized to + fill a maximum of this number of patches while preserving the aspect ratio. If the resulting number of patches + is lower, the image is padded in the patch dimension. + pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1): + Spatial factor applied before pixel shuffle reduces the resolution. + """ + + model_type = "isaac_vision" + base_config_key = "vision_config" + + hidden_size: int = 768 + intermediate_size: int = 3072 + num_hidden_layers: int = 12 + num_attention_heads: int = 12 + num_channels: int = 3 + patch_size: int | list[int] | tuple[int, int] = 16 + hidden_act: str = "gelu_pytorch_tanh" + layer_norm_eps: float = 1e-6 + attention_dropout: float | int = 0.0 + + num_patches: int = 256 + pixel_shuffle_scale_factor: int = 1 + + +@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") +@strict +class IsaacTextConfig(PreTrainedConfig): + r""" + Example: + + ```python + >>> from transformers import IsaacTextConfig, IsaacTextModel + + >>> configuration = IsaacTextConfig() + >>> model = IsaacTextModel(configuration) + >>> configuration = model.config + ``` + """ + + model_type = "isaac_text" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `IsaacText` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + vocab_size: int = 151936 + hidden_size: int = 4096 + intermediate_size: int = 22016 + num_hidden_layers: int = 32 + num_attention_heads: int = 32 + num_key_value_heads: int | None = 32 + head_dim: int = 128 + hidden_act: str = "silu" + max_position_embeddings: int = 32768 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + use_cache: bool = True + tie_word_embeddings: bool = False + rope_parameters: RopeParameters | dict | None = None + attention_bias: bool = False + use_sliding_window: bool = False + max_window_layers: int = 28 + attention_dropout: float | int = 0.0 + pad_token_id: int | None = None + bos_token_id: int | None = None + eos_token_id: int | list[int] | None = None + ignore_keys_at_rope_validation = {"mrope_section", "mrope_interleaved"} + + def __post_init__(self, **kwargs): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + PretrainedConfig.__post_init__(self, **kwargs) + + +@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") +@strict +class IsaacConfig(PretrainedConfig): + r""" + vision_config (`IsaacVisionConfig` or `dict`, *optional*): + Configuration for the Isaac vision tower. Dictionaries are converted to [`IsaacVisionConfig`]. If unset, + the default [`IsaacVisionConfig`] is used. + text_config (`IsaacTextConfig` or `dict`, *optional*): + Configuration for the text backbone. Dictionaries are converted to [`IsaacTextConfig`]. + vision_rescale_factor (`float`, *optional*, defaults to 1 / 255): + Rescale factor applied by the image processor before normalization. + max_sequence_length (`int`, *optional*, defaults to 16384): + Maximum multimodal sequence length produced by the processor and expected by the model. + + Example: + + ```python + >>> from transformers import IsaacConfig, IsaacModel + + >>> configuration = IsaacConfig() + >>> model = IsaacModel(configuration) + >>> configuration = model.config + ``` + """ + + model_type = "isaac" + sub_configs = {"vision_config": IsaacVisionConfig, "text_config": IsaacTextConfig} + vision_config: IsaacVisionConfig | dict | None = None + text_config: IsaacTextConfig | dict | None = None + vision_rescale_factor: float = 1 / 255 + max_sequence_length: int = 16384 + + def __post_init__(self, **kwargs): + if isinstance(self.text_config, dict): + self.text_config = self.sub_configs["text_config"](**self.text_config) + elif self.text_config is None: + self.text_config = self.sub_configs["text_config"]() + elif not isinstance(self.text_config, IsaacTextConfig): + raise TypeError( + f"text_config must be a dict or an IsaacTextConfig instance, got {type(self.text_config).__name__}." + ) + + if isinstance(self.vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**self.vision_config) + elif self.vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + elif not isinstance(self.vision_config, IsaacVisionConfig): + raise TypeError( + f"vision_config must be a dict or an IsaacVisionConfig instance, got {type(self.vision_config).__name__}." + ) + + self.vision_rescale_factor = float(self.vision_rescale_factor) + super().__post_init__(**kwargs) + + +__all__ = ["IsaacConfig", "IsaacTextConfig", "IsaacVisionConfig"] diff --git a/src/transformers/models/isaac/image_processing_isaac.py b/src/transformers/models/isaac/image_processing_isaac.py new file mode 100644 index 000000000000..39750c1bd792 --- /dev/null +++ b/src/transformers/models/isaac/image_processing_isaac.py @@ -0,0 +1,375 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/isaac/modular_isaac.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_isaac.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 Perceptron, Inc and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Any + +from ... import TorchvisionBackend +from ...feature_extraction_utils import BatchFeature +from ...image_transforms import group_images_by_shape, reorder_images +from ...image_utils import ImageInput, PILImageResampling, SizeDict, make_nested_list_of_images +from ...processing_utils import ImagesKwargs, Unpack +from ...utils import TensorType, auto_docstring +from ...utils.constants import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD +from ...utils.import_utils import is_torch_available, requires + + +if is_torch_available(): + import torch + import torch.nn.functional as F + from torchvision.transforms.v2 import functional as tvF + + +# --------------------------------Isaac Image Processor-------------------------------- + + +class IsaacImageProcessorKwargs(ImagesKwargs, total=False): + """ + patch_size (`int`, *optional*): + Side length (in pixels) for square patches extracted from resized images. + max_num_patches (`int`, *optional*): + Upper bound on extracted patches per image after resizing. + min_num_patches (`int`, *optional*): + Lower bound on extracted patches per image after resizing. + pixel_shuffle_scale (`int`, *optional*): + Pixel-shuffle reduction factor applied in the vision tower. + """ + + patch_size: int + max_num_patches: int + min_num_patches: int + pixel_shuffle_scale: int + + +# Disable as it causes issues with torch.compile +@torch.compiler.disable +def torch_extract_patches(image_tensor, patch_height, patch_width): + """ + Extract patches from image tensor. Returns tensor of shape (batch, rows, columns, patch_height*patch_width*channels). + + Args: + image_tensor (`torch.Tensor`): + Image tensor of shape (batch, channels, height, width). + patch_height (`int`): + Height of patches to extract. + patch_width (`int`): + Width of patches to extract. + """ + batch_size, channels, height, width = image_tensor.shape + patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width)) + patches = patches.reshape(batch_size, channels, patch_height, patch_width, -1) + patches = patches.permute(0, 4, 2, 3, 1).reshape( + batch_size, height // patch_height, width // patch_width, channels * patch_height * patch_width + ) + return patches + + +def get_scaled_image_size( + scale: float, + original_size: int, + patch_size: int, + pixel_shuffle_scale: int, +) -> int: + scaled_size = scale * original_size + divisor = patch_size * pixel_shuffle_scale + scaled_size = math.ceil(scaled_size / divisor) * divisor + scaled_size = max(divisor, scaled_size) + return int(scaled_size) + + +def get_image_size_for_max_num_patches( + image_height: int, + image_width: int, + patch_size: int, + max_num_patches: int, + min_num_patches: int | None = None, + eps: float = 1e-5, + pixel_shuffle_scale: int = 1, +) -> tuple[int, int]: + r"""Compute a target resolution whose patch grid satisfies patching parametrization. + + Args: + image_height (`int`): + Height in pixels of the source image prior to any resizing. + image_width (`int`): + Width in pixels of the source image prior to any resizing. + patch_size (`int`): + Size of the square patch used by the vision encoder. + max_num_patches (`int`): + Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. + min_num_patches (`int`, *optional*): + Lower bound on the number of patches. When provided the image will be scaled up if necessary. + eps (`float`, *optional*, defaults to 1e-5): + Convergence tolerance for the internal binary search to determing the target dimensions. + pixel_shuffle_scale (`int`, *optional*, defaults to 1): + Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. + + Returns: + `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` + and respect both the maximum and optional minimum patch-count constraints. + """ + + # Ensure divisibility + divisor = patch_size * pixel_shuffle_scale + adjusted_height = math.ceil(image_height / divisor) * divisor + adjusted_height = max(divisor, adjusted_height) + adjusted_width = math.ceil(image_width / divisor) * divisor + adjusted_width = max(divisor, adjusted_width) + + num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) + + if min_num_patches is not None and num_patches < min_num_patches: + # Scale up via binary search to satisfy the minimum patch budget while + # preserving divisibility by patch_size * pixel_shuffle_scale. + scale_min, scale_max = 1.0, 100.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + num_patches = (target_height / patch_size) * (target_width / patch_size) + if num_patches >= min_num_patches: + scale_max = scale + else: + scale_min = scale + scale = scale_max + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + return target_height, target_width + elif num_patches <= max_num_patches: + return adjusted_height, adjusted_width + else: + # Scale down + scale_min, scale_max = eps / 10, 1.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + num_patches = (target_height / patch_size) * (target_width / patch_size) + if num_patches <= max_num_patches: + scale_min = scale + else: + scale_max = scale + scale = scale_min + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + return target_height, target_width + + +@auto_docstring +@requires(backends=("vision",)) +class IsaacImageProcessor(TorchvisionBackend): + model_input_names = ["pixel_values", "image_grid_thw"] + valid_kwargs = IsaacImageProcessorKwargs + + resample = PILImageResampling.BILINEAR + do_resize = True + do_center_crop = False + patch_size = 16 + max_num_patches = 256 + min_num_patches = None + pixel_shuffle_scale = 1 + do_pad = True + do_rescale = True + do_normalize = True + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + do_convert_rgb = True + disable_grouping = False + + def __init__(self, **kwargs: Unpack[IsaacImageProcessorKwargs]): + super().__init__(**kwargs) + + def _validate_preprocess_kwargs(self, **kwargs): + # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for. + kwargs.pop("do_resize", None) + return super()._validate_preprocess_kwargs(**kwargs) + + def _prepare_images_structure( + self, + images: ImageInput, + expected_ndims: int = 3, + ) -> ImageInput: + images = self.fetch_images(images) + return make_nested_list_of_images(images, expected_ndims=expected_ndims) + + def resize( + self, + image: torch.Tensor, + size: SizeDict, + **kwargs, + ) -> torch.Tensor: + if image.dtype == torch.uint8: + image = F.interpolate(image.float(), size=(size.height, size.width), mode="bilinear", align_corners=False) + return image.clamp(0, 255).round().to(torch.uint8) + return F.interpolate(image, size=(size.height, size.width), mode="bilinear", align_corners=False) + + def pack_images( + self, + vision_patches: list[list[torch.Tensor]], + vision_token_grids: list[list[torch.Tensor]], + ) -> dict[str, torch.Tensor | None]: + batch_size = len(vision_patches) + flat_patches = [patches for sample_patches in vision_patches for patches in sample_patches] + if len(flat_patches) == 0: + return {"pixel_values": None, "image_grid_thw": None} + + first_patch = flat_patches[0] + max_patches = max(patches.shape[0] for patches in flat_patches) + max_images = max((len(sample_patches) for sample_patches in vision_patches), default=0) + + patch_dim = first_patch.shape[-1] + tensors = { + "pixel_values": torch.zeros( + (batch_size, max_images, max_patches, patch_dim), + device=first_patch.device, + dtype=first_patch.dtype, + ), + "image_grid_thw": torch.zeros((batch_size, max_images, 3), device=first_patch.device, dtype=torch.long), + } + + for batch_idx, (sample_patches, sample_token_grids) in enumerate( + zip(vision_patches, vision_token_grids, strict=True) + ): + for image_slot, (patches, token_grid) in enumerate(zip(sample_patches, sample_token_grids, strict=True)): + patch_count = int(patches.shape[0]) + tensors["pixel_values"][batch_idx, image_slot, :patch_count] = patches + tensors["image_grid_thw"][batch_idx, image_slot, 0] = 1 + tensors["image_grid_thw"][batch_idx, image_slot, 1:] = token_grid + + return tensors + + def _preprocess( + self, + images: list[list[torch.Tensor]], + do_resize: bool, + resample: PILImageResampling | tvF.InterpolationMode | int | None, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: float | list[float] | None, + image_std: float | list[float] | None, + do_pad: bool, + patch_size: int, + max_num_patches: int, + min_num_patches: int, + pixel_shuffle_scale: int, + disable_grouping: bool | None, + return_tensors: str | TensorType | None, + **kwargs, + ) -> BatchFeature: + if all(len(sample_images) == 0 for sample_images in images): + return BatchFeature(data={"pixel_values": None, "image_grid_thw": None}, tensor_type=return_tensors) + + grouped_images, grouped_images_index = group_images_by_shape( + images, disable_grouping=disable_grouping, is_nested=True + ) + grouped_outputs = {} + for shape, stacked_images in grouped_images.items(): + grouped_batch_size, channels, original_height, original_width = stacked_images.shape + if do_resize: + target_height, target_width = get_image_size_for_max_num_patches( + original_height, + original_width, + patch_size, + max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + ) + image_batch = self.resize( + stacked_images, SizeDict(height=target_height, width=target_width), resample=resample + ) + else: + if (original_height % patch_size) or (original_width % patch_size): + raise ValueError( + f"Image dimensions (h={original_height}, w={original_width}) must be divisible by patch_size={patch_size} when resize is disabled; enable resizing or adjust the input resolution." + ) + image_batch, target_height, target_width = stacked_images, original_height, original_width + + image_batch = self.rescale_and_normalize( + image_batch, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + ) + + patches = torch_extract_patches(image_batch, patch_size, patch_size) + _, height_tokens, width_tokens, patch_dim = patches.shape + + token_grid = ( + torch.tensor([height_tokens, width_tokens], device=patches.device).long().expand(grouped_batch_size, 2) + ) + + if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): + raise ValueError( + f"Token grid (h={height_tokens}, w={width_tokens}) must be divisible by pixel_shuffle_scale={pixel_shuffle_scale};" + f" adjust resize/patch parameters or disable pixel shuffle." + ) + + grouped_outputs[shape] = ( + patches.reshape(grouped_batch_size, -1, patch_dim), + token_grid, + ) + + keys = ("vision_patches", "vision_token_grids") + nested_outputs = {} + for i, key in enumerate(keys): + nested_outputs[key] = reorder_images( + {shape: values[i] for shape, values in grouped_outputs.items()}, + dict(grouped_images_index), + is_nested=True, + ) + + if not do_pad: + raise ValueError("IsaacImageProcessor doesn't support `do_pad=False` mode.") + + tensors = self.pack_images( + vision_patches=nested_outputs["vision_patches"], + vision_token_grids=nested_outputs["vision_token_grids"], + ) + + return BatchFeature(data=tensors, tensor_type=return_tensors) + + def get_number_of_image_patches( + self, + image_height: int, + image_width: int, + images_kwargs: dict[str, Any] | None = None, + ) -> int: + images_kwargs = images_kwargs or {} + patch_size = images_kwargs.get("patch_size", self.patch_size) + max_num_patches = images_kwargs.get("max_num_patches", self.max_num_patches) + min_num_patches = images_kwargs.get("min_num_patches", self.min_num_patches) + pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale", self.pixel_shuffle_scale) + + target_height, target_width = get_image_size_for_max_num_patches( + image_height, + image_width, + patch_size, + max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + ) + return (target_height // patch_size) * (target_width // patch_size) + + +__all__ = ["IsaacImageProcessor"] diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py new file mode 100644 index 000000000000..1d293615cc37 --- /dev/null +++ b/src/transformers/models/isaac/modeling_isaac.py @@ -0,0 +1,1743 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/isaac/modular_isaac.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_isaac.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 Perceptron, Inc and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, NamedTuple, Optional + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation.utils import GenerationMixin +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func +from ...masking_utils import create_bidirectional_mask, create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + CausalLMOutputWithPast, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, torch_compilable_check +from ...utils.generic import TransformersKwargs, can_return_tuple, maybe_autocast, merge_with_config_defaults +from ...utils.import_utils import is_torch_available, is_torchdynamo_compiling +from ...utils.output_capturing import capture_outputs +from .configuration_isaac import IsaacConfig, IsaacTextConfig, IsaacVisionConfig + + +if is_torch_available(): + import torch + import torch.nn as nn + import torch.nn.functional as F + + +class IsaacVisionEmbeddings(nn.Module): + """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences. + + Isaac accepts variable-resolution vision inputs as a single packed sequence with per-image + `token_grids`; packing/unpacking here reconstructs per-image shapes so we can resize positional + embeddings and build `cu_seqlens` for variable-length attention (not generic generation packing). + """ + + def __init__(self, config: IsaacVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Linear( + in_features=config.num_channels * self.patch_size * self.patch_size, + out_features=self.embed_dim, + ) + + self.num_patches = config.num_patches + self.position_embedding_size = int(self.num_patches**0.5) + self.position_embedding = nn.Parameter( + torch.empty( + self.position_embedding_size, + self.position_embedding_size, + self.embed_dim, + ) + ) + + @staticmethod + def resize_positional_embeddings( + positional_embeddings: torch.Tensor, + spatial_shapes: torch.LongTensor, + max_length: int, + ) -> torch.Tensor: + """ + Resize positional embeddings to image-specific size and pad to a fixed size. + + Args: + positional_embeddings (`torch.Tensor`): + Position embeddings of shape (height, width, embed_dim) + spatial_shapes (`torch.LongTensor`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + max_length (`int`): + Maximum length of the positional embeddings to pad resized positional embeddings to + + Returns: + `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) + """ + batch_size = spatial_shapes.shape[0] + embed_dim = positional_embeddings.shape[-1] + source_dtype = positional_embeddings.dtype + + resulted_positional_embeddings = torch.empty( + (batch_size, max_length, embed_dim), + device=positional_embeddings.device, + dtype=source_dtype, + ) + + # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation + positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0) + + # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU + if positional_embeddings.device.type == "cpu": + positional_embeddings = positional_embeddings.to(torch.float32) + + for i in range(batch_size): + # (1, dim, height, width) -> (1, dim, target_height, target_width) + height, width = spatial_shapes[i].tolist() # will be itemized in F.interpolate either way + torch_compilable_check((width > 0), "Width of resized positional embeddings must be positive.") + torch_compilable_check((height > 0), "Height of resized positional embeddings must be positive.") + torch_compilable_check((height * width) <= max_length, "Resized positional embeddings exceed max_length.") + resized_embeddings = F.interpolate( + positional_embeddings, + size=(height, width), + mode="bilinear", + align_corners=False, + antialias=True, + ) + + # (1, dim, target_height, target_width) -> (target_height * target_width, dim) + resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1) + + # Cast to original dtype + resized_embeddings = resized_embeddings.to(source_dtype) + + resulted_positional_embeddings[i, : height * width] = resized_embeddings + resulted_positional_embeddings[i, height * width :] = resized_embeddings[0] + + return resulted_positional_embeddings + + def forward( + self, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor`): + Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size) + spatial_shapes (`list[tuple[int, int]]`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + """ + # pixel_values: (num_images, max_patches, patch_dim) + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) + + resized_positional_embeddings = self.resize_positional_embeddings( + self.position_embedding, + image_grid_thw[:, 1:], + max_length=pixel_values.shape[1], + ) + resized_positional_embeddings = resized_positional_embeddings.to( + device=patch_embeds.device, dtype=patch_embeds.dtype + ) + embeddings = patch_embeds + resized_positional_embeddings + + if attention_mask is not None: + embeddings = embeddings * attention_mask.unsqueeze(-1).to(device=embeddings.device, dtype=embeddings.dtype) + + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class IsaacVisionAttention(nn.Module): + """Custom attention that supports variable-length sequences with flash/SDPA backends.""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Input shape: Batch x Time x Channel""" + + input_shape = hidden_states.shape[:-1] + + hidden_shape = (*input_shape, -1, self.head_dim) + queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class IsaacMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class IsaacVisionEncoderLayer(GradientCheckpointingLayer): + """Isaac vision encoder layer using the shared attention interfaces.""" + + def __init__(self, config: IsaacVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = IsaacVisionAttention(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = IsaacMLP(config) + + @auto_docstring + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class IsaacVisionEncoder(nn.Module): + """Encoder using Isaac encoder layers.""" + + def __init__(self, config: IsaacVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + @auto_docstring + def forward( + self, + inputs_embeds, + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask, + **kwargs, + ) + + return BaseModelOutput(last_hidden_state=hidden_states) + + +@auto_docstring +class IsaacVisionModel(PreTrainedModel): + config: IsaacVisionConfig + _supports_sdpa = True + _supports_flash_attn = True + _can_record_outputs = { + "hidden_states": IsaacVisionEncoderLayer, + "attentions": IsaacVisionAttention, + } + + def __init__(self, config: IsaacVisionConfig): + super().__init__(config) + self.embeddings = IsaacVisionEmbeddings(config) + self.encoder = IsaacVisionEncoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor + + self.post_init() + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, IsaacVisionEmbeddings): + init.zeros_(module.position_embedding) + + def pixel_shuffle_padded( + self, + hidden_states: torch.Tensor, + token_grids: torch.Tensor, + ) -> torch.Tensor: + """Apply pixel shuffle per image on padded batched vision embeddings. + + Args: + hidden_states (`torch.Tensor`): + Vision embeddings of shape `(num_images, max_patches, hidden_size)`. + token_grids (`torch.Tensor`): + Grid sizes `(height, width)` per image, shape `(num_images, 2)`. + + Returns: + `torch.Tensor`: Pixel-shuffled embeddings of shape + `(num_images, max_tokens, hidden_size * scale_factor**2)`. + """ + scale_factor = self.pixel_shuffle_scale_factor + num_images, max_patches, embed_dim = hidden_states.shape + output_dim = embed_dim * scale_factor * scale_factor + + token_grids = token_grids.to(device=hidden_states.device, dtype=torch.long) + heights = token_grids[:, 0] + widths = token_grids[:, 1] + full_lengths = heights * widths + + non_empty = full_lengths > 0 + if not is_torchdynamo_compiling(): + divisible = ((heights % scale_factor) == 0) & ((widths % scale_factor) == 0) + torch_compilable_check( + (~non_empty) | divisible, + f"Every non-empty (H, W) grid must be divisible by pixel_shuffle_scale={scale_factor}.", + ) + + output_lengths = (heights // scale_factor) * (widths // scale_factor) + max_output_tokens = output_lengths.max() + shuffled_4d = hidden_states.new_zeros((num_images, max_output_tokens, scale_factor * scale_factor, embed_dim)) + + token_positions = ( + torch.arange(max_patches, device=hidden_states.device, dtype=torch.long) + .unsqueeze(0) + .expand(num_images, -1) + ) + valid_token_mask = token_positions < full_lengths.unsqueeze(1) + + safe_widths = torch.where(widths > 0, widths, torch.ones_like(widths)) + row_index = torch.div(token_positions, safe_widths.unsqueeze(1), rounding_mode="floor") + col_index = token_positions.remainder(safe_widths.unsqueeze(1)) + + output_widths = widths.div(scale_factor, rounding_mode="floor") + output_index = row_index.div(scale_factor, rounding_mode="floor") * output_widths.unsqueeze(1) + output_index = output_index + col_index.div(scale_factor, rounding_mode="floor") + sub_index = row_index.remainder(scale_factor) * scale_factor + col_index.remainder(scale_factor) + + batch_index = ( + torch.arange(num_images, device=hidden_states.device, dtype=torch.long) + .unsqueeze(1) + .expand_as(token_positions) + ) + shuffled_4d[batch_index[valid_token_mask], output_index[valid_token_mask], sub_index[valid_token_mask]] = ( + hidden_states[valid_token_mask] + ) + + shuffled = shuffled_4d.view(num_images, max_output_tokens, output_dim) + return shuffled + + @merge_with_config_defaults + @capture_outputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + r""" + image_grid_thw (`torch.LongTensor`, *optional*): + Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. + """ + full_lengths = image_grid_thw[:, 1] * image_grid_thw[:, 2] + token_positions = torch.arange(pixel_values.shape[1], device=pixel_values.device, dtype=torch.long) + image_patch_attention_mask = token_positions.unsqueeze(0) < full_lengths.unsqueeze(1) + image_patch_attention_mask = image_patch_attention_mask.to(dtype=torch.long) + hidden_states = self.embeddings( + pixel_values, + image_grid_thw, + attention_mask=image_patch_attention_mask, + ) + + encoder_attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=hidden_states, + attention_mask=image_patch_attention_mask, + ) + encoder_outputs = self.encoder(inputs_embeds=hidden_states, attention_mask=encoder_attention_mask, **kwargs) + hidden_states = self.post_layernorm(encoder_outputs.last_hidden_state) + + hidden_states = self.pixel_shuffle_padded( + hidden_states=hidden_states, + token_grids=image_grid_thw[:, 1:], + ) + + return BaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=None, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class IsaacRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: IsaacTextConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + self.mrope_section = config.rope_parameters.get("mrope_section") + if self.mrope_section is None: + weights = (2, 1, 1) + self.mrope_section = [self.inv_freq.shape[0] * w // sum(weights) for w in weights] + self.mrope_section[0] += self.inv_freq.shape[0] - sum(self.mrope_section) + + @staticmethod + def compute_default_rope_parameters( + config: IsaacTextConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Isaac has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def apply_interleaved_mrope(self, freqs: torch.Tensor, mrope_section: list[int]) -> torch.Tensor: + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THWTHWTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + chunks = freqs.split(tuple(mrope_section), dim=-1) + return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) + + +@use_kernel_forward_from_hub("RMSNorm") +class IsaacTextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + IsaacTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +@use_kernelized_func(apply_rotary_pos_emb) +class IsaacTextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: IsaacTextConfig, layer_idx: int): + super().__init__() + self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = IsaacTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = IsaacTextRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # thus post q_norm does not need reshape + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class IsaacTextMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class IsaacTextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: IsaacTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = IsaacTextAttention(config=config, layer_idx=layer_idx) + + self.mlp = IsaacTextMLP(config) + self.input_layernorm = IsaacTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = IsaacTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class IsaacVisionRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +@auto_docstring +class IsaacPreTrainedModel(PreTrainedModel): + config: IsaacConfig + base_model_prefix = "model" + input_modalities = ("image", "video", "text") + supports_gradient_checkpointing = True + _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": IsaacTextDecoderLayer, + "attentions": IsaacTextAttention, + } + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, IsaacVisionRotaryEmbedding): + inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim)) + init.copy_(module.inv_freq, inv_freq) + + +@auto_docstring( + custom_intro=( + "Text part of Isaac, " + "not a pure text-only model, as DeepStack integrates visual features into the early hidden states." + ) +) +class IsaacTextModel(IsaacPreTrainedModel): + config: IsaacTextConfig + input_modalities = ("text",) + _no_split_modules = ["IsaacTextDecoderLayer"] + + def __init__(self, config: IsaacTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [IsaacTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = IsaacTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = IsaacRotaryEmbedding(config=config, device=self.device) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + # args for deepstack + visual_pos_masks: torch.Tensor | None = None, + deepstack_visual_embeds: list[torch.Tensor] | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple | BaseModelOutputWithPast: + r""" + visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*): + The mask of the visual positions. + deepstack_visual_embeds (`list[torch.Tensor]`, *optional*): + The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim). + The feature is extracted from the different visual encoder layers, and fed to the decoder + hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334). + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # the hard coded `4` is for text, temporal, height and width. + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.view(1, 1, -1).expand(4, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(4, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = None + + attention_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + for layer_idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=text_position_ids, + past_key_values=past_key_values, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = layer_outputs + + # add visual features to the hidden states of first several layers + if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + def _deepstack_process( + self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor + ): + visual_pos_masks = visual_pos_masks.to(hidden_states.device) + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + hidden_states = hidden_states.clone() + local_this = hidden_states[visual_pos_masks, :] + visual_embeds + hidden_states[visual_pos_masks, :] = local_this + return hidden_states + + +class IsaacMultiModalProjector(nn.Module): + """Maps vision tower outputs to the text hidden size with a SiLU MLP.""" + + def __init__(self, config: IsaacConfig): + super().__init__() + text_config = config.get_text_config() + vision_hidden_size = config.vision_config.hidden_size * (config.vision_config.pixel_shuffle_scale_factor**2) + backbone_hidden_size = text_config.hidden_size + self.linear_1 = nn.Linear(vision_hidden_size, 4 * vision_hidden_size, bias=False) + self.silu = nn.SiLU() + self.linear_2 = nn.Linear(4 * vision_hidden_size, backbone_hidden_size, bias=False) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.silu(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +@auto_docstring +class IsaacModel(IsaacPreTrainedModel): + base_model_prefix = "model" + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: IsaacConfig + _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionEncoderLayer"] + input_modalities = ("image", "text") + supports_gradient_checkpointing = True + _can_compile_fullgraph = False + _supports_flex_attn = False + _tied_weights_keys = {} + _input_embed_layer = "language_model.embed_tokens" + + def __init__(self, config: IsaacConfig): + super().__init__(config) + self.language_model = IsaacTextModel._from_config(config.text_config) + self.visual = IsaacVisionModel(config.vision_config) + self.multimodal_projector = IsaacMultiModalProjector(config) + self.max_sequence_length = config.max_sequence_length + self.vision_rescale_factor = config.vision_rescale_factor + self.rope_deltas = None + + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_vision_position_ids( + self, + start_position: int, + grid_thw: torch.LongTensor, + image_metadata: torch.LongTensor, + ) -> torch.LongTensor: + """ + Compute 3D positional indices for vision tokens derived from a single image or video input. + + The positions are generated from the input grid defined by temporal (T), height (H), and + width (W) dimensions. Temporal and spatial dimensions can be downscaled according to the + merge sizes used in the vision backbone. The resulting positions are offset by `start_position`. + + Args: + start_position (`int`): + Offset added to all computed positional indices. + grid_thw (`Sequence[int]` or `torch.Tensor` of shape `(3,)`): + The (T, H, W) grid representing the feature layout of the current image or video after patch embedding. + temp_merge_size (`int`, *optional*): + Factor by which the temporal dimension is reduced in the backbone. The temporal grid size is divided + by this value. Defaults to 1. + spatial_merge_size (`int`, *optional*): + Factor by which the spatial dimensions (H and W) are reduced in the backbone. Both H and W are divided + by this value. Defaults to 1. + time_interval (`int`, *optional*): + Spacing factor applied between consecutive temporal position indices.Defaults to 1. + device (`str` or `torch.device`, *optional*): + Device on which the resulting tensor is allocated. If `None`, uses the current default device. + + Returns: + torch.LongTensor of shape (3, sequence_length): + Positional indices for temporal, height, and width dimensions, + flattened into sequence form and offset by `start_position`. + """ + pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor + height = grid_thw[1].div(pixel_shuffle_scale, rounding_mode="floor").item() + width = grid_thw[2].div(pixel_shuffle_scale, rounding_mode="floor").item() + token_positions = torch.arange(height * width, device=grid_thw.device, dtype=torch.long) + vision_position_ids = torch.stack( + ( + torch.full((token_positions.shape[0],), start_position, device=grid_thw.device, dtype=torch.long), + token_positions.div(width, rounding_mode="floor"), + token_positions.remainder(width), + ), + dim=0, + ) + token_offset = int(image_metadata[0].item()) + token_length = int(image_metadata[1].item()) + return vision_position_ids[:, token_offset : token_offset + token_length] + + def get_rope_index( + self, + input_ids: torch.LongTensor, + mm_token_type_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + image_metadata: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Difference from Qwen2VL/Qwen2.5VL's get_rope_index: + - Since Qwen3.5 use timestamps to seperate videos, like , the video_grid_thw should also be split too. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + mm_token_type_ids (`torch.IntTensor` of shape `(batch_size, sequence_length)`): + Token type ids matching each modality to a different value in the input sequence, i.e. text (0), image (1), video (2). + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + if attention_mask is None: + if input_ids is None: + attention_mask = mm_token_type_ids.new_ones(mm_token_type_ids.shape, dtype=torch.long) + else: + attention_mask = input_ids.new_ones(input_ids.shape, dtype=torch.long) + + if input_ids is None: + batch_size, seq_len = attention_mask.shape + position_dtype = torch.long + else: + batch_size, seq_len = input_ids.shape + position_dtype = input_ids.dtype + + device = attention_mask.device + mm_token_type_ids = mm_token_type_ids.to(dtype=torch.long) + image_grid_thw = image_grid_thw.to(dtype=torch.long) + image_metadata = image_metadata.to(dtype=torch.long) + attention_mask = attention_mask.to(dtype=torch.long) + active_slot_mask = image_grid_thw[..., 0].eq(1) + + position_ids = torch.zeros((3, batch_size, seq_len), device=device, dtype=position_dtype) + rope_deltas = torch.zeros((batch_size, 1), device=device, dtype=torch.long) + + for batch_idx in range(batch_size): + sample_attention_mask = attention_mask[batch_idx].bool() + sample_token_types = mm_token_type_ids[batch_idx][sample_attention_mask] + sample_grids = image_grid_thw[batch_idx] + sample_metadata = image_metadata[batch_idx] + sample_active_slots = active_slot_mask[batch_idx] + + current_pos = 0 + image_idx = 0 + seq_pos = 0 + llm_pos_ids_list = [] + + while seq_pos < sample_token_types.shape[0]: + modality_type = int(sample_token_types[seq_pos].item()) + if modality_type == 0: + group_end = seq_pos + 1 + while group_end < sample_token_types.shape[0] and sample_token_types[group_end] == 0: + group_end += 1 + group_length = group_end - seq_pos + llm_pos_ids_list.append( + torch.arange(group_length, device=device, dtype=torch.long).view(1, -1).expand(3, -1) + + current_pos + ) + current_pos += group_length + seq_pos = group_end + else: + while image_idx < sample_metadata.shape[0] and ( + not bool(sample_active_slots[image_idx].item()) or sample_metadata[image_idx, 1].item() == 0 + ): + image_idx += 1 + torch_compilable_check( + image_idx < sample_metadata.shape[0], + "Isaac multimodal sequence has more visible image tokens than batch-major image metadata slots.", + ) + token_length = int(sample_metadata[image_idx, 1].item()) + torch_compilable_check( + token_length <= sample_token_types.shape[0] - seq_pos, + "Isaac image metadata length exceeds the remaining multimodal placeholder span.", + ) + llm_pos_ids_list.append( + self.get_vision_position_ids(current_pos, sample_grids[image_idx], sample_metadata[image_idx]) + ) + current_pos += 1 + seq_pos += token_length + image_idx += 1 + + llm_positions = ( + torch.cat(llm_pos_ids_list, dim=1) + if llm_pos_ids_list + else torch.zeros((3, 0), device=device, dtype=torch.long) + ) + position_ids[:, batch_idx, sample_attention_mask] = llm_positions + rope_deltas[batch_idx, 0] = ( + llm_positions.max() + 1 - sample_token_types.shape[0] if llm_positions.numel() > 0 else 0 + ) + + return position_ids, rope_deltas + + @can_return_tuple + @auto_docstring + def get_image_features( + self, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, + image_metadata: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values (`torch.FloatTensor`, *optional*): + Batch-major patch vectors shaped `(batch_size, max_images, max_patches, patch_dim)`. + image_grid_thw (`torch.LongTensor`, *optional*): + Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. + image_metadata (`torch.Tensor`, *optional*): + Batch-major per-slot metadata `(offset, length)` shaped `(batch_size, max_images, 2)`. + """ + active_slot_mask = image_grid_thw[..., 0].eq(1) + flat_pixel_values = pixel_values[active_slot_mask] + flat_image_grid_thw = image_grid_thw[active_slot_mask] + + vision_outputs: BaseModelOutputWithPooling = self.visual( + pixel_values=flat_pixel_values, + image_grid_thw=flat_image_grid_thw, + return_dict=True, + **kwargs, + ) + projected_features = self.multimodal_projector(vision_outputs.last_hidden_state) + + # Truncate image features using offset and length + if image_metadata is None: + pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor + downsampled_height = flat_image_grid_thw[:, 1].div(pixel_shuffle_scale, rounding_mode="floor") + downsampled_width = flat_image_grid_thw[:, 2].div(pixel_shuffle_scale, rounding_mode="floor") + lengths = downsampled_height * downsampled_width + offsets = torch.zeros_like(lengths) + else: + torch_compilable_check( + image_metadata.shape[:2] == image_grid_thw.shape[:2], + "IsaacModel.get_image_features expects batch-major metadata aligned with `image_grid_thw`.", + ) + offsets = image_metadata[active_slot_mask][:, 0] + lengths = image_metadata[active_slot_mask][:, 1] + + image_features = tuple( + projected_features[image_idx, offset : offset + length] + for image_idx, (offset, length) in enumerate(zip(offsets.tolist(), lengths.tolist(), strict=True)) + ) + + return BaseModelOutputWithPooling( + last_hidden_state=projected_features, + pooler_output=image_features, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + ) + + def get_placeholder_mask( + self, + mm_token_type_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor, + ) -> torch.BoolTensor: + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + image_token_mask = mm_token_type_ids.to(dtype=torch.long) == 1 + n_image_tokens = image_token_mask.sum() + image_token_mask = image_token_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[image_token_mask].numel() == image_features.numel(), + f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", + ) + return image_token_mask + + def compute_3d_position_ids( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor | None, + mm_token_type_ids: torch.Tensor | None = None, + image_grid_thw: torch.Tensor | None = None, + image_metadata: torch.Tensor | None = None, + past_key_values: Cache | None = None, + ) -> torch.Tensor: + past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() + has_multimodal = ( + image_grid_thw is not None + and image_metadata is not None + and bool(image_grid_thw[..., 0].eq(1).any().item()) + ) + if has_multimodal and mm_token_type_ids is None and input_ids is not None: + raise ValueError( + "Multimodal data was passed (via `image_grid_thw` or `image_metadata`) but `mm_token_type_ids` is " + "missing. Please pass `mm_token_type_ids` to the model so that multimodal RoPE (M-RoPE) can be " + "computed correctly. `mm_token_type_ids` is returned by the processor alongside `input_ids`." + ) + + if has_multimodal and past_seen_tokens == 0: + position_ids, rope_deltas = self.get_rope_index( + input_ids=input_ids, + mm_token_type_ids=mm_token_type_ids, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, + attention_mask=attention_mask, + ) + self.rope_deltas = rope_deltas + return position_ids + + if self.rope_deltas is None: + return None + + rope_deltas = torch.as_tensor(self.rope_deltas, device=inputs_embeds.device, dtype=torch.long).reshape(-1, 1) + if rope_deltas.shape[0] != inputs_embeds.shape[0]: + if inputs_embeds.shape[0] % rope_deltas.shape[0] == 0: + rope_deltas = rope_deltas.repeat_interleave(inputs_embeds.shape[0] // rope_deltas.shape[0], dim=0) + else: + rope_deltas = rope_deltas[:1].expand(inputs_embeds.shape[0], -1) + + if attention_mask is not None and attention_mask.shape[-1] > inputs_embeds.shape[1]: + rope_position = attention_mask.long().cumsum(dim=-1) - 1 + rope_position = rope_position.masked_fill(attention_mask == 0, 0) + rope_position = rope_position[:, -inputs_embeds.shape[1] :] + else: + rope_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + dtype=torch.long, + ).view(1, -1) + rope_position = rope_position.expand(inputs_embeds.shape[0], -1) + + position_ids = rope_position.view(1, inputs_embeds.shape[0], -1).expand(3, -1, -1) + return position_ids + rope_deltas.to(device=inputs_embeds.device).unsqueeze(0) + + @auto_docstring( + custom_intro=""" + Forward pass with multimodal MRoPE position ids. + + When image placeholders are present, Isaac computes vision features, scatters them into the token + embeddings, and runs the shared text backbone on the mixed sequence. + """, + ) + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor | None = None, + mm_token_type_ids: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + image_metadata: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPast: + r""" + mm_token_type_ids (`torch.LongTensor`, *optional*): + Multimodal token type ids aligned with the token sequence, using `0 -> text` and `1 -> image`. + pixel_values (`torch.FloatTensor`, *optional*): + Batch-major patch vectors shaped `(batch_size, max_images, max_patches, patch_dim)`. + image_grid_thw (`torch.LongTensor`, *optional*): + Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. + image_metadata (`torch.LongTensor`, *optional*): + Batch-major per-slot metadata shaped `(batch_size, max_images, 2)` with `(offset, length)`. + """ + if (input_ids is None) == (inputs_embeds is None): + raise ValueError("You must specify exactly one of `input_ids` or `inputs_embeds`.") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_mask = None + if pixel_values is not None and image_grid_thw is not None: + image_outputs = self.get_image_features( + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, + return_dict=True, + ) + image_embeds = image_outputs.pooler_output + if len(image_embeds) > 0: + image_embeds = torch.cat(image_embeds, dim=0).to( + device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + image_mask = self.get_placeholder_mask( + mm_token_type_ids=mm_token_type_ids, + inputs_embeds=inputs_embeds, + image_features=image_embeds, + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if isinstance(attention_mask, dict): + attention_mask = attention_mask["full_attention"] + + past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() + computed_position_ids = self.compute_3d_position_ids( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + mm_token_type_ids=mm_token_type_ids, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, + attention_mask=attention_mask, + past_key_values=past_key_values, + ) + if computed_position_ids is not None: + position_ids = computed_position_ids + elif past_seen_tokens > 0: + position_ids = None + elif position_ids is not None and past_seen_tokens == 0: + position_ids = position_ids.to(device=inputs_embeds.device) + if position_ids.ndim == 2: + position_ids = position_ids.view(1, position_ids.shape[0], -1).expand(3, -1, -1) + + outputs = self.language_model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + visual_pos_masks=image_mask[..., 0] if image_mask is not None else None, + deepstack_visual_embeds=None, + use_cache=use_cache, + **kwargs, + ) + + return BaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@dataclass +class IsaacCausalLMOutputWithPast(CausalLMOutputWithPast): + """ + Base class for Isaac causal language model (or autoregressive) outputs. + + Args: + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + rope_deltas: torch.LongTensor | None = None + + +@dataclass +@auto_docstring +class BaseModelOutputWithDeepstackFeatures(BaseModelOutputWithPooling): + r""" + deepstack_features (`List[torch.FloatTensor]`, *optional*): + List of hidden-states (feature maps) from deepstack layers. + """ + + deepstack_features: list[torch.FloatTensor] | None = None + + +@auto_docstring +class IsaacForConditionalGeneration(IsaacPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: IsaacConfig + config_class = IsaacConfig + input_modalities = ("image", "text") + _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionEncoderLayer"] + _can_compile_fullgraph = False + + def __init__(self, config: IsaacConfig): + super().__init__(config) + self.model = IsaacModel(config) + self.vocab_size = config.get_text_config().vocab_size + self.lm_head = nn.Linear(config.get_text_config().hidden_size, config.get_text_config().vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @auto_docstring + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithDeepstackFeatures: + r""" + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + return self.model.get_video_features( + pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs + ) + + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithDeepstackFeatures: + r""" + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + mm_token_type_ids: torch.IntTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | IsaacCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + + Example: + + ```python + >>> from transformers import AutoProcessor, IsaacForConditionalGeneration + + >>> model = IsaacForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL-8B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", + }, + {"type": "text", "text": "Describe the image."}, + ], + } + ] + + >>> inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt" + ) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=1024) + >>> generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] + >>> output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + >>> print(output_text) + ``` + """ + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + mm_token_type_ids=mm_token_type_ids, + **kwargs, + ) + + hidden_states = outputs[0] + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return IsaacCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.model.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values: Cache = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + mm_token_type_ids: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + image_metadata: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + is_first_iteration: bool = False, + use_cache: bool = True, + **kwargs, + ) -> dict[str, Any]: + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + is_first_iteration=is_first_iteration, + use_cache=use_cache, + **kwargs, + ) + + multimodal_inputs = { + "mm_token_type_ids": mm_token_type_ids, + "pixel_values": pixel_values, + "image_grid_thw": image_grid_thw, + "image_metadata": image_metadata, + } + is_prefill = is_first_iteration or not use_cache + for key, value in multimodal_inputs.items(): + model_inputs[key] = value if is_prefill else None + if model_inputs["mm_token_type_ids"] is not None: + sequence_length = None + if model_inputs.get("input_ids") is not None: + sequence_length = model_inputs["input_ids"].shape[1] + elif model_inputs.get("inputs_embeds") is not None: + sequence_length = model_inputs["inputs_embeds"].shape[1] + + if sequence_length is not None: + current_length = model_inputs["mm_token_type_ids"].shape[1] + if current_length < sequence_length: + padding = model_inputs["mm_token_type_ids"].new_zeros( + (model_inputs["mm_token_type_ids"].shape[0], sequence_length - current_length) + ) + model_inputs["mm_token_type_ids"] = torch.cat([model_inputs["mm_token_type_ids"], padding], dim=1) + elif current_length > sequence_length: + model_inputs["mm_token_type_ids"] = model_inputs["mm_token_type_ids"][:, -sequence_length:] + + return model_inputs + + def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs): + text_positions = super()._prepare_position_ids_for_generation(inputs_tensor, model_kwargs) + + past_length = 0 + if (cache := model_kwargs.get("past_key_values")) is not None: + past_length = cache.get_seq_length() + if past_length != 0 and self.model.rope_deltas is not None: + return text_positions[None, ...] + self.model.rope_deltas + + if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0: + inputs_tensor = model_kwargs["input_ids"] + + is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long] + if ( + is_input_ids + and model_kwargs.get("mm_token_type_ids") is not None + and model_kwargs.get("image_grid_thw") is not None + and model_kwargs.get("image_metadata") is not None + ): + model_kwargs = {k: v for k, v in model_kwargs.items() if k != "input_ids"} + vision_positions, rope_deltas = self.model.get_rope_index(inputs_tensor, **model_kwargs) + self.model.rope_deltas = rope_deltas + else: + vision_positions = text_positions.unsqueeze(0).expand(3, -1, -1) + self.model.rope_deltas = torch.zeros( + inputs_tensor.shape[0], 1, dtype=torch.long, device=inputs_tensor.device + ) + + return torch.cat([text_positions[None, ...], vision_positions], dim=0) + + def _get_image_nums_and_video_nums( + self, + input_ids: torch.LongTensor | None, + inputs_embeds: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + if inputs_embeds is not None: + vision_start_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + image_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + video_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + else: + vision_start_mask = input_ids == vision_start_token_id + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: torch.LongTensor | None = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + position_ids = model_kwargs.pop("position_ids", None) + if expand_size == 1: + if position_ids is not None: + model_kwargs["position_ids"] = position_ids + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "image_metadata"] + for key in visual_keys: + value = model_kwargs.get(key) + if value is not None: + model_kwargs[key] = value.repeat_interleave(expand_size, dim=0) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + for key, value in list(model_kwargs.items()): + if key == "position_ids" and value is not None and value.ndim == 3: + model_kwargs[key] = value.repeat_interleave(expand_size, dim=1) + elif value is not None and isinstance(value, torch.Tensor) and key not in visual_keys: + model_kwargs[key] = value.repeat_interleave(expand_size, dim=0) + + if position_ids is not None: + dim = 1 if position_ids.ndim == 3 else 0 + model_kwargs["position_ids"] = position_ids.repeat_interleave(expand_size, dim=dim) + return input_ids, model_kwargs + + +class SinglePoint(NamedTuple): + x: int + y: int + mention: str | None = None + t: float | None = None + + +class BoundingBox(NamedTuple): + top_left: Any + bottom_right: Any + mention: str | None = None + t: float | None = None + + +class Polygon(NamedTuple): + points: tuple[Any, ...] + mention: str | None = None + t: float | None = None + + +__all__ = ["IsaacTextModel", "IsaacVisionModel", "IsaacModel", "IsaacPreTrainedModel", "IsaacForConditionalGeneration"] diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py new file mode 100644 index 000000000000..5d2935064980 --- /dev/null +++ b/src/transformers/models/isaac/modular_isaac.py @@ -0,0 +1,1642 @@ +# Copyright 2026 Perceptron, Inc and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +import re +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, NamedTuple + +from huggingface_hub.dataclasses import strict + +from ... import TorchvisionBackend +from ... import initialization as init +from ...cache_utils import Cache +from ...configuration_utils import PretrainedConfig +from ...feature_extraction_utils import BatchFeature +from ...generation.utils import GenerationMixin +from ...image_transforms import group_images_by_shape, reorder_images +from ...image_utils import ImageInput, PILImageResampling, SizeDict, make_nested_list_of_images +from ...masking_utils import create_bidirectional_mask +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...models.qwen3.configuration_qwen3 import Qwen3Config +from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack +from ...utils import TensorType, auto_docstring, torch_compilable_check +from ...utils.constants import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD +from ...utils.generic import TransformersKwargs, can_return_tuple, merge_with_config_defaults +from ...utils.import_utils import ( + is_torch_available, + is_torchdynamo_compiling, + is_torchvision_available, + requires, +) +from ...utils.output_capturing import capture_outputs +from ..qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLForConditionalGeneration, + Qwen3VLModel, + Qwen3VLTextDecoderLayer, + Qwen3VLTextModel, + Qwen3VLTextRotaryEmbedding, +) +from ..siglip2.configuration_siglip2 import Siglip2VisionConfig +from ..siglip2.modeling_siglip2 import ( + Siglip2Attention, + Siglip2Encoder, + Siglip2EncoderLayer, + Siglip2VisionEmbeddings, +) + + +if is_torch_available(): + import torch + import torch.nn as nn + import torch.nn.functional as F + from torchvision.transforms.v2 import functional as tvF + +if is_torchvision_available(): + from ..pix2struct.image_processing_pix2struct import torch_extract_patches + + +@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") +@strict +class IsaacVisionConfig(Siglip2VisionConfig): + r""" + num_patches (`int`, *optional*, defaults to 256): + The number of patches in the image with the size of (`patch_size`, `patch_size`). The image is resized to + fill a maximum of this number of patches while preserving the aspect ratio. If the resulting number of patches + is lower, the image is padded in the patch dimension. + pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1): + Spatial factor applied before pixel shuffle reduces the resolution. + """ + + model_type = "isaac_vision" + base_config_key = "vision_config" + pixel_shuffle_scale_factor: int = 1 + + +@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") +@strict +class IsaacTextConfig(Qwen3Config): + r""" + Example: + + ```python + >>> from transformers import IsaacTextConfig, IsaacTextModel + + >>> configuration = IsaacTextConfig() + >>> model = IsaacTextModel(configuration) + >>> configuration = model.config + ``` + """ + + model_type = "isaac_text" + ignore_keys_at_rope_validation = {"mrope_section", "mrope_interleaved"} + max_position_embeddings: int = 32768 + sliding_window = AttributeError() + layer_types = AttributeError() + + def __post_init__(self, **kwargs): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + PretrainedConfig.__post_init__(self, **kwargs) + + +@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base") +@strict +class IsaacConfig(PretrainedConfig): + r""" + vision_config (`IsaacVisionConfig` or `dict`, *optional*): + Configuration for the Isaac vision tower. Dictionaries are converted to [`IsaacVisionConfig`]. If unset, + the default [`IsaacVisionConfig`] is used. + text_config (`IsaacTextConfig` or `dict`, *optional*): + Configuration for the text backbone. Dictionaries are converted to [`IsaacTextConfig`]. + vision_rescale_factor (`float`, *optional*, defaults to 1 / 255): + Rescale factor applied by the image processor before normalization. + max_sequence_length (`int`, *optional*, defaults to 16384): + Maximum multimodal sequence length produced by the processor and expected by the model. + + Example: + + ```python + >>> from transformers import IsaacConfig, IsaacModel + + >>> configuration = IsaacConfig() + >>> model = IsaacModel(configuration) + >>> configuration = model.config + ``` + """ + + model_type = "isaac" + sub_configs = {"vision_config": IsaacVisionConfig, "text_config": IsaacTextConfig} + vision_config: IsaacVisionConfig | dict | None = None + text_config: IsaacTextConfig | dict | None = None + vision_rescale_factor: float = 1 / 255 + max_sequence_length: int = 16384 + + def __post_init__(self, **kwargs): + if isinstance(self.text_config, dict): + self.text_config = self.sub_configs["text_config"](**self.text_config) + elif self.text_config is None: + self.text_config = self.sub_configs["text_config"]() + elif not isinstance(self.text_config, IsaacTextConfig): + raise TypeError( + f"text_config must be a dict or an IsaacTextConfig instance, got {type(self.text_config).__name__}." + ) + + if isinstance(self.vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**self.vision_config) + elif self.vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + elif not isinstance(self.vision_config, IsaacVisionConfig): + raise TypeError( + f"vision_config must be a dict or an IsaacVisionConfig instance, got {type(self.vision_config).__name__}." + ) + + self.vision_rescale_factor = float(self.vision_rescale_factor) + super().__post_init__(**kwargs) + + +class IsaacVisionEmbeddings(Siglip2VisionEmbeddings): + """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences. + + Isaac accepts variable-resolution vision inputs as a single packed sequence with per-image + `token_grids`; packing/unpacking here reconstructs per-image shapes so we can resize positional + embeddings and build `cu_seqlens` for variable-length attention (not generic generation packing). + """ + + def __init__(self, config: IsaacVisionConfig): + super().__init__(config) + self.position_embedding = nn.Parameter( + torch.empty( + self.position_embedding_size, + self.position_embedding_size, + self.embed_dim, + ) + ) + + def forward( + self, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + # pixel_values: (num_images, max_patches, patch_dim) + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) + + resized_positional_embeddings = self.resize_positional_embeddings( + self.position_embedding, + image_grid_thw[:, 1:], + max_length=pixel_values.shape[1], + ) + resized_positional_embeddings = resized_positional_embeddings.to( + device=patch_embeds.device, dtype=patch_embeds.dtype + ) + embeddings = patch_embeds + resized_positional_embeddings + + if attention_mask is not None: + embeddings = embeddings * attention_mask.unsqueeze(-1).to(device=embeddings.device, dtype=embeddings.dtype) + + return embeddings + + +class IsaacVisionAttention(Siglip2Attention): + """Custom attention that supports variable-length sequences with flash/SDPA backends.""" + + pass + + +class IsaacVisionEncoderLayer(Siglip2EncoderLayer): + """Isaac vision encoder layer using the shared attention interfaces.""" + + def __init__(self, config: IsaacVisionConfig): + super().__init__(config) + self.self_attn = IsaacVisionAttention(config) + + +class IsaacVisionEncoder(Siglip2Encoder): + """Encoder using Isaac encoder layers.""" + + def __init__(self, config: IsaacVisionConfig): + super().__init__(config) + self.layers = nn.ModuleList([IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + + +@auto_docstring +class IsaacVisionModel(PreTrainedModel): + config: IsaacVisionConfig + _supports_sdpa = True + _supports_flash_attn = True + _can_record_outputs = { + "hidden_states": IsaacVisionEncoderLayer, + "attentions": IsaacVisionAttention, + } + + def __init__(self, config: IsaacVisionConfig): + super().__init__(config) + self.embeddings = IsaacVisionEmbeddings(config) + self.encoder = IsaacVisionEncoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor + + self.post_init() + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, IsaacVisionEmbeddings): + init.zeros_(module.position_embedding) + + def pixel_shuffle_padded( + self, + hidden_states: torch.Tensor, + token_grids: torch.Tensor, + ) -> torch.Tensor: + """Apply pixel shuffle per image on padded batched vision embeddings. + + Args: + hidden_states (`torch.Tensor`): + Vision embeddings of shape `(num_images, max_patches, hidden_size)`. + token_grids (`torch.Tensor`): + Grid sizes `(height, width)` per image, shape `(num_images, 2)`. + + Returns: + `torch.Tensor`: Pixel-shuffled embeddings of shape + `(num_images, max_tokens, hidden_size * scale_factor**2)`. + """ + scale_factor = self.pixel_shuffle_scale_factor + num_images, max_patches, embed_dim = hidden_states.shape + output_dim = embed_dim * scale_factor * scale_factor + + token_grids = token_grids.to(device=hidden_states.device, dtype=torch.long) + heights = token_grids[:, 0] + widths = token_grids[:, 1] + full_lengths = heights * widths + + non_empty = full_lengths > 0 + if not is_torchdynamo_compiling(): + divisible = ((heights % scale_factor) == 0) & ((widths % scale_factor) == 0) + torch_compilable_check( + (~non_empty) | divisible, + f"Every non-empty (H, W) grid must be divisible by pixel_shuffle_scale={scale_factor}.", + ) + + output_lengths = (heights // scale_factor) * (widths // scale_factor) + max_output_tokens = output_lengths.max() + shuffled_4d = hidden_states.new_zeros((num_images, max_output_tokens, scale_factor * scale_factor, embed_dim)) + + token_positions = ( + torch.arange(max_patches, device=hidden_states.device, dtype=torch.long) + .unsqueeze(0) + .expand(num_images, -1) + ) + valid_token_mask = token_positions < full_lengths.unsqueeze(1) + + safe_widths = torch.where(widths > 0, widths, torch.ones_like(widths)) + row_index = torch.div(token_positions, safe_widths.unsqueeze(1), rounding_mode="floor") + col_index = token_positions.remainder(safe_widths.unsqueeze(1)) + + output_widths = widths.div(scale_factor, rounding_mode="floor") + output_index = row_index.div(scale_factor, rounding_mode="floor") * output_widths.unsqueeze(1) + output_index = output_index + col_index.div(scale_factor, rounding_mode="floor") + sub_index = row_index.remainder(scale_factor) * scale_factor + col_index.remainder(scale_factor) + + batch_index = ( + torch.arange(num_images, device=hidden_states.device, dtype=torch.long) + .unsqueeze(1) + .expand_as(token_positions) + ) + shuffled_4d[batch_index[valid_token_mask], output_index[valid_token_mask], sub_index[valid_token_mask]] = ( + hidden_states[valid_token_mask] + ) + + shuffled = shuffled_4d.view(num_images, max_output_tokens, output_dim) + return shuffled + + @merge_with_config_defaults + @capture_outputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + r""" + image_grid_thw (`torch.LongTensor`, *optional*): + Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. + """ + full_lengths = image_grid_thw[:, 1] * image_grid_thw[:, 2] + token_positions = torch.arange(pixel_values.shape[1], device=pixel_values.device, dtype=torch.long) + image_patch_attention_mask = token_positions.unsqueeze(0) < full_lengths.unsqueeze(1) + image_patch_attention_mask = image_patch_attention_mask.to(dtype=torch.long) + hidden_states = self.embeddings( + pixel_values, + image_grid_thw, + attention_mask=image_patch_attention_mask, + ) + + encoder_attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=hidden_states, + attention_mask=image_patch_attention_mask, + ) + encoder_outputs = self.encoder(inputs_embeds=hidden_states, attention_mask=encoder_attention_mask, **kwargs) + hidden_states = self.post_layernorm(encoder_outputs.last_hidden_state) + + hidden_states = self.pixel_shuffle_padded( + hidden_states=hidden_states, + token_grids=image_grid_thw[:, 1:], + ) + + return BaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=None, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class IsaacRotaryEmbedding(Qwen3VLTextRotaryEmbedding): + def __init__(self, config: IsaacTextConfig, device=None): + super().__init__(config, device=device) + self.mrope_section = config.rope_parameters.get("mrope_section") + if self.mrope_section is None: + weights = (2, 1, 1) + self.mrope_section = [self.inv_freq.shape[0] * w // sum(weights) for w in weights] + self.mrope_section[0] += self.inv_freq.shape[0] - sum(self.mrope_section) + + def apply_interleaved_mrope(self, freqs: torch.Tensor, mrope_section: list[int]) -> torch.Tensor: + chunks = freqs.split(tuple(mrope_section), dim=-1) + return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) + + +class IsaacTextDecoderLayer(Qwen3VLTextDecoderLayer): + pass + + +class IsaacTextModel(Qwen3VLTextModel): + def __init__(self, config: IsaacTextConfig): + super().__init__(config) + self.rotary_emb = IsaacRotaryEmbedding(config=config, device=self.device) + + +class IsaacMultiModalProjector(nn.Module): + """Maps vision tower outputs to the text hidden size with a SiLU MLP.""" + + def __init__(self, config: IsaacConfig): + super().__init__() + text_config = config.get_text_config() + vision_hidden_size = config.vision_config.hidden_size * (config.vision_config.pixel_shuffle_scale_factor**2) + backbone_hidden_size = text_config.hidden_size + self.linear_1 = nn.Linear(vision_hidden_size, 4 * vision_hidden_size, bias=False) + self.silu = nn.SiLU() + self.linear_2 = nn.Linear(4 * vision_hidden_size, backbone_hidden_size, bias=False) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.silu(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +@auto_docstring +class IsaacModel(Qwen3VLModel): + input_modalities = ("image", "text") + supports_gradient_checkpointing = True + _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionEncoderLayer"] + _can_compile_fullgraph = False + _supports_flex_attn = False + _tied_weights_keys = {} + _input_embed_layer = "language_model.embed_tokens" + + def __init__(self, config: IsaacConfig): + PreTrainedModel.__init__(self, config) + self.language_model = IsaacTextModel._from_config(config.text_config) + self.visual = IsaacVisionModel(config.vision_config) + self.multimodal_projector = IsaacMultiModalProjector(config) + self.max_sequence_length = config.max_sequence_length + self.vision_rescale_factor = config.vision_rescale_factor + self.rope_deltas = None + + self.post_init() + + @can_return_tuple + @auto_docstring + def get_image_features( + self, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, + image_metadata: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values (`torch.FloatTensor`, *optional*): + Batch-major patch vectors shaped `(batch_size, max_images, max_patches, patch_dim)`. + image_grid_thw (`torch.LongTensor`, *optional*): + Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. + image_metadata (`torch.Tensor`, *optional*): + Batch-major per-slot metadata `(offset, length)` shaped `(batch_size, max_images, 2)`. + """ + active_slot_mask = image_grid_thw[..., 0].eq(1) + flat_pixel_values = pixel_values[active_slot_mask] + flat_image_grid_thw = image_grid_thw[active_slot_mask] + + vision_outputs: BaseModelOutputWithPooling = self.visual( + pixel_values=flat_pixel_values, + image_grid_thw=flat_image_grid_thw, + return_dict=True, + **kwargs, + ) + projected_features = self.multimodal_projector(vision_outputs.last_hidden_state) + + # Truncate image features using offset and length + if image_metadata is None: + pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor + downsampled_height = flat_image_grid_thw[:, 1].div(pixel_shuffle_scale, rounding_mode="floor") + downsampled_width = flat_image_grid_thw[:, 2].div(pixel_shuffle_scale, rounding_mode="floor") + lengths = downsampled_height * downsampled_width + offsets = torch.zeros_like(lengths) + else: + torch_compilable_check( + image_metadata.shape[:2] == image_grid_thw.shape[:2], + "IsaacModel.get_image_features expects batch-major metadata aligned with `image_grid_thw`.", + ) + offsets = image_metadata[active_slot_mask][:, 0] + lengths = image_metadata[active_slot_mask][:, 1] + + image_features = tuple( + projected_features[image_idx, offset : offset + length] + for image_idx, (offset, length) in enumerate(zip(offsets.tolist(), lengths.tolist(), strict=True)) + ) + + return BaseModelOutputWithPooling( + last_hidden_state=projected_features, + pooler_output=image_features, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + ) + + def get_placeholder_mask( + self, + mm_token_type_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor, + ) -> torch.BoolTensor: + image_token_mask = mm_token_type_ids.to(dtype=torch.long) == 1 + n_image_tokens = image_token_mask.sum() + image_token_mask = image_token_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[image_token_mask].numel() == image_features.numel(), + f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", + ) + return image_token_mask + + def get_video_features(self, **super_kwargs): + raise AttributeError("Isaac is image-only and does not support `pixel_values_videos` or `video_grid_thw`.") + + def get_vision_position_ids( + self, + start_position: int, + grid_thw: torch.LongTensor, + image_metadata: torch.LongTensor, + ) -> torch.LongTensor: + pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor + height = grid_thw[1].div(pixel_shuffle_scale, rounding_mode="floor").item() + width = grid_thw[2].div(pixel_shuffle_scale, rounding_mode="floor").item() + token_positions = torch.arange(height * width, device=grid_thw.device, dtype=torch.long) + vision_position_ids = torch.stack( + ( + torch.full((token_positions.shape[0],), start_position, device=grid_thw.device, dtype=torch.long), + token_positions.div(width, rounding_mode="floor"), + token_positions.remainder(width), + ), + dim=0, + ) + token_offset = int(image_metadata[0].item()) + token_length = int(image_metadata[1].item()) + return vision_position_ids[:, token_offset : token_offset + token_length] + + def get_rope_index( + self, + input_ids: torch.LongTensor, + mm_token_type_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + image_metadata: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if attention_mask is None: + if input_ids is None: + attention_mask = mm_token_type_ids.new_ones(mm_token_type_ids.shape, dtype=torch.long) + else: + attention_mask = input_ids.new_ones(input_ids.shape, dtype=torch.long) + + if input_ids is None: + batch_size, seq_len = attention_mask.shape + position_dtype = torch.long + else: + batch_size, seq_len = input_ids.shape + position_dtype = input_ids.dtype + + device = attention_mask.device + mm_token_type_ids = mm_token_type_ids.to(dtype=torch.long) + image_grid_thw = image_grid_thw.to(dtype=torch.long) + image_metadata = image_metadata.to(dtype=torch.long) + attention_mask = attention_mask.to(dtype=torch.long) + active_slot_mask = image_grid_thw[..., 0].eq(1) + + position_ids = torch.zeros((3, batch_size, seq_len), device=device, dtype=position_dtype) + rope_deltas = torch.zeros((batch_size, 1), device=device, dtype=torch.long) + + for batch_idx in range(batch_size): + sample_attention_mask = attention_mask[batch_idx].bool() + sample_token_types = mm_token_type_ids[batch_idx][sample_attention_mask] + sample_grids = image_grid_thw[batch_idx] + sample_metadata = image_metadata[batch_idx] + sample_active_slots = active_slot_mask[batch_idx] + + current_pos = 0 + image_idx = 0 + seq_pos = 0 + llm_pos_ids_list = [] + + while seq_pos < sample_token_types.shape[0]: + modality_type = int(sample_token_types[seq_pos].item()) + if modality_type == 0: + group_end = seq_pos + 1 + while group_end < sample_token_types.shape[0] and sample_token_types[group_end] == 0: + group_end += 1 + group_length = group_end - seq_pos + llm_pos_ids_list.append( + torch.arange(group_length, device=device, dtype=torch.long).view(1, -1).expand(3, -1) + + current_pos + ) + current_pos += group_length + seq_pos = group_end + else: + while image_idx < sample_metadata.shape[0] and ( + not bool(sample_active_slots[image_idx].item()) or sample_metadata[image_idx, 1].item() == 0 + ): + image_idx += 1 + torch_compilable_check( + image_idx < sample_metadata.shape[0], + "Isaac multimodal sequence has more visible image tokens than batch-major image metadata slots.", + ) + token_length = int(sample_metadata[image_idx, 1].item()) + torch_compilable_check( + token_length <= sample_token_types.shape[0] - seq_pos, + "Isaac image metadata length exceeds the remaining multimodal placeholder span.", + ) + llm_pos_ids_list.append( + self.get_vision_position_ids(current_pos, sample_grids[image_idx], sample_metadata[image_idx]) + ) + current_pos += 1 + seq_pos += token_length + image_idx += 1 + + llm_positions = ( + torch.cat(llm_pos_ids_list, dim=1) + if llm_pos_ids_list + else torch.zeros((3, 0), device=device, dtype=torch.long) + ) + position_ids[:, batch_idx, sample_attention_mask] = llm_positions + rope_deltas[batch_idx, 0] = ( + llm_positions.max() + 1 - sample_token_types.shape[0] if llm_positions.numel() > 0 else 0 + ) + + return position_ids, rope_deltas + + def compute_3d_position_ids( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor | None, + mm_token_type_ids: torch.Tensor | None = None, + image_grid_thw: torch.Tensor | None = None, + image_metadata: torch.Tensor | None = None, + past_key_values: Cache | None = None, + ) -> torch.Tensor: + past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() + has_multimodal = ( + image_grid_thw is not None + and image_metadata is not None + and bool(image_grid_thw[..., 0].eq(1).any().item()) + ) + if has_multimodal and mm_token_type_ids is None and input_ids is not None: + raise ValueError( + "Multimodal data was passed (via `image_grid_thw` or `image_metadata`) but `mm_token_type_ids` is " + "missing. Please pass `mm_token_type_ids` to the model so that multimodal RoPE (M-RoPE) can be " + "computed correctly. `mm_token_type_ids` is returned by the processor alongside `input_ids`." + ) + + if has_multimodal and past_seen_tokens == 0: + position_ids, rope_deltas = self.get_rope_index( + input_ids=input_ids, + mm_token_type_ids=mm_token_type_ids, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, + attention_mask=attention_mask, + ) + self.rope_deltas = rope_deltas + return position_ids + + if self.rope_deltas is None: + return None + + rope_deltas = torch.as_tensor(self.rope_deltas, device=inputs_embeds.device, dtype=torch.long).reshape(-1, 1) + if rope_deltas.shape[0] != inputs_embeds.shape[0]: + if inputs_embeds.shape[0] % rope_deltas.shape[0] == 0: + rope_deltas = rope_deltas.repeat_interleave(inputs_embeds.shape[0] // rope_deltas.shape[0], dim=0) + else: + rope_deltas = rope_deltas[:1].expand(inputs_embeds.shape[0], -1) + + if attention_mask is not None and attention_mask.shape[-1] > inputs_embeds.shape[1]: + rope_position = attention_mask.long().cumsum(dim=-1) - 1 + rope_position = rope_position.masked_fill(attention_mask == 0, 0) + rope_position = rope_position[:, -inputs_embeds.shape[1] :] + else: + rope_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + dtype=torch.long, + ).view(1, -1) + rope_position = rope_position.expand(inputs_embeds.shape[0], -1) + + position_ids = rope_position.view(1, inputs_embeds.shape[0], -1).expand(3, -1, -1) + return position_ids + rope_deltas.to(device=inputs_embeds.device).unsqueeze(0) + + @auto_docstring( + custom_intro=""" + Forward pass with multimodal MRoPE position ids. + + When image placeholders are present, Isaac computes vision features, scatters them into the token + embeddings, and runs the shared text backbone on the mixed sequence. + """, + ) + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor | None = None, + mm_token_type_ids: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + image_metadata: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPast: + r""" + mm_token_type_ids (`torch.LongTensor`, *optional*): + Multimodal token type ids aligned with the token sequence, using `0 -> text` and `1 -> image`. + pixel_values (`torch.FloatTensor`, *optional*): + Batch-major patch vectors shaped `(batch_size, max_images, max_patches, patch_dim)`. + image_grid_thw (`torch.LongTensor`, *optional*): + Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries. + image_metadata (`torch.LongTensor`, *optional*): + Batch-major per-slot metadata shaped `(batch_size, max_images, 2)` with `(offset, length)`. + """ + if (input_ids is None) == (inputs_embeds is None): + raise ValueError("You must specify exactly one of `input_ids` or `inputs_embeds`.") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_mask = None + if pixel_values is not None and image_grid_thw is not None: + image_outputs = self.get_image_features( + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, + return_dict=True, + ) + image_embeds = image_outputs.pooler_output + if len(image_embeds) > 0: + image_embeds = torch.cat(image_embeds, dim=0).to( + device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + image_mask = self.get_placeholder_mask( + mm_token_type_ids=mm_token_type_ids, + inputs_embeds=inputs_embeds, + image_features=image_embeds, + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if isinstance(attention_mask, dict): + attention_mask = attention_mask["full_attention"] + + past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() + computed_position_ids = self.compute_3d_position_ids( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + mm_token_type_ids=mm_token_type_ids, + image_grid_thw=image_grid_thw, + image_metadata=image_metadata, + attention_mask=attention_mask, + past_key_values=past_key_values, + ) + if computed_position_ids is not None: + position_ids = computed_position_ids + elif past_seen_tokens > 0: + position_ids = None + elif position_ids is not None and past_seen_tokens == 0: + position_ids = position_ids.to(device=inputs_embeds.device) + if position_ids.ndim == 2: + position_ids = position_ids.view(1, position_ids.shape[0], -1).expand(3, -1, -1) + + outputs = self.language_model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + visual_pos_masks=image_mask[..., 0] if image_mask is not None else None, + deepstack_visual_embeds=None, + use_cache=use_cache, + **kwargs, + ) + + return BaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@dataclass +class IsaacCausalLMOutputWithPast(CausalLMOutputWithPast): + """ + Base class for Isaac causal language model (or autoregressive) outputs. + + Args: + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + rope_deltas: torch.LongTensor | None = None + + +@auto_docstring +class IsaacForConditionalGeneration(Qwen3VLForConditionalGeneration, GenerationMixin): + config_class = IsaacConfig + input_modalities = ("image", "text") + _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionEncoderLayer"] + _can_compile_fullgraph = False + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: IsaacConfig): + PreTrainedModel.__init__(self, config) + self.model = IsaacModel(config) + self.vocab_size = config.get_text_config().vocab_size + self.lm_head = nn.Linear(config.get_text_config().hidden_size, config.get_text_config().vocab_size, bias=False) + self.post_init() + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + mm_token_type_ids: torch.IntTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | IsaacCausalLMOutputWithPast: + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + mm_token_type_ids=mm_token_type_ids, + **kwargs, + ) + + hidden_states = outputs[0] + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return IsaacCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.model.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values: Cache = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + mm_token_type_ids: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + image_metadata: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + is_first_iteration: bool = False, + use_cache: bool = True, + **kwargs, + ) -> dict[str, Any]: + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + is_first_iteration=is_first_iteration, + use_cache=use_cache, + **kwargs, + ) + + multimodal_inputs = { + "mm_token_type_ids": mm_token_type_ids, + "pixel_values": pixel_values, + "image_grid_thw": image_grid_thw, + "image_metadata": image_metadata, + } + is_prefill = is_first_iteration or not use_cache + for key, value in multimodal_inputs.items(): + model_inputs[key] = value if is_prefill else None + if model_inputs["mm_token_type_ids"] is not None: + sequence_length = None + if model_inputs.get("input_ids") is not None: + sequence_length = model_inputs["input_ids"].shape[1] + elif model_inputs.get("inputs_embeds") is not None: + sequence_length = model_inputs["inputs_embeds"].shape[1] + + if sequence_length is not None: + current_length = model_inputs["mm_token_type_ids"].shape[1] + if current_length < sequence_length: + padding = model_inputs["mm_token_type_ids"].new_zeros( + (model_inputs["mm_token_type_ids"].shape[0], sequence_length - current_length) + ) + model_inputs["mm_token_type_ids"] = torch.cat([model_inputs["mm_token_type_ids"], padding], dim=1) + elif current_length > sequence_length: + model_inputs["mm_token_type_ids"] = model_inputs["mm_token_type_ids"][:, -sequence_length:] + + return model_inputs + + def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs): + text_positions = GenerationMixin._prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs) + + past_length = 0 + if (cache := model_kwargs.get("past_key_values")) is not None: + past_length = cache.get_seq_length() + if past_length != 0 and self.model.rope_deltas is not None: + return text_positions[None, ...] + self.model.rope_deltas + + if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0: + inputs_tensor = model_kwargs["input_ids"] + + is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long] + if ( + is_input_ids + and model_kwargs.get("mm_token_type_ids") is not None + and model_kwargs.get("image_grid_thw") is not None + and model_kwargs.get("image_metadata") is not None + ): + model_kwargs = {k: v for k, v in model_kwargs.items() if k != "input_ids"} + vision_positions, rope_deltas = self.model.get_rope_index(inputs_tensor, **model_kwargs) + self.model.rope_deltas = rope_deltas + else: + vision_positions = text_positions.unsqueeze(0).expand(3, -1, -1) + self.model.rope_deltas = torch.zeros( + inputs_tensor.shape[0], 1, dtype=torch.long, device=inputs_tensor.device + ) + + return torch.cat([text_positions[None, ...], vision_positions], dim=0) + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: torch.LongTensor | None = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + position_ids = model_kwargs.pop("position_ids", None) + if expand_size == 1: + if position_ids is not None: + model_kwargs["position_ids"] = position_ids + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "image_metadata"] + for key in visual_keys: + value = model_kwargs.get(key) + if value is not None: + model_kwargs[key] = value.repeat_interleave(expand_size, dim=0) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + for key, value in list(model_kwargs.items()): + if key == "position_ids" and value is not None and value.ndim == 3: + model_kwargs[key] = value.repeat_interleave(expand_size, dim=1) + elif value is not None and isinstance(value, torch.Tensor) and key not in visual_keys: + model_kwargs[key] = value.repeat_interleave(expand_size, dim=0) + + if position_ids is not None: + dim = 1 if position_ids.ndim == 3 else 0 + model_kwargs["position_ids"] = position_ids.repeat_interleave(expand_size, dim=dim) + return input_ids, model_kwargs + + +# --------------------------------Isaac Image Processor-------------------------------- + + +class IsaacImageProcessorKwargs(ImagesKwargs, total=False): + """ + patch_size (`int`, *optional*): + Side length (in pixels) for square patches extracted from resized images. + max_num_patches (`int`, *optional*): + Upper bound on extracted patches per image after resizing. + min_num_patches (`int`, *optional*): + Lower bound on extracted patches per image after resizing. + pixel_shuffle_scale (`int`, *optional*): + Pixel-shuffle reduction factor applied in the vision tower. + """ + + patch_size: int + max_num_patches: int + min_num_patches: int + pixel_shuffle_scale: int + + +def get_scaled_image_size( + scale: float, + original_size: int, + patch_size: int, + pixel_shuffle_scale: int, +) -> int: + scaled_size = scale * original_size + divisor = patch_size * pixel_shuffle_scale + scaled_size = math.ceil(scaled_size / divisor) * divisor + scaled_size = max(divisor, scaled_size) + return int(scaled_size) + + +def get_image_size_for_max_num_patches( + image_height: int, + image_width: int, + patch_size: int, + max_num_patches: int, + min_num_patches: int | None = None, + eps: float = 1e-5, + pixel_shuffle_scale: int = 1, +) -> tuple[int, int]: + r"""Compute a target resolution whose patch grid satisfies patching parametrization. + + Args: + image_height (`int`): + Height in pixels of the source image prior to any resizing. + image_width (`int`): + Width in pixels of the source image prior to any resizing. + patch_size (`int`): + Size of the square patch used by the vision encoder. + max_num_patches (`int`): + Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. + min_num_patches (`int`, *optional*): + Lower bound on the number of patches. When provided the image will be scaled up if necessary. + eps (`float`, *optional*, defaults to 1e-5): + Convergence tolerance for the internal binary search to determing the target dimensions. + pixel_shuffle_scale (`int`, *optional*, defaults to 1): + Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. + + Returns: + `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` + and respect both the maximum and optional minimum patch-count constraints. + """ + + # Ensure divisibility + divisor = patch_size * pixel_shuffle_scale + adjusted_height = math.ceil(image_height / divisor) * divisor + adjusted_height = max(divisor, adjusted_height) + adjusted_width = math.ceil(image_width / divisor) * divisor + adjusted_width = max(divisor, adjusted_width) + + num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) + + if min_num_patches is not None and num_patches < min_num_patches: + # Scale up via binary search to satisfy the minimum patch budget while + # preserving divisibility by patch_size * pixel_shuffle_scale. + scale_min, scale_max = 1.0, 100.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + num_patches = (target_height / patch_size) * (target_width / patch_size) + if num_patches >= min_num_patches: + scale_max = scale + else: + scale_min = scale + scale = scale_max + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + return target_height, target_width + elif num_patches <= max_num_patches: + return adjusted_height, adjusted_width + else: + # Scale down + scale_min, scale_max = eps / 10, 1.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + num_patches = (target_height / patch_size) * (target_width / patch_size) + if num_patches <= max_num_patches: + scale_min = scale + else: + scale_max = scale + scale = scale_min + target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) + target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + return target_height, target_width + + +@auto_docstring +@requires(backends=("vision",)) +class IsaacImageProcessor(TorchvisionBackend): + model_input_names = ["pixel_values", "image_grid_thw"] + valid_kwargs = IsaacImageProcessorKwargs + + resample = PILImageResampling.BILINEAR + do_resize = True + do_center_crop = False + patch_size = 16 + max_num_patches = 256 + min_num_patches = None + pixel_shuffle_scale = 1 + do_pad = True + do_rescale = True + do_normalize = True + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + do_convert_rgb = True + disable_grouping = False + + def __init__(self, **kwargs: Unpack[IsaacImageProcessorKwargs]): + super().__init__(**kwargs) + + def _validate_preprocess_kwargs(self, **kwargs): + # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for. + kwargs.pop("do_resize", None) + return super()._validate_preprocess_kwargs(**kwargs) + + def _prepare_images_structure( + self, + images: ImageInput, + expected_ndims: int = 3, + ) -> ImageInput: + images = self.fetch_images(images) + return make_nested_list_of_images(images, expected_ndims=expected_ndims) + + def resize( + self, + image: torch.Tensor, + size: SizeDict, + **kwargs, + ) -> torch.Tensor: + if image.dtype == torch.uint8: + image = F.interpolate(image.float(), size=(size.height, size.width), mode="bilinear", align_corners=False) + return image.clamp(0, 255).round().to(torch.uint8) + return F.interpolate(image, size=(size.height, size.width), mode="bilinear", align_corners=False) + + def pack_images( + self, + vision_patches: list[list[torch.Tensor]], + vision_token_grids: list[list[torch.Tensor]], + ) -> dict[str, torch.Tensor | None]: + batch_size = len(vision_patches) + flat_patches = [patches for sample_patches in vision_patches for patches in sample_patches] + if len(flat_patches) == 0: + return {"pixel_values": None, "image_grid_thw": None} + + first_patch = flat_patches[0] + max_patches = max(patches.shape[0] for patches in flat_patches) + max_images = max((len(sample_patches) for sample_patches in vision_patches), default=0) + + patch_dim = first_patch.shape[-1] + tensors = { + "pixel_values": torch.zeros( + (batch_size, max_images, max_patches, patch_dim), + device=first_patch.device, + dtype=first_patch.dtype, + ), + "image_grid_thw": torch.zeros((batch_size, max_images, 3), device=first_patch.device, dtype=torch.long), + } + + for batch_idx, (sample_patches, sample_token_grids) in enumerate( + zip(vision_patches, vision_token_grids, strict=True) + ): + for image_slot, (patches, token_grid) in enumerate(zip(sample_patches, sample_token_grids, strict=True)): + patch_count = int(patches.shape[0]) + tensors["pixel_values"][batch_idx, image_slot, :patch_count] = patches + tensors["image_grid_thw"][batch_idx, image_slot, 0] = 1 + tensors["image_grid_thw"][batch_idx, image_slot, 1:] = token_grid + + return tensors + + def _preprocess( + self, + images: list[list[torch.Tensor]], + do_resize: bool, + resample: PILImageResampling | tvF.InterpolationMode | int | None, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: float | list[float] | None, + image_std: float | list[float] | None, + do_pad: bool, + patch_size: int, + max_num_patches: int, + min_num_patches: int, + pixel_shuffle_scale: int, + disable_grouping: bool | None, + return_tensors: str | TensorType | None, + **kwargs, + ) -> BatchFeature: + if all(len(sample_images) == 0 for sample_images in images): + return BatchFeature(data={"pixel_values": None, "image_grid_thw": None}, tensor_type=return_tensors) + + grouped_images, grouped_images_index = group_images_by_shape( + images, disable_grouping=disable_grouping, is_nested=True + ) + grouped_outputs = {} + for shape, stacked_images in grouped_images.items(): + grouped_batch_size, channels, original_height, original_width = stacked_images.shape + if do_resize: + target_height, target_width = get_image_size_for_max_num_patches( + original_height, + original_width, + patch_size, + max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + ) + image_batch = self.resize( + stacked_images, SizeDict(height=target_height, width=target_width), resample=resample + ) + else: + if (original_height % patch_size) or (original_width % patch_size): + raise ValueError( + f"Image dimensions (h={original_height}, w={original_width}) must be divisible by patch_size={patch_size} when resize is disabled; enable resizing or adjust the input resolution." + ) + image_batch, target_height, target_width = stacked_images, original_height, original_width + + image_batch = self.rescale_and_normalize( + image_batch, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + ) + + patches = torch_extract_patches(image_batch, patch_size, patch_size) + _, height_tokens, width_tokens, patch_dim = patches.shape + + token_grid = ( + torch.tensor([height_tokens, width_tokens], device=patches.device).long().expand(grouped_batch_size, 2) + ) + + if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale): + raise ValueError( + f"Token grid (h={height_tokens}, w={width_tokens}) must be divisible by pixel_shuffle_scale={pixel_shuffle_scale};" + f" adjust resize/patch parameters or disable pixel shuffle." + ) + + grouped_outputs[shape] = ( + patches.reshape(grouped_batch_size, -1, patch_dim), + token_grid, + ) + + keys = ("vision_patches", "vision_token_grids") + nested_outputs = {} + for i, key in enumerate(keys): + nested_outputs[key] = reorder_images( + {shape: values[i] for shape, values in grouped_outputs.items()}, + dict(grouped_images_index), + is_nested=True, + ) + + if not do_pad: + raise ValueError("IsaacImageProcessor doesn't support `do_pad=False` mode.") + + tensors = self.pack_images( + vision_patches=nested_outputs["vision_patches"], + vision_token_grids=nested_outputs["vision_token_grids"], + ) + + return BatchFeature(data=tensors, tensor_type=return_tensors) + + def get_number_of_image_patches( + self, + image_height: int, + image_width: int, + images_kwargs: dict[str, Any] | None = None, + ) -> int: + images_kwargs = images_kwargs or {} + patch_size = images_kwargs.get("patch_size", self.patch_size) + max_num_patches = images_kwargs.get("max_num_patches", self.max_num_patches) + min_num_patches = images_kwargs.get("min_num_patches", self.min_num_patches) + pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale", self.pixel_shuffle_scale) + + target_height, target_width = get_image_size_for_max_num_patches( + image_height, + image_width, + patch_size, + max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + ) + return (target_height // patch_size) * (target_width // patch_size) + + +# --------------------------------Isaac Processor-------------------------------- + + +class IsaacProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs = IsaacImageProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": True, + "truncation": True, + "truncation_side": "left", + "return_attention_mask": True, + "return_overflowing_tokens": True, + "return_mm_token_type_ids": True, + "add_special_tokens": False, + }, + } + + +class SinglePoint(NamedTuple): + x: int + y: int + mention: str | None = None + t: float | None = None + + +class BoundingBox(NamedTuple): + top_left: Any + bottom_right: Any + mention: str | None = None + t: float | None = None + + +class Polygon(NamedTuple): + points: tuple[Any, ...] + mention: str | None = None + t: float | None = None + + +_point_box_or_polygon_tag = re.compile( + r"<(?Ppoint|point_box|polygon)(?P[^>]*)>(?P[\s\S]*?)", re.IGNORECASE +) +_attr_re = re.compile(r"(\w+)\s*=\s*(?:\"([^\"]*)\"|([^\s>]+))") +_coord_re = re.compile(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)") + + +@auto_docstring +@requires(backends=("vision",)) +class IsaacProcessor(ProcessorMixin): + def __init__( + self, + image_processor, + tokenizer, + chat_template: str | dict[str, str] | None = None, + max_sequence_length: int = 16384, + ): + r""" + max_sequence_length (`int`, *optional*, defaults to 16384): + Maximum packed multimodal sequence length produced by the processor. + """ + if chat_template is None: + chat_template = getattr(tokenizer, "chat_template", None) + + self.pad_token_id = tokenizer.pad_token_id + self.image_token = getattr(tokenizer, "image_pad_token", None) or getattr(tokenizer, "image_token", None) + self.image_token_id = getattr(tokenizer, "image_pad_token_id", None) or getattr( + tokenizer, "image_token_id", None + ) + self.max_sequence_length = max_sequence_length + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + text: str | list[str], + images: ImageInput | None = None, + **kwargs, + ) -> BatchFeature: + output_kwargs = self._merge_kwargs( + IsaacProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + # 1. Validate number of that text and images match + texts = [text] if isinstance(text, str) else text.copy() + rendered_image_token = "" + if self.image_token is not None and self.image_token != rendered_image_token: + # Isaac's current chat template still renders ``, while the tokenizer exposes + # `<|image_pad|>`. Normalize here so apply_chat_template(..., tokenize=True, + # return_dict=True) follows the standard ProcessorMixin path. + texts = [text_value.replace(rendered_image_token, self.image_token) for text_value in texts] + if images is None: + batched_images = [[] for _ in texts] + else: + fetched_images = self.image_processor.fetch_images(images) + batched_images = make_nested_list_of_images(fetched_images) + if len(batched_images) != len(texts): + num_images_in_text = [text_value.count(self.image_token) for text_value in texts] + num_images_in_images = [len(sample_images) for sample_images in batched_images] + add_message = "" + if sum(num_images_in_text) == sum(num_images_in_images): + add_message = " Make sure to pass your images as a nested list, where each sub-list holds images for one text sample." + raise ValueError( + f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(texts)}).{add_message}" + ) + + # 2. Process images + image_inputs = self.image_processor(images=batched_images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + + # 3. Expand text with image placeholders + merge_length = self.image_processor.pixel_shuffle_scale**2 + if image_grid_thw is None: + vision_segment_lengths = None + else: + vision_segment_lengths = image_grid_thw.prod(dim=-1) // merge_length + for batch_idx in range(len(texts)): + image_idx = 0 + while self.image_token in texts[batch_idx]: + num_image_tokens = vision_segment_lengths[batch_idx, image_idx] + texts[batch_idx] = texts[batch_idx].replace( + self.image_token, "<|placeholder|>" * num_image_tokens, 1 + ) + image_idx += 1 + texts[batch_idx] = texts[batch_idx].replace("<|placeholder|>", self.image_token) + + # 4. Process text + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids") + max_length = output_kwargs["text_kwargs"].pop("max_length", None) + max_length = self.max_sequence_length if max_length is None else max_length + text_inputs = self.tokenizer(texts, max_length=max_length, **output_kwargs["text_kwargs"]) + + truncated_input_ids: list[list[int] | None] = [None] * len(texts) + truncated_attention_mask: list[list[int] | None] = [None] * len(texts) + offset_mappings = text_inputs.get("offset_mapping") + truncated_offset_mapping: list[list[list[int]] | None] | None = None + if offset_mappings is not None: + truncated_offset_mapping = [None] * len(texts) + overflow_input_ids_per_sample = defaultdict(int) + + # 5. Drop overflowing token ids + if offset_mappings is None: + iterator = zip( + text_inputs["overflow_to_sample_mapping"], text_inputs["input_ids"], text_inputs["attention_mask"] + ) + else: + iterator = zip( + text_inputs["overflow_to_sample_mapping"], + text_inputs["input_ids"], + text_inputs["attention_mask"], + offset_mappings, + ) + + for sample in iterator: + if offset_mappings is None: + batch_idx, input_ids, attention_mask = sample + offset_mapping = None + else: + batch_idx, input_ids, attention_mask, offset_mapping = sample + + if truncated_input_ids[batch_idx] is None: + truncated_input_ids[batch_idx] = input_ids + truncated_attention_mask[batch_idx] = attention_mask + if truncated_offset_mapping is not None: + truncated_offset_mapping[batch_idx] = offset_mapping + else: + overflow_input_ids_per_sample[batch_idx] += input_ids.count(self.image_token_id) + + # 6. Do the same for overflowing pixel values. Isaac truncates images based on `max_length` + # We can't really truncate pixels, so we pass over an image offset mask. Model will crop off + # truncated image pixels at run-time using this mask + image_metadata = None + if image_grid_thw is not None: + batch_size, max_images = image_grid_thw.shape[:2] + image_metadata = torch.zeros((batch_size, max_images, 2), dtype=torch.long) + for batch_idx, image_lengths in enumerate(vision_segment_lengths): + remaining_dropped = overflow_input_ids_per_sample[batch_idx] + for image_idx, length in enumerate(image_lengths): + offset = 0 + if 0 < remaining_dropped < length: + offset = remaining_dropped + length -= offset + remaining_dropped = 0 + elif remaining_dropped >= length: + dropped_length = length + length = 0 + remaining_dropped -= dropped_length + + # Record which suffix of this image's placeholder span survives left truncation. + # The model still encodes the full image and uses this window for both feature gathering and vision RoPE. + image_metadata[batch_idx, image_idx, 0] = offset + image_metadata[batch_idx, image_idx, 1] = length + + data = { + "input_ids": torch.tensor(truncated_input_ids, dtype=torch.long), + "attention_mask": torch.tensor(truncated_attention_mask, dtype=torch.long), + "image_metadata": image_metadata, + **image_inputs, + } + if truncated_offset_mapping is not None: + data["offset_mapping"] = torch.tensor(truncated_offset_mapping, dtype=torch.long) + + if return_mm_token_type_ids: + data["mm_token_type_ids"] = self.create_mm_token_type_ids(data["input_ids"]) + + return BatchFeature(data=data, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + vision_data = {} + if image_sizes is not None: + images_kwargs = dict(IsaacProcessorKwargs._defaults.get("images_kwargs", {})) + images_kwargs.update(kwargs) + + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) + for image_size in image_sizes + ] + pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale") or self.image_processor.pixel_shuffle_scale + num_image_tokens = [num_patches // pixel_shuffle_scale**2 for num_patches in num_image_patches] + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + + @property + def model_input_names(self): + return super().model_input_names + ["mm_token_type_ids", "image_metadata"] + + @staticmethod + def _maybe_float(value: str | None) -> float | None: + try: + return float(value) + except (ValueError, TypeError): + return None + + @classmethod + def _parse_attrs(cls, attr_text: str) -> dict[str, str]: + attrs = {} + for match in _attr_re.finditer(attr_text): + key = match.group(1) + value = match.group(2) or match.group(3) or "" + attrs[key] = value + return attrs + + @classmethod + def _parse_point_body( + cls, + body: str, + mention: str | None = None, + t: str | None = None, + ) -> Any: + match = _coord_re.search(body) + if not match: + raise ValueError(f"Malformed tag: {body!r}") + x, y = int(match.group(1)), int(match.group(2)) + return SinglePoint(x=x, y=y, mention=mention, t=cls._maybe_float(t)) + + @classmethod + def _parse_box_body( + cls, + body: str, + mention: str | None = None, + t: str | None = None, + ) -> Any: + coords = list(_coord_re.finditer(body)) + if len(coords) < 2: + raise ValueError(f"Malformed tag: {body!r}") + + top_left = SinglePoint(x=int(coords[0].group(1)), y=int(coords[0].group(2))) + bottom_right = SinglePoint(x=int(coords[1].group(1)), y=int(coords[1].group(2))) + return BoundingBox(top_left=top_left, bottom_right=bottom_right, mention=mention, t=cls._maybe_float(t)) + + @classmethod + def _parse_polygon_body( + cls, + body: str, + mention: str | None = None, + t: str | None = None, + ) -> Any: + coords = list(_coord_re.finditer(body)) + if len(coords) < 3: + raise ValueError(f"Malformed tag: {body!r}") + + points = tuple(SinglePoint(x=int(coord.group(1)), y=int(coord.group(2))) for coord in coords) + return Polygon(points=points, mention=mention, t=cls._maybe_float(t)) + + @classmethod + def clean_text_and_extract_points( + cls, + text: str, + expected: str | None = None, + ) -> tuple[str, list[Any]]: + results: list[Any] = [] + for match in _point_box_or_polygon_tag.finditer(text): + tag = match.group("tag").lower() + attrs = cls._parse_attrs(match.group("attrs")) + mention = attrs.get("mention") + t = attrs.get("t") + if tag == "point": + if expected not in (None, "point"): + continue + results.append(cls._parse_point_body(match.group("body"), mention=mention, t=t)) + elif tag == "point_box": + if expected not in (None, "box"): + continue + results.append(cls._parse_box_body(match.group("body"), mention=mention, t=t)) + else: + if expected not in (None, "polygon"): + continue + results.append(cls._parse_polygon_body(match.group("body"), mention=mention, t=t)) + + clean_text = re.sub(r"\s+", " ", _point_box_or_polygon_tag.sub("", text or "")).strip() + return clean_text, results + + def post_process_generation( + self, + text: str, + expected: str | None = None, + cleanup_and_extract: bool = True, + ) -> str | tuple[str, list[Any]]: + if cleanup_and_extract: + return self.clean_text_and_extract_points(text, expected=expected) + return text + + def post_process_image_text_to_text( + self, + generated_outputs, + skip_special_tokens: bool = True, + cleanup_and_extract: bool = False, + expected: str | None = None, + **kwargs, + ): + generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs) + return [ + self.post_process_generation(text, expected=expected, cleanup_and_extract=cleanup_and_extract) + for text in generated_texts + ] + + +__all__ = [ + "IsaacConfig", + "IsaacTextConfig", + "IsaacTextModel", + "IsaacVisionConfig", + "IsaacVisionModel", + "IsaacModel", + "IsaacPreTrainedModel", # noqa: F822 + "IsaacForConditionalGeneration", + "IsaacImageProcessor", + "IsaacProcessor", +] diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py new file mode 100644 index 000000000000..44dc5688e2a2 --- /dev/null +++ b/src/transformers/models/isaac/processing_isaac.py @@ -0,0 +1,356 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/isaac/modular_isaac.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_isaac.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 Perceptron, Inc and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from collections import defaultdict +from typing import Any + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, make_nested_list_of_images +from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin +from ...utils import auto_docstring +from ...utils.import_utils import is_torch_available, requires +from .image_processing_isaac import IsaacImageProcessorKwargs +from .modeling_isaac import BoundingBox, Polygon, SinglePoint + + +if is_torch_available(): + import torch + + +# --------------------------------Isaac Processor-------------------------------- + + +class IsaacProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs = IsaacImageProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": True, + "truncation": True, + "truncation_side": "left", + "return_attention_mask": True, + "return_overflowing_tokens": True, + "return_mm_token_type_ids": True, + "add_special_tokens": False, + }, + } + + +_point_box_or_polygon_tag = re.compile( + r"<(?Ppoint|point_box|polygon)(?P[^>]*)>(?P[\s\S]*?)", re.IGNORECASE +) +_attr_re = re.compile(r"(\w+)\s*=\s*(?:\"([^\"]*)\"|([^\s>]+))") +_coord_re = re.compile(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)") + + +@auto_docstring +@requires(backends=("vision",)) +class IsaacProcessor(ProcessorMixin): + def __init__( + self, + image_processor, + tokenizer, + chat_template: str | dict[str, str] | None = None, + max_sequence_length: int = 16384, + ): + r""" + max_sequence_length (`int`, *optional*, defaults to 16384): + Maximum packed multimodal sequence length produced by the processor. + """ + if chat_template is None: + chat_template = getattr(tokenizer, "chat_template", None) + + self.pad_token_id = tokenizer.pad_token_id + self.image_token = getattr(tokenizer, "image_pad_token", None) or getattr(tokenizer, "image_token", None) + self.image_token_id = getattr(tokenizer, "image_pad_token_id", None) or getattr( + tokenizer, "image_token_id", None + ) + self.max_sequence_length = max_sequence_length + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + text: str | list[str], + images: ImageInput | None = None, + **kwargs, + ) -> BatchFeature: + output_kwargs = self._merge_kwargs( + IsaacProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + # 1. Validate number of that text and images match + texts = [text] if isinstance(text, str) else text.copy() + rendered_image_token = "" + if self.image_token is not None and self.image_token != rendered_image_token: + # Isaac's current chat template still renders ``, while the tokenizer exposes + # `<|image_pad|>`. Normalize here so apply_chat_template(..., tokenize=True, + # return_dict=True) follows the standard ProcessorMixin path. + texts = [text_value.replace(rendered_image_token, self.image_token) for text_value in texts] + if images is None: + batched_images = [[] for _ in texts] + else: + fetched_images = self.image_processor.fetch_images(images) + batched_images = make_nested_list_of_images(fetched_images) + if len(batched_images) != len(texts): + num_images_in_text = [text_value.count(self.image_token) for text_value in texts] + num_images_in_images = [len(sample_images) for sample_images in batched_images] + add_message = "" + if sum(num_images_in_text) == sum(num_images_in_images): + add_message = " Make sure to pass your images as a nested list, where each sub-list holds images for one text sample." + raise ValueError( + f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(texts)}).{add_message}" + ) + + # 2. Process images + image_inputs = self.image_processor(images=batched_images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + + # 3. Expand text with image placeholders + merge_length = self.image_processor.pixel_shuffle_scale**2 + if image_grid_thw is None: + vision_segment_lengths = None + else: + vision_segment_lengths = image_grid_thw.prod(dim=-1) // merge_length + for batch_idx in range(len(texts)): + image_idx = 0 + while self.image_token in texts[batch_idx]: + num_image_tokens = vision_segment_lengths[batch_idx, image_idx] + texts[batch_idx] = texts[batch_idx].replace( + self.image_token, "<|placeholder|>" * num_image_tokens, 1 + ) + image_idx += 1 + texts[batch_idx] = texts[batch_idx].replace("<|placeholder|>", self.image_token) + + # 4. Process text + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids") + max_length = output_kwargs["text_kwargs"].pop("max_length", None) + max_length = self.max_sequence_length if max_length is None else max_length + text_inputs = self.tokenizer(texts, max_length=max_length, **output_kwargs["text_kwargs"]) + + truncated_input_ids: list[list[int] | None] = [None] * len(texts) + truncated_attention_mask: list[list[int] | None] = [None] * len(texts) + offset_mappings = text_inputs.get("offset_mapping") + truncated_offset_mapping: list[list[list[int]] | None] | None = None + if offset_mappings is not None: + truncated_offset_mapping = [None] * len(texts) + overflow_input_ids_per_sample = defaultdict(int) + + # 5. Drop overflowing token ids + if offset_mappings is None: + iterator = zip( + text_inputs["overflow_to_sample_mapping"], text_inputs["input_ids"], text_inputs["attention_mask"] + ) + else: + iterator = zip( + text_inputs["overflow_to_sample_mapping"], + text_inputs["input_ids"], + text_inputs["attention_mask"], + offset_mappings, + ) + + for sample in iterator: + if offset_mappings is None: + batch_idx, input_ids, attention_mask = sample + offset_mapping = None + else: + batch_idx, input_ids, attention_mask, offset_mapping = sample + + if truncated_input_ids[batch_idx] is None: + truncated_input_ids[batch_idx] = input_ids + truncated_attention_mask[batch_idx] = attention_mask + if truncated_offset_mapping is not None: + truncated_offset_mapping[batch_idx] = offset_mapping + else: + overflow_input_ids_per_sample[batch_idx] += input_ids.count(self.image_token_id) + + # 6. Do the same for overflowing pixel values. Isaac truncates images based on `max_length` + # We can't really truncate pixels, so we pass over an image offset mask. Model will crop off + # truncated image pixels at run-time using this mask + image_metadata = None + if image_grid_thw is not None: + batch_size, max_images = image_grid_thw.shape[:2] + image_metadata = torch.zeros((batch_size, max_images, 2), dtype=torch.long) + for batch_idx, image_lengths in enumerate(vision_segment_lengths): + remaining_dropped = overflow_input_ids_per_sample[batch_idx] + for image_idx, length in enumerate(image_lengths): + offset = 0 + if 0 < remaining_dropped < length: + offset = remaining_dropped + length -= offset + remaining_dropped = 0 + elif remaining_dropped >= length: + dropped_length = length + length = 0 + remaining_dropped -= dropped_length + + # Record which suffix of this image's placeholder span survives left truncation. + # The model still encodes the full image and uses this window for both feature gathering and vision RoPE. + image_metadata[batch_idx, image_idx, 0] = offset + image_metadata[batch_idx, image_idx, 1] = length + + data = { + "input_ids": torch.tensor(truncated_input_ids, dtype=torch.long), + "attention_mask": torch.tensor(truncated_attention_mask, dtype=torch.long), + "image_metadata": image_metadata, + **image_inputs, + } + if truncated_offset_mapping is not None: + data["offset_mapping"] = torch.tensor(truncated_offset_mapping, dtype=torch.long) + + if return_mm_token_type_ids: + data["mm_token_type_ids"] = self.create_mm_token_type_ids(data["input_ids"]) + + return BatchFeature(data=data, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + vision_data = {} + if image_sizes is not None: + images_kwargs = dict(IsaacProcessorKwargs._defaults.get("images_kwargs", {})) + images_kwargs.update(kwargs) + + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) + for image_size in image_sizes + ] + pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale") or self.image_processor.pixel_shuffle_scale + num_image_tokens = [num_patches // pixel_shuffle_scale**2 for num_patches in num_image_patches] + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + + @property + def model_input_names(self): + return super().model_input_names + ["mm_token_type_ids", "image_metadata"] + + @staticmethod + def _maybe_float(value: str | None) -> float | None: + try: + return float(value) + except (ValueError, TypeError): + return None + + @classmethod + def _parse_attrs(cls, attr_text: str) -> dict[str, str]: + attrs = {} + for match in _attr_re.finditer(attr_text): + key = match.group(1) + value = match.group(2) or match.group(3) or "" + attrs[key] = value + return attrs + + @classmethod + def _parse_point_body( + cls, + body: str, + mention: str | None = None, + t: str | None = None, + ) -> Any: + match = _coord_re.search(body) + if not match: + raise ValueError(f"Malformed tag: {body!r}") + x, y = int(match.group(1)), int(match.group(2)) + return SinglePoint(x=x, y=y, mention=mention, t=cls._maybe_float(t)) + + @classmethod + def _parse_box_body( + cls, + body: str, + mention: str | None = None, + t: str | None = None, + ) -> Any: + coords = list(_coord_re.finditer(body)) + if len(coords) < 2: + raise ValueError(f"Malformed tag: {body!r}") + + top_left = SinglePoint(x=int(coords[0].group(1)), y=int(coords[0].group(2))) + bottom_right = SinglePoint(x=int(coords[1].group(1)), y=int(coords[1].group(2))) + return BoundingBox(top_left=top_left, bottom_right=bottom_right, mention=mention, t=cls._maybe_float(t)) + + @classmethod + def _parse_polygon_body( + cls, + body: str, + mention: str | None = None, + t: str | None = None, + ) -> Any: + coords = list(_coord_re.finditer(body)) + if len(coords) < 3: + raise ValueError(f"Malformed tag: {body!r}") + + points = tuple(SinglePoint(x=int(coord.group(1)), y=int(coord.group(2))) for coord in coords) + return Polygon(points=points, mention=mention, t=cls._maybe_float(t)) + + @classmethod + def clean_text_and_extract_points( + cls, + text: str, + expected: str | None = None, + ) -> tuple[str, list[Any]]: + results: list[Any] = [] + for match in _point_box_or_polygon_tag.finditer(text): + tag = match.group("tag").lower() + attrs = cls._parse_attrs(match.group("attrs")) + mention = attrs.get("mention") + t = attrs.get("t") + if tag == "point": + if expected not in (None, "point"): + continue + results.append(cls._parse_point_body(match.group("body"), mention=mention, t=t)) + elif tag == "point_box": + if expected not in (None, "box"): + continue + results.append(cls._parse_box_body(match.group("body"), mention=mention, t=t)) + else: + if expected not in (None, "polygon"): + continue + results.append(cls._parse_polygon_body(match.group("body"), mention=mention, t=t)) + + clean_text = re.sub(r"\s+", " ", _point_box_or_polygon_tag.sub("", text or "")).strip() + return clean_text, results + + def post_process_generation( + self, + text: str, + expected: str | None = None, + cleanup_and_extract: bool = True, + ) -> str | tuple[str, list[Any]]: + if cleanup_and_extract: + return self.clean_text_and_extract_points(text, expected=expected) + return text + + def post_process_image_text_to_text( + self, + generated_outputs, + skip_special_tokens: bool = True, + cleanup_and_extract: bool = False, + expected: str | None = None, + **kwargs, + ): + generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs) + return [ + self.post_process_generation(text, expected=expected, cleanup_and_extract=cleanup_and_extract) + for text in generated_texts + ] + + +__all__ = ["IsaacProcessor"] diff --git a/tests/models/isaac/__init__.py b/tests/models/isaac/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/isaac/test_image_processing_isaac.py b/tests/models/isaac/test_image_processing_isaac.py new file mode 100644 index 000000000000..7c7627805a5f --- /dev/null +++ b/tests/models/isaac/test_image_processing_isaac.py @@ -0,0 +1,278 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np + +from transformers.models.isaac.image_processing_isaac import get_image_size_for_max_num_patches +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + +def _make_dummy_image(size=(32, 32), color=(255, 0, 0)): + return Image.new("RGB", size, color=color) + + +class IsaacImageProcessingTester: + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + min_resolution=30, + max_resolution=80, + do_resize=True, + do_rescale=True, + rescale_factor=1 / 255, + do_normalize=True, + image_mean=None, + image_std=None, + patch_size=16, + max_num_patches=16, + min_num_patches=4, + pixel_shuffle_scale=1, + do_convert_rgb=True, + ): + if image_mean is None: + image_mean = [0.485, 0.456, 0.406] + if image_std is None: + image_std = [0.229, 0.224, 0.225] + + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.patch_size = patch_size + self.max_num_patches = max_num_patches + self.min_num_patches = min_num_patches + self.pixel_shuffle_scale = pixel_shuffle_scale + self.do_convert_rgb = do_convert_rgb + + @property + def patch_dim(self): + return self.num_channels * self.patch_size * self.patch_size + + def prepare_image_processor_dict(self): + return { + "do_resize": self.do_resize, + "do_rescale": self.do_rescale, + "rescale_factor": self.rescale_factor, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "patch_size": self.patch_size, + "max_num_patches": self.max_num_patches, + "min_num_patches": self.min_num_patches, + "pixel_shuffle_scale": self.pixel_shuffle_scale, + "do_convert_rgb": self.do_convert_rgb, + } + + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + images = prepare_image_inputs( + batch_size=self.batch_size, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + num_channels=self.num_channels, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + return [[image] for image in images] + + def expected_output_image_shape(self, images): + max_images = 0 + max_patches = 0 + for sample_images in images: + if not isinstance(sample_images, (list, tuple)): + sample_images = [sample_images] + + max_images = max(max_images, len(sample_images)) + for image in sample_images: + if isinstance(image, Image.Image): + width, height = image.size + elif isinstance(image, np.ndarray): + height, width = image.shape[:2] + else: + height, width = image.shape[-2:] + + target_height, target_width = get_image_size_for_max_num_patches( + image_height=height, + image_width=width, + patch_size=self.patch_size, + max_num_patches=self.max_num_patches, + min_num_patches=self.min_num_patches, + pixel_shuffle_scale=self.pixel_shuffle_scale, + ) + max_patches = max(max_patches, (target_height // self.patch_size) * (target_width // self.patch_size)) + + return (max_images, max_patches, self.patch_dim) + + +@require_torch +@require_vision +class IsaacImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + def setUp(self): + super().setUp() + self.image_processor_tester = IsaacImageProcessingTester(self) + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_call_pil(self): + for image_processing_class in self.image_processing_classes.values(): + image_processing = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + for sample_images in image_inputs: + self.assertEqual(len(sample_images), 1) + self.assertIsInstance(sample_images[0], Image.Image) + + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual( + tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) + ) + + def test_call_numpy(self): + for image_processing_class in self.image_processing_classes.values(): + image_processing = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + for sample_images in image_inputs: + self.assertEqual(len(sample_images), 1) + self.assertIsInstance(sample_images[0], np.ndarray) + + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual( + tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) + ) + + def test_call_pytorch(self): + for image_processing_class in self.image_processing_classes.values(): + image_processing = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + for sample_images in image_inputs: + self.assertEqual(len(sample_images), 1) + self.assertIsInstance(sample_images[0], torch.Tensor) + + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual( + tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) + ) + + @unittest.skip(reason="Isaac image processor 4-channel coverage is not defined") + def test_call_numpy_4_channels(self): + pass + + def test_flat_list_is_single_multi_image_sample(self): + for image_processing_class in self.image_processing_classes.values(): + image_processor = image_processing_class( + **{ + **self.image_processor_dict, + "do_resize": False, + "patch_size": 16, + "max_num_patches": 64, + "min_num_patches": 1, + "pixel_shuffle_scale": 1, + } + ) + image_inputs = [ + _make_dummy_image(size=(32, 32), color=(255, 0, 0)), + _make_dummy_image(size=(32, 32), color=(0, 255, 0)), + ] + + encoding = image_processor(image_inputs, return_tensors="pt") + self.assertEqual(tuple(encoding["pixel_values"].shape), (1, 2, 4, 768)) + + expected_grids = torch.tensor([[[1, 2, 2], [1, 2, 2]]], dtype=torch.long) + torch.testing.assert_close(encoding["image_grid_thw"], expected_grids) + + def test_nested_multi_image_batch_preserves_grids_and_padding(self): + for image_processing_class in self.image_processing_classes.values(): + image_processor = image_processing_class( + **{ + **self.image_processor_dict, + "do_resize": False, + "patch_size": 16, + "max_num_patches": 64, + "min_num_patches": 1, + "pixel_shuffle_scale": 1, + } + ) + image_inputs = [ + [_make_dummy_image(size=(32, 32), color=(255, 0, 0))], + [ + _make_dummy_image(size=(48, 32), color=(0, 255, 0)), + _make_dummy_image(size=(32, 48), color=(0, 0, 255)), + ], + ] + + encoding = image_processor(image_inputs, return_tensors="pt") + self.assertEqual(tuple(encoding["pixel_values"].shape), (2, 2, 6, 768)) + + expected_grids = torch.tensor( + [ + [[1, 2, 2], [0, 0, 0]], + [[1, 2, 3], [1, 3, 2]], + ], + dtype=torch.long, + ) + + torch.testing.assert_close(encoding["image_grid_thw"], expected_grids) + self.assertTrue(torch.all(encoding["pixel_values"][0, 1] == 0)) + + def test_pixel_shuffle_scale_requires_divisible_token_grid(self): + for image_processing_class in self.image_processing_classes.values(): + image_processor = image_processing_class( + **{ + **self.image_processor_dict, + "do_resize": False, + "patch_size": 16, + "pixel_shuffle_scale": 2, + } + ) + + with self.assertRaisesRegex(ValueError, "must be divisible by pixel_shuffle_scale"): + image_processor([[_make_dummy_image(size=(32, 16))]], return_tensors="pt") diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py new file mode 100644 index 000000000000..ccfb83411f66 --- /dev/null +++ b/tests/models/isaac/test_modeling_isaac.py @@ -0,0 +1,893 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Testing suite for the Isaac model.""" + +import base64 +import io +import os +import unittest +from functools import lru_cache +from pathlib import Path + +import pytest +from huggingface_hub import is_offline_mode + +from tests.generation.test_utils import ( + GenerationTesterMixin, +) +from tests.test_configuration_common import ConfigTester +from tests.test_pipeline_mixin import PipelineTesterMixin +from transformers import ( + IsaacConfig, + IsaacForConditionalGeneration, + IsaacModel, + is_torch_available, +) +from transformers.models.isaac.processing_isaac import IsaacProcessor +from transformers.testing_utils import ( + require_flash_attn, + require_torch, + require_vision, + slow, + torch_device, +) +from transformers.utils import is_vision_available + + +if is_vision_available(): + from PIL import Image +else: + Image = None + +from ...test_modeling_common import ModelTesterMixin, ids_tensor + + +if is_torch_available(): + import torch + + +BASE_MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1-Base") +MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1") + +BASE_MODEL_REVISION = os.environ.get("ISAAC_TEST_MODEL_REVISION", "refs/pr/3") or None +MODEL_REVISION = os.environ.get("ISAAC_TEST_MODEL_REVISION", "refs/pr/5") or None + +LOCAL_CHECKPOINT = os.environ.get("ISAAC_TEST_MODEL_PATH") +RED_DOT_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg==" +ISAAC_IMAGE_TOKEN = "<|image_pad|>" + + +def compute_logits_statistics(tensor: torch.Tensor) -> dict[str, object]: + """ + Summarize logits with simple statistics that are stable across minor + implementation changes yet still sensitive to behavioral regressions. + """ + + float_tensor = tensor.detach().to(torch.float32).cpu() + flat = float_tensor.reshape(-1).to(torch.float64) + + def _rounded(value: torch.Tensor | float) -> float: + return round(float(value), 10) + + return { + "shape": list(float_tensor.shape), + "numel": flat.numel(), + "mean": _rounded(flat.mean()), + "std": _rounded(flat.std(unbiased=False)), + "min": _rounded(flat.min()), + "max": _rounded(flat.max()), + "sum": _rounded(flat.sum()), + "l2_norm": _rounded(torch.linalg.vector_norm(flat, ord=2)), + } + + +def pack_image_inputs(pixel_values, image_token_grids, image_token_offsets=None, image_token_lengths=None): + batch_size, max_images, _, _ = pixel_values.shape + device = pixel_values.device + + if image_token_offsets is None: + image_token_offsets = torch.zeros((batch_size, max_images), device=device, dtype=torch.long) + if image_token_lengths is None: + image_token_lengths = image_token_grids[..., 0] * image_token_grids[..., 1] + + image_grid_thw = torch.zeros((batch_size, max_images, 3), device=device, dtype=torch.long) + active_slots = image_token_grids.prod(dim=-1).gt(0) + image_grid_thw[..., 0] = active_slots.to(dtype=torch.long) + image_grid_thw[..., 1:] = image_token_grids + + image_metadata = torch.stack( + ( + image_token_offsets.to(device=device, dtype=torch.long), + image_token_lengths.to(device=device, dtype=torch.long), + ), + dim=-1, + ) + + return pixel_values, image_grid_thw, image_metadata + + +@lru_cache(maxsize=1) +def _load_red_dot_image(): + if Image is None: + return None + data = base64.b64decode(RED_DOT_B64) + return Image.open(io.BytesIO(data)).convert("RGB") + + +def _base_reference_checkpoint_or_skip(): + if LOCAL_CHECKPOINT: + resolved = Path(LOCAL_CHECKPOINT).expanduser() + if not resolved.exists(): + pytest.skip(f"Local checkpoint path {resolved} does not exist.") + return str(resolved) + if is_offline_mode(): + pytest.skip("Offline mode: set ISAAC_TEST_MODEL_PATH to a local checkpoint to run these tests.") + return BASE_MODEL_ID + + +def _reference_checkpoint_or_skip(): + if LOCAL_CHECKPOINT: + resolved = Path(LOCAL_CHECKPOINT).expanduser() + if not resolved.exists(): + pytest.skip(f"Local checkpoint path {resolved} does not exist.") + return str(resolved) + if is_offline_mode(): + pytest.skip("Offline mode: set ISAAC_TEST_MODEL_PATH to a local checkpoint to run these tests.") + return MODEL_ID + + +class IsaacModelTester: + def __init__( + self, + parent, + batch_size=2, + seq_length=5, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.is_training = True + self.expected_num_hidden_layers = num_hidden_layers + 1 + + self.text_config = { + "bos_token_id": 0, + "eos_token_id": 1, + "pad_token_id": 2, + "hidden_act": "silu", + "head_dim": hidden_size // num_attention_heads, + "hidden_size": hidden_size, + "vocab_size": vocab_size, + "intermediate_size": hidden_size * 3, + "max_position_embeddings": 128, + "model_type": "qwen3", + "num_attention_heads": num_attention_heads, + "num_hidden_layers": num_hidden_layers, + "num_key_value_heads": num_attention_heads, + # Keep the same multi-RoPE setup as the reference checkpoints but shrink the + # sections so they sum to the rotary half-dimension (4) of this tiny test model. + "rope_parameters": {"rope_type": "default", "mrope_section": [2, 1, 1], "mrope_interleaved": True}, + "tie_word_embeddings": True, + } + + self.vision_config = { + "hidden_size": hidden_size, + "intermediate_size": hidden_size * 2, + "num_hidden_layers": 1, + "num_attention_heads": num_attention_heads, + "num_channels": 3, + "num_patches": 64, + "patch_size": 4, + "pixel_shuffle_scale_factor": 1, + "attention_dropout": 0.0, + "layer_norm_eps": 1e-6, + } + + def get_config(self): + config = IsaacConfig( + text_config=self.text_config, + vision_config=self.vision_config, + ) + # Rely on eager attention so output_attentions tests remain compatible without flash attention. + config._attn_implementation = "eager" + config.text_config._attn_implementation = "eager" + config.vision_attn_implementation = "eager" + return config + + def prepare_config_and_inputs(self): + config = self.get_config() + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = torch.ones( + (self.batch_size, self.seq_length), + dtype=torch.long, + device=torch_device, + ) + labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + return config, input_ids, attention_mask, labels + + def prepare_config_and_inputs_for_common(self): + config, input_ids, attention_mask, labels = self.prepare_config_and_inputs() + patch_size = self.vision_config["patch_size"] + patch_dim = self.vision_config["num_channels"] * patch_size * patch_size + num_image_patches = 4 + vision_patches = torch.randn( + (self.batch_size, 1, num_image_patches, patch_dim), device=torch_device, dtype=torch.float32 + ) + image_token_grids = torch.tensor([[[2, 2]]] * self.batch_size, device=torch_device, dtype=torch.long) + pixel_values, image_grid_thw, image_metadata = pack_image_inputs( + pixel_values=vision_patches, + image_token_grids=image_token_grids, + ) + mm_token_type_ids = torch.zeros((self.batch_size, self.seq_length), device=torch_device, dtype=torch.long) + mm_token_type_ids[:, :num_image_patches] = 1 + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "mm_token_type_ids": mm_token_type_ids, + "pixel_values": pixel_values, + "image_grid_thw": image_grid_thw, + "image_metadata": image_metadata, + } + if labels is not None: + inputs_dict["labels"] = labels + return config, inputs_dict + + +@require_torch +class IsaacModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (IsaacModel, IsaacForConditionalGeneration) if is_torch_available() else () + pipeline_model_mapping = ( + { + "image-to-text": IsaacForConditionalGeneration, + "image-text-to-text": IsaacForConditionalGeneration, + } + if is_torch_available() + else {} + ) + _is_composite = True + test_attention_outputs = False + test_all_params_have_gradient = False + + def setUp(self): + self.model_tester = IsaacModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=IsaacConfig, + has_text_modality=False, + ) + + def test_config(self): + self.maxDiff = None + self.config_tester.run_common_tests() + + def prepare_config_and_inputs_for_generate(self, batch_size=2): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + input_keys_to_ignore = [ + "decoder_input_ids", + "decoder_attention_mask", + "use_cache", + "labels", + ] + + filtered_inputs_dict = { + k: v[:batch_size, ...] + if isinstance(v, torch.Tensor) and k not in ["pixel_values", "image_grid_thw", "image_metadata"] + else v + for k, v in inputs_dict.items() + if k not in input_keys_to_ignore + } + + filtered_inputs_dict["pixel_values"] = inputs_dict["pixel_values"][:batch_size] + filtered_inputs_dict["image_grid_thw"] = inputs_dict["image_grid_thw"][:batch_size] + filtered_inputs_dict["image_metadata"] = inputs_dict["image_metadata"][:batch_size] + + text_gen_config = config.get_text_config(decoder=True) + if text_gen_config.eos_token_id is not None and text_gen_config.pad_token_id is None: + text_gen_config.pad_token_id = ( + text_gen_config.eos_token_id + if isinstance(text_gen_config.eos_token_id, int) + else text_gen_config.eos_token_id[0] + ) + text_gen_config.eos_token_id = None + text_gen_config.forced_eos_token_id = None + + return config, filtered_inputs_dict + + @pytest.mark.generate + def test_left_padding_compatibility(self): + _, inputs_dict = self.prepare_config_and_inputs_for_generate() + mm_token_type_ids = inputs_dict["mm_token_type_ids"] + pad_size = (mm_token_type_ids.shape[0], 32) + padded_mm_token_type_ids = torch.cat( + (torch.zeros(pad_size, dtype=mm_token_type_ids.dtype, device=torch_device), mm_token_type_ids), dim=1 + ) + + super().test_left_padding_compatibility( + unpadded_custom_inputs={"mm_token_type_ids": mm_token_type_ids}, + padded_custom_inputs={"mm_token_type_ids": padded_mm_token_type_ids}, + ) + + @unittest.skip(reason="Assisted decoding not supported; Qwen3 backbone does not implement returning attentions") + def test_assisted_decoding_matches_greedy_search_0_random(self): + pass + + @unittest.skip(reason="Assisted decoding not supported; Qwen3 backbone does not implement returning attentions") + def test_assisted_decoding_matches_greedy_search_1_same(self): + pass + + @unittest.skip(reason="Unsupported") + def test_flash_attn_kernels_inference_equivalence(self): + pass + + @unittest.skip(reason="Isaac is image-only.") + def test_get_video_features_output_0(self): + pass + + @unittest.skip(reason="Isaac is image-only.") + def test_get_video_features_output_1(self): + pass + + @unittest.skip(reason="Isaac is image-only.") + def test_get_video_features_output_2(self): + pass + + @unittest.skip(reason="Isaac is image-only.") + def test_get_video_features_hidden_states(self): + pass + + @unittest.skip(reason="Isaac is image-only.") + def test_get_video_features_attentions(self): + pass + + +@require_torch +@require_vision +@slow +@require_flash_attn +class IsaacGenerationIntegrationTest(unittest.TestCase): + max_new_tokens = 25 + dtype = torch.bfloat16 + + def setUp(self): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.checkpoint = _base_reference_checkpoint_or_skip() + self.hf_config = IsaacConfig.from_pretrained(self.checkpoint, revision=BASE_MODEL_REVISION) + self.processor = IsaacProcessor.from_pretrained(self.checkpoint, revision=BASE_MODEL_REVISION, do_pad=True) + self.tokenizer = self.processor.tokenizer + self.hf_config.vision_config._attn_implementation = "flash_attention_2" + self.hf_config.vision_config.attn_implementation = "flash_attention_2" + self.model = IsaacForConditionalGeneration.from_pretrained( + self.checkpoint, config=self.hf_config, revision=BASE_MODEL_REVISION + ) + self.model = self.model.to(device=self.device, dtype=self.dtype) + self.model.eval() + + def test_generate_from_image_text(self): + image = _load_red_dot_image() + if image is None: + pytest.skip("PIL.Image is required for Isaac generation tests.") + + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image:"}, + {"type": "image", "image": image}, + ], + } + ] + inputs = self.processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + ).to(self.device, dtype=self.dtype) + + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=self.max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + ) + + generated_ids = outputs.sequences[:, inputs["input_ids"].shape[1] :] + generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + expected_fragment = "The image is a close-up photograph of a red cross symbol." + assert expected_fragment in generated_text + + def test_generate_from_text_only(self): + conversation = [ + { + "role": "user", + "content": [{"type": "text", "text": "What is the pythogorean theorem?"}], + } + ] + inputs = self.processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + ).to(self.device, dtype=self.dtype) + + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=100, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + ) + + generated_ids = outputs.sequences[:, inputs["input_ids"].shape[1] :] + generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + expected_fragmenet = "The Pythagorean theorem is a fundamental principle in geometry that relates the lengths of the sides of a right-angled triangle. Let's break it down step by step:" + assert expected_fragmenet in generated_text + + def test_vqa_from_image(self): + conversation = [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", + }, + {"type": "text", "text": "Is it safe to cross the street at this moment?"}, + ], + } + ] + inputs = self.processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + ).to(self.device, dtype=self.dtype) + + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=256, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + ) + + generated_ids = outputs.sequences[:, inputs["input_ids"].shape[1] :] + generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + expected_response = "\nNo, it is not safe to cross the street at this moment. The traffic light for pedestrians is red, indicating that it is not safe to cross." + assert generated_text == expected_response + + def test_logit_equivalence(self): + image = _load_red_dot_image() + if image is None: + pytest.skip("PIL.Image is required for Isaac generation tests.") + + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image:"}, + {"type": "image", "image": image}, + ], + } + ] + inputs = self.processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + ).to(self.device, dtype=self.dtype) + + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=10, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + output_logits=True, + ) + + hf_logits = torch.cat(outputs.logits, dim=0) + logit_stats = compute_logits_statistics(hf_logits) + expected_logit_stats = { + "shape": [10, 151936], + "numel": 1519360, + "mean": 0.0608877803, + "std": 2.8308793244, + "min": -12.0625, + "max": 31.0, + "sum": 92510.4578057677, + "l2_norm": 3490.2146142251, + } + assert logit_stats == expected_logit_stats + + def test_batched_generation_matches_individual(self): + image = _load_red_dot_image() + if image is None: + pytest.skip("PIL.Image is required for Isaac generation tests.") + + conversations = [ + [ + { + "role": "user", + "content": [{"type": "text", "text": "What is the pythogorean theorem?"}], + } + ], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image:"}, + {"type": "image", "image": image}, + ], + } + ], + [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", + }, + {"type": "text", "text": "Is it safe to cross the street at this moment?"}, + ], + } + ], + ] + + single_inputs = [ + self.processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + ) + for conversation in conversations + ] + batch_inputs = self.processor.apply_chat_template( + conversations, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + processor_kwargs={"padding_side": "left"}, + ) + batch_input_ids = batch_inputs["input_ids"] + max_length = batch_input_ids.shape[1] + + pad_id = self.tokenizer.pad_token_id + if pad_id is None: + pad_id = getattr(self.processor, "pad_token_id", 0) + + sample_lengths = [single_input["input_ids"].squeeze(0).shape[0] for single_input in single_inputs] + for i, (single_input, batch_ids, single_len) in enumerate(zip(single_inputs, batch_input_ids, sample_lengths)): + single_ids = single_input["input_ids"].squeeze(0) + torch.testing.assert_close(batch_ids[-single_len:], single_ids) + + batch_modality_row = batch_inputs["mm_token_type_ids"][i] + expected_modality = torch.full( + (max_length,), + batch_modality_row[-1].item(), + dtype=batch_modality_row.dtype, + device=batch_modality_row.device, + ) + expected_modality[-single_len:] = single_input["mm_token_type_ids"].squeeze(0) + torch.testing.assert_close(batch_modality_row, expected_modality) + + if batch_inputs["image_grid_thw"] is not None: + batch_image_mask = batch_inputs["image_grid_thw"][i, :, 0].eq(1) + expected_image_count = int(batch_image_mask.sum().item()) + if single_input["image_grid_thw"] is None: + assert expected_image_count == 0 + else: + single_image_mask = single_input["image_grid_thw"][0, :, 0].eq(1) + assert expected_image_count == int(single_image_mask.sum().item()) + if expected_image_count > 0: + batch_image_grid_thw = batch_inputs["image_grid_thw"][i, batch_image_mask] + single_image_grid_thw = single_input["image_grid_thw"][0, single_image_mask] + batch_image_metadata = batch_inputs["image_metadata"][i, batch_image_mask] + single_image_metadata = single_input["image_metadata"][0, single_image_mask] + + torch.testing.assert_close(batch_image_grid_thw, single_image_grid_thw) + torch.testing.assert_close(batch_image_metadata, single_image_metadata) + + for batch_pixel_values, single_pixel_values, grid_thw in zip( + batch_inputs["pixel_values"][i, batch_image_mask], + single_input["pixel_values"][0, single_image_mask], + batch_image_grid_thw, + strict=True, + ): + valid_patch_count = int((grid_thw[1] * grid_thw[2]).item()) + torch.testing.assert_close( + batch_pixel_values[:valid_patch_count], + single_pixel_values[:valid_patch_count], + ) + + if single_len == max_length: + continue + + pad_span = batch_ids[: max_length - single_len] + assert torch.all(pad_span == pad_id), f"sample {i} left pad span not padded with pad id" + torch.testing.assert_close( + batch_inputs["attention_mask"][i], + batch_ids.ne(pad_id).long(), + ) + + single_texts = [] + for single_input in single_inputs: + single_input = single_input.to(self.device, dtype=self.dtype) + with torch.no_grad(): + outputs = self.model.generate( + **single_input, + max_new_tokens=self.max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + ) + generated_ids = outputs.sequences[:, single_input["input_ids"].shape[1] :] + single_texts.append(self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]) + + batch_inputs = batch_inputs.to(self.device, dtype=self.dtype) + with torch.no_grad(): + batch_outputs = self.model.generate( + **batch_inputs, + max_new_tokens=100, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + ) + batch_generated_ids = batch_outputs.sequences[:, batch_inputs["input_ids"].shape[1] :] + batch_texts = self.processor.batch_decode(batch_generated_ids, skip_special_tokens=True) + assert len(batch_texts) == len(single_texts) == 3 + + for i, (batch_text, single_text) in enumerate(zip(batch_texts, single_texts)): + assert single_text in batch_text, f"batch[{i}] mismatch: {batch_text!r} vs single[{i}] {single_text!r}" + + def test_batched_beam_generation_matches_individual(self): + image = _load_red_dot_image() + if image is None: + pytest.skip("PIL.Image is required for Isaac generation tests.") + + conversations = [ + [ + { + "role": "user", + "content": [{"type": "text", "text": "What is the pythogorean theorem?"}], + } + ], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image:"}, + {"type": "image", "image": image}, + ], + } + ], + [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", + }, + {"type": "text", "text": "Is it safe to cross the street at this moment?"}, + ], + } + ], + ] + beam_kwargs = {"num_beams": 2} + + single_texts = [] + for conversation in conversations: + single_input = self.processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + ).to(self.device, dtype=self.dtype) + with torch.no_grad(): + outputs = self.model.generate( + **single_input, + max_new_tokens=self.max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + **beam_kwargs, + ) + generated_ids = outputs.sequences[:, single_input["input_ids"].shape[1] :] + single_texts.append(self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]) + + batch_inputs = self.processor.apply_chat_template( + conversations, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + processor_kwargs={"padding_side": "left"}, + ).to(self.device, dtype=self.dtype) + with torch.no_grad(): + batch_outputs = self.model.generate( + **batch_inputs, + max_new_tokens=100, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + **beam_kwargs, + ) + batch_generated_ids = batch_outputs.sequences[:, batch_inputs["input_ids"].shape[1] :] + batch_texts = self.processor.batch_decode(batch_generated_ids, skip_special_tokens=True) + assert len(batch_texts) == len(single_texts) == 3 + + for i, (batch_text, single_text) in enumerate(zip(batch_texts, single_texts)): + assert single_text in batch_text, ( + f"beam batch[{i}] mismatch: {batch_text!r} vs single[{i}] {single_text!r}" + ) + + +@require_torch +@require_vision +@slow +@require_flash_attn +class IsaacBoxPointingIntegrationTest(unittest.TestCase): + max_new_tokens = 256 + dtype = torch.bfloat16 + + def setUp(self): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.checkpoint = _reference_checkpoint_or_skip() + self.hf_config = IsaacConfig.from_pretrained(self.checkpoint, revision=MODEL_REVISION) + # The current local slow fallback only supports padded packing for this checkpoint. + self.processor = IsaacProcessor.from_pretrained(self.checkpoint, revision=MODEL_REVISION, do_pad=True) + self.tokenizer = self.processor.tokenizer + self.hf_config.vision_config._attn_implementation = "flash_attention_2" + self.hf_config.vision_config.attn_implementation = "flash_attention_2" + self.model = IsaacForConditionalGeneration.from_pretrained( + self.checkpoint, config=self.hf_config, revision=MODEL_REVISION + ) + self.model = self.model.to(device=self.device, dtype=self.dtype) + self.model.eval() + + def test_hf_generate_box_points(self): + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "BOX"}, + { + "type": "image", + "url": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", + }, + { + "type": "text", + "text": "Determine whether it is safe to cross the street. Look for signage and moving traffic.", + }, + ], + } + ] + inputs = self.processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + ).to(self.device, dtype=self.dtype) + + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=self.max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + ) + + generated_ids = outputs.sequences[:, inputs["input_ids"].shape[1] :] + generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + _, points = self.processor.post_process_generation(generated_text, expected="box") + assert len(points) == 1 + first_point = points[0] + assert first_point.top_left.x < first_point.bottom_right.x + assert first_point.top_left.y < first_point.bottom_right.y + assert first_point.mention == "traffic light" + assert first_point.top_left.x == 808 + assert first_point.top_left.y == 247 + assert first_point.bottom_right.x == 863 + assert first_point.bottom_right.y == 386 + + def test_hf_generate_polygon_points(self): + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "POLYGON"}, + { + "type": "image", + "url": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp", + }, + { + "type": "text", + "text": "Determine whether it is safe to cross the street. Look for signage and moving traffic.", + }, + ], + } + ] + inputs = self.processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + ).to(self.device, dtype=self.dtype) + + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=self.max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + ) + + generated_ids = outputs.sequences[:, inputs["input_ids"].shape[1] :] + generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + _, polygons = self.processor.post_process_generation(generated_text, expected="polygon") + assert len(polygons) == 1 + first_polygon = polygons[0] + xs = [point.x for point in first_polygon.points] + ys = [point.y for point in first_polygon.points] + expected_left, expected_top, expected_right, expected_bottom = 808, 247, 863, 386 + + assert len(first_polygon.points) >= 3 + assert first_polygon.mention == "traffic light" + assert min(xs) >= expected_left - 4 + assert max(xs) <= expected_right + 4 + assert min(ys) >= expected_top - 4 + assert max(ys) <= expected_bottom + 4 + assert max(xs) - min(xs) >= 35 + assert max(ys) - min(ys) >= 100 + assert any(abs(x - expected_left) <= 12 for x in xs) + assert any(abs(x - expected_right) <= 12 for x in xs) + assert any(abs(y - expected_top) <= 12 for y in ys) + assert any(abs(y - expected_bottom) <= 12 for y in ys) diff --git a/tests/models/isaac/test_processing_isaac.py b/tests/models/isaac/test_processing_isaac.py new file mode 100644 index 000000000000..7c2bbc3f6048 --- /dev/null +++ b/tests/models/isaac/test_processing_isaac.py @@ -0,0 +1,172 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Testing suite for the Isaac processor.""" + +import os +import unittest +from pathlib import Path + +import numpy as np +import pytest +from huggingface_hub import is_offline_mode + +from transformers.models.isaac.processing_isaac import IsaacProcessor +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image +else: + Image = None + + +def _make_dummy_image(size=(32, 32), color=(255, 0, 0)): + if Image is None: + raise RuntimeError("PIL.Image is not available in this environment.") + return Image.new("RGB", size, color=color) + + +BASE_MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1-Base") +BASE_MODEL_REVISION = os.environ.get("ISAAC_TEST_MODEL_REVISION", "refs/pr/3") or None +LOCAL_CHECKPOINT = os.environ.get("ISAAC_TEST_MODEL_PATH") + + +def _checkpoint_or_skip(model_id=BASE_MODEL_ID): + if LOCAL_CHECKPOINT: + resolved = Path(LOCAL_CHECKPOINT).expanduser() + if not resolved.exists(): + pytest.skip(f"Local checkpoint path {resolved} does not exist.") + return str(resolved) + if is_offline_mode(): + pytest.skip("Offline mode: set ISAAC_TEST_MODEL_PATH to a local checkpoint to run these tests.") + return model_id + + +@require_torch +@require_vision +class IsaacProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = IsaacProcessor + model_id = BASE_MODEL_ID + images_input_name = "pixel_values" + + @classmethod + def _setup_from_pretrained(cls, model_id, **kwargs): + checkpoint = _checkpoint_or_skip(model_id) + return super()._setup_from_pretrained( + checkpoint, + revision=BASE_MODEL_REVISION, + patch_size=4, + max_num_patches=4, + **kwargs, + ) + + @classmethod + def _setup_test_attributes(cls, processor): + cls.image_token = processor.image_token + cls.pad_token_id = processor.tokenizer.pad_token_id + cls.image_pad_token_id = processor.image_token_id + + def prepare_image_inputs(self, batch_size: int | None = None, nested: bool = False): + if batch_size is None: + return _make_dummy_image(size=(16, 16)) + images = [_make_dummy_image(size=(16, 16), color=(50 * (i + 1), 0, 0)) for i in range(batch_size)] + if nested: + return [[image] for image in images] + return images + + @unittest.skip("Isaac chat templates emit placeholders but the processor consumes image pad tokens") + def test_apply_chat_template_image_0(self): + pass + + @unittest.skip("Isaac chat templates emit placeholders but the processor consumes image pad tokens") + def test_apply_chat_template_image_1(self): + pass + + def test_apply_chat_template_image_placeholder_expands_to_image_pad_tokens(self): + processor = self.get_processor() + image = _make_dummy_image(size=(16, 16)) + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this."}, + {"type": "image", "image": image}, + ], + } + ] + ] + + formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + self.assertEqual(len(formatted_prompt), 1) + self.assertIn("", formatted_prompt[0]) + + out_dict = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + self.assertTrue( + all( + key in out_dict + for key in [ + "input_ids", + "attention_mask", + "pixel_values", + "image_grid_thw", + "image_metadata", + "mm_token_type_ids", + ] + ) + ) + + expected_num_image_tokens = processor._get_num_multimodal_tokens(image_sizes=[(image.height, image.width)])[ + "num_image_tokens" + ][0] + actual_num_image_tokens = int(out_dict["input_ids"][0].eq(processor.image_token_id).sum().item()) + + self.assertEqual(actual_num_image_tokens, expected_num_image_tokens) + self.assertEqual(int(out_dict["mm_token_type_ids"][0].sum().item()), expected_num_image_tokens) + self.assertEqual(int(out_dict["image_metadata"][0, 0, 1].item()), expected_num_image_tokens) + self.assertTrue( + torch.all(out_dict["mm_token_type_ids"][0][out_dict["input_ids"][0].eq(processor.image_token_id)] == 1) + ) + + def test_get_num_multimodal_tokens_matches_processor_call(self): + processor = self.get_processor() + + image_sizes = [(100, 100), (300, 100), (500, 30), (213, 167)] + image_inputs = [np.random.randint(255, size=(h, w, 3), dtype=np.uint8) for h, w in image_sizes] + + text = [f"This is an image {self.image_token}"] * len(image_inputs) + inputs = processor( + text=text, + images=[[image] for image in image_inputs], + padding=True, + return_mm_token_type_ids=True, + return_tensors="pt", + ) + + num_image_tokens_from_call = inputs.mm_token_type_ids.sum(-1).tolist() + num_image_tokens_from_helper = processor._get_num_multimodal_tokens(image_sizes=image_sizes) + self.assertListEqual(num_image_tokens_from_call, num_image_tokens_from_helper["num_image_tokens"]) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 7366845c4d78..3ac8ff1573f8 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -3246,6 +3246,67 @@ def test_vision_language_model(self): assert dec is model.model.language_model, f"LLaVA get_decoder() should return language_model, got {type(dec)}" +class TestEmbeddingAccessMixin(unittest.TestCase): + def test_get_input_embeddings_supports_dotted_input_embed_layer(self): + class NestedEmbeddingModel(PreTrainedModel): + config_class = PreTrainedConfig + _input_embed_layer = "text_model.embed_tokens" + + def __init__(self, config): + super().__init__(config) + self.text_model = nn.Module() + self.text_model.embed_tokens = nn.Embedding(8, 4) + + def forward(self, input_ids=None): + return input_ids + + model = NestedEmbeddingModel(PreTrainedConfig()) + + assert model.get_input_embeddings() is model.text_model.embed_tokens + + def test_set_input_embeddings_supports_dotted_input_embed_layer(self): + class NestedEmbeddingModel(PreTrainedModel): + config_class = PreTrainedConfig + _input_embed_layer = "text_model.embed_tokens" + + def __init__(self, config): + super().__init__(config) + self.text_model = nn.Module() + self.text_model.embed_tokens = nn.Embedding(8, 4) + + def forward(self, input_ids=None): + return input_ids + + model = NestedEmbeddingModel(PreTrainedConfig()) + new_embeddings = nn.Embedding(10, 4) + + model.set_input_embeddings(new_embeddings) + + assert model.get_input_embeddings() is new_embeddings + assert model.text_model.embed_tokens is new_embeddings + + def test_invalid_dotted_input_embed_layer_raises(self): + class NestedEmbeddingModel(PreTrainedModel): + config_class = PreTrainedConfig + _input_embed_layer = "text_model.missing_embed_tokens" + + def __init__(self, config): + super().__init__(config) + self.text_model = nn.Module() + self.text_model.embed_tokens = nn.Embedding(8, 4) + + def forward(self, input_ids=None): + return input_ids + + model = NestedEmbeddingModel(PreTrainedConfig()) + + with self.assertRaises(NotImplementedError): + model.get_input_embeddings() + + with self.assertRaises(NotImplementedError): + model.set_input_embeddings(nn.Embedding(10, 4)) + + class TestGetEncoder(unittest.TestCase): def test_seq2seq_lm_get_encoder_returns_encoder(self): cfg = BartConfig( diff --git a/utils/check_repo.py b/utils/check_repo.py index b1a3d158c716..4d040fc165c6 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -217,6 +217,8 @@ "Qwen3VLMoeTextModel", # Building part of bigger (tested) model. "Qwen3_5TextModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3_5ForConditionalGeneration. "Qwen3_5MoeTextModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3_5MoeForConditionalGeneration. + "IsaacTextModel", # Building part of bigger (tested) model. Tested implicitly through IsaacForConditionalGeneration. + "IsaacVisionModel", # Building part of bigger (tested) model. Tested implicitly through IsaacForConditionalGeneration. "Qwen2_5OmniForConditionalGeneration", # Not a regular model. Testted in Qwen2_5OmniModelIntergrationTest "Qwen2_5OmniTalkerForConditionalGeneration", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest. "Qwen2_5OmniTalkerModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest. @@ -467,6 +469,8 @@ "PaddleOCRVisionModel", # Building part of bigger (tested) model "PaddleOCRVisionTransformer", # Building part of bigger (tested) model "PaddleOCRTextModel", # Building part of bigger (tested) model + "IsaacTextModel", # Building part of a bigger model + "IsaacVisionModel", # Building part of a bigger model "Qwen2_5OmniTalkerForConditionalGeneration", # Building part of a bigger model "Qwen2_5OmniTalkerModel", # Building part of a bigger model "Qwen2_5OmniThinkerForConditionalGeneration", # Building part of a bigger model @@ -1130,6 +1134,7 @@ def find_all_documented_objects() -> list[str]: "Ernie4_5_VL_MoeImageProcessorFast", # BC Alias "Ernie4_5_VL_MoeImageProcessorPil", # BC Alias "Ernie4_5_VL_MoeModel", # BC Alias + "IsaacVisionModel", # Internal building block tested implicitly through IsaacForConditionalGeneration. "Ernie4_5_VL_MoeTextConfig", # BC Alias "Ernie4_5_VL_MoeTextModel", # BC Alias "Ernie4_5_VL_MoeVariableResolutionResamplerModel", # BC Alias