diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index ba69db1c5e78..296df33bf11b 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -1237,6 +1237,8 @@
title: InstructBlipVideo
- local: model_doc/internvl
title: InternVL
+ - local: model_doc/isaac
+ title: Isaac
- local: model_doc/janus
title: Janus
- local: model_doc/kosmos-2
diff --git a/docs/source/en/model_doc/isaac.md b/docs/source/en/model_doc/isaac.md
new file mode 100644
index 000000000000..91773a4962ed
--- /dev/null
+++ b/docs/source/en/model_doc/isaac.md
@@ -0,0 +1,143 @@
+
+*This model was released on {release_date} and added to Hugging Face Transformers on 2026-04-13.*
+
+
+
+# Isaac
+
+## Overview
+
+Isaac is Perceptron's vision-language model (VLM) that pairs a SigLIP2 vision encoder with a Qwen3 decoder-only stack. The
+Transformers implementation supports text-only and image-conditioned generation, including prompts with multiple interleaved
+images. Isaac uses variable-resolution image preprocessing and can optionally reduce spatial tokens with pixel shuffle to keep
+long multimodal prompts manageable. For more information, refer to the [technical report](https://github.com/perceptron-ai-inc/perceptron/blob/main/papers/isaac_01.pdf).
+
+Isaac checkpoints are distributed under Perceptron's Non-Production license; please review the license that ships with the
+weights before using them in commercial settings.
+
+## Usage tips
+
+- Batched inputs can mix text-only and multimodal samples. For direct processor/model batching, pass images as a nested
+ list such as `[[], [image_a], [image_b, image_c]]`.
+- `image_grid_thw[batch_idx, image_slot] == (0, 0, 0)` marks a padded empty slot. Real image slots have
+ `(T=1, H>0, W>0)`.
+- If truncation is enabled, the processor keeps the rightmost part of the multimodal prompt and updates the slot-local
+ `image_metadata[..., 0]` and `image_metadata[..., 1]` values automatically.
+
+## Usage example
+
+Isaac uses explicit image placeholders in the rendered prompt. Every occurrence of `processor.image_token` (usually ``) must have a matching image in the `images` argument.
+
+```py
+import torch
+from PIL import Image
+from transformers import AutoProcessor, IsaacForConditionalGeneration
+
+model_id = "PerceptronAI/Isaac-0.1"
+processor = AutoProcessor.from_pretrained(model_id)
+model = IsaacForConditionalGeneration.from_pretrained(
+ model_id,
+ dtype=torch.bfloat16,
+ device_map="auto",
+ attn_implementation="flash_attention_2",
+)
+
+conversation = [
+ {
+
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Compare the two figures and explain what changed."},
+ {"type": "image", "path": "first_image.png"},
+ {"type": "image", "path": "second_image.png"},
+ ],
+ },
+]
+
+prompt = processor.apply_chat_template(
+ conversation,
+ tokenize=True,
+ return_dict=True,
+ add_generation_prompt=True,
+ return_tensors="pt",
+)
+
+inputs = processor(text=prompt, images=images, return_tensors="pt").to(model.device)
+generated_ids = model.generate(**inputs, max_new_tokens=256, do_sample=False,)
+
+generated_ids = generated_ids[:, inputs["input_ids"].shape[1] :]
+response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+print(response)
+```
+
+### Post-processing grounded outputs
+
+Isaac can generate grounded points and boxes in tagged text spans. Use `post_process_generation()` to strip the tags and
+recover structured annotations.
+
+```py
+clean_text, annotations = processor.post_process_generation(response, expected="box")
+print(clean_text)
+print(annotations)
+```
+
+Set `expected="point"` to extract point annotations, or leave `expected=None` to collect both points and boxes.
+
+## IsaacVisionConfig
+
+[[autodoc]] IsaacVisionConfig
+
+## IsaacTextConfig
+
+[[autodoc]] IsaacTextConfig
+
+## IsaacConfig
+
+[[autodoc]] IsaacConfig
+
+## IsaacVisionModel
+
+[[autodoc]] IsaacVisionModel
+
+## IsaacTextModel
+
+[[autodoc]] IsaacTextModel
+ - forward
+
+## IsaacModel
+
+[[autodoc]] IsaacModel
+ - forward
+
+## IsaacForConditionalGeneration
+
+[[autodoc]] IsaacForConditionalGeneration
+ - forward
+
+## IsaacProcessor
+
+[[autodoc]] IsaacProcessor
+
+## IsaacImageProcessor
+
+[[autodoc]] IsaacImageProcessor
diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py
index 2a6dc23ba9d0..b3aebcd5cebe 100755
--- a/src/transformers/conversion_mapping.py
+++ b/src/transformers/conversion_mapping.py
@@ -133,6 +133,10 @@ def _build_checkpoint_conversion_mapping():
),
WeightRenaming(source_patterns=r"^visual", target_patterns="model.visual"),
],
+ "isaac": [
+ WeightRenaming(source_patterns=r"text_model", target_patterns="language_model"),
+ WeightRenaming(source_patterns=r"vision_tower", target_patterns="visual"),
+ ],
"colqwen2": [
WeightRenaming(source_patterns=r"vlm.model", target_patterns="vlm"),
WeightRenaming(source_patterns=r"vlm(?!\.(language_model|visual))", target_patterns="vlm.language_model"),
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 81ff067b1470..df5f8bc4958d 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -1003,6 +1003,33 @@ class EmbeddingAccessMixin:
_input_embed_layer = "embed_tokens" # default layer that holds input embeddings.
+ def _resolve_input_embed_layer(self) -> tuple[nn.Module | None, str]:
+ """
+ Returns the parent module and leaf attribute for `_input_embed_layer`.
+
+ Supports both a simple attribute name such as `embed_tokens` and a dotted path such as
+ `text_model.embed_tokens`.
+ """
+
+ name = getattr(self, "_input_embed_layer", "embed_tokens")
+ if "." not in name:
+ return None, name
+
+ module_path, _, attribute_name = name.rpartition(".")
+ try:
+ module = self.get_submodule(module_path)
+ except AttributeError as error:
+ raise NotImplementedError(
+ f"`_input_embed_layer={name}` could not be resolved for {self.__class__.__name__}."
+ ) from error
+
+ if not hasattr(module, attribute_name):
+ raise NotImplementedError(
+ f"`_input_embed_layer={name}` could not be resolved for {self.__class__.__name__}."
+ )
+
+ return module, attribute_name
+
def get_input_embeddings(self) -> nn.Module:
"""
Returns the model's input embeddings.
@@ -1011,7 +1038,9 @@ def get_input_embeddings(self) -> nn.Module:
`nn.Module`: A torch module mapping vocabulary to hidden states.
"""
- name = getattr(self, "_input_embed_layer", "embed_tokens")
+ module, name = self._resolve_input_embed_layer()
+ if module is not None:
+ return getattr(module, name)
# 1) Direct attribute (most NLP models).
if (default_embedding := getattr(self, name, None)) is not None:
@@ -1044,7 +1073,11 @@ def set_input_embeddings(self, value: nn.Module):
should) override for exotic layouts.
"""
- name = getattr(self, "_input_embed_layer", "embed_tokens")
+ module, name = self._resolve_input_embed_layer()
+ if module is not None:
+ setattr(module, name, value)
+ return
+
# 1) Direct attribute (most NLP models)
if hasattr(self, name):
setattr(self, name, value)
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 989be9eb114e..df538531c3ba 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -201,6 +201,7 @@
from .instructblip import *
from .instructblipvideo import *
from .internvl import *
+ from .isaac import *
from .jais2 import *
from .jamba import *
from .janus import *
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 2c0fe88d0e74..95b8378c0861 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -238,6 +238,8 @@
("instructblipvideo", "InstructBlipVideoConfig"),
("internvl", "InternVLConfig"),
("internvl_vision", "InternVLVisionConfig"),
+ ("isaac", "IsaacConfig"),
+ ("isaac_vision", "IsaacVisionConfig"),
("jais2", "Jais2Config"),
("jamba", "JambaConfig"),
("janus", "JanusConfig"),
@@ -758,6 +760,8 @@
("instructblipvideo", "InstructBlipVideo"),
("internvl", "InternVL"),
("internvl_vision", "InternVLVision"),
+ ("isaac", "Isaac"),
+ ("isaac_vision", "IsaacVision"),
("jais2", "Jais2"),
("jamba", "Jamba"),
("janus", "Janus"),
@@ -1109,6 +1113,7 @@
("gemma4_audio", "gemma4"),
("gemma4_text", "gemma4"),
("gemma4_vision", "gemma4"),
+ ("isaac_vision", "isaac"),
("glm4v_vision", "glm4v"),
("glm4v_moe_vision", "glm4v_moe"),
("glm4v_text", "glm4v"),
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 4ada3ba6b8ed..febd934a38ab 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -145,6 +145,7 @@
("imagegpt", {"torchvision": "ImageGPTImageProcessor", "pil": "ImageGPTImageProcessorPil"}),
("instructblip", {"torchvision": "BlipImageProcessor", "pil": "BlipImageProcessorPil"}),
("internvl", {"torchvision": "GotOcr2ImageProcessor", "pil": "GotOcr2ImageProcessorPil"}),
+ ("isaac", {"torchvision": "IsaacImageProcessor"}),
("janus", {"torchvision": "JanusImageProcessor", "pil": "JanusImageProcessorPil"}),
("kosmos-2", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}),
("kosmos-2.5", {"torchvision": "Kosmos2_5ImageProcessor", "pil": "Kosmos2_5ImageProcessorPil"}),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index d4cb17cddfa6..120668f4ac8d 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -235,6 +235,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("instructblipvideo", "InstructBlipVideoModel"),
("internvl", "InternVLModel"),
("internvl_vision", "InternVLVisionModel"),
+ ("isaac", "IsaacModel"),
+ ("isaac_vision", "IsaacVisionModel"),
("jais2", "Jais2Model"),
("jamba", "JambaModel"),
("janus", "JanusModel"),
@@ -990,6 +992,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("instructblip", "InstructBlipForConditionalGeneration"),
("instructblipvideo", "InstructBlipVideoForConditionalGeneration"),
("internvl", "InternVLForConditionalGeneration"),
+ ("isaac", "IsaacForConditionalGeneration"),
("janus", "JanusForConditionalGeneration"),
("kosmos-2", "Kosmos2ForConditionalGeneration"),
("kosmos-2.5", "Kosmos2_5ForConditionalGeneration"),
diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py
index 262480b71485..61cf836d7c56 100644
--- a/src/transformers/models/auto/processing_auto.py
+++ b/src/transformers/models/auto/processing_auto.py
@@ -100,6 +100,7 @@
("instructblip", "InstructBlipProcessor"),
("instructblipvideo", "InstructBlipVideoProcessor"),
("internvl", "InternVLProcessor"),
+ ("isaac", "IsaacProcessor"),
("janus", "JanusProcessor"),
("kosmos-2", "Kosmos2Processor"),
("kosmos-2.5", "Kosmos2_5Processor"),
diff --git a/src/transformers/models/isaac/__init__.py b/src/transformers/models/isaac/__init__.py
new file mode 100644
index 000000000000..bc0f3fcc6d7c
--- /dev/null
+++ b/src/transformers/models/isaac/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_isaac import *
+ from .modeling_isaac import *
+ from .processing_isaac import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/isaac/configuration_isaac.py b/src/transformers/models/isaac/configuration_isaac.py
new file mode 100644
index 000000000000..d5e0080d029a
--- /dev/null
+++ b/src/transformers/models/isaac/configuration_isaac.py
@@ -0,0 +1,179 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/isaac/modular_isaac.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_isaac.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2026 Perceptron, Inc and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from huggingface_hub.dataclasses import strict
+
+from ...configuration_utils import PreTrainedConfig, PretrainedConfig
+from ...modeling_rope_utils import RopeParameters
+from ...utils import auto_docstring
+
+
+@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base")
+@strict
+class IsaacVisionConfig(PreTrainedConfig):
+ r"""
+ num_patches (`int`, *optional*, defaults to 256):
+ The number of patches in the image with the size of (`patch_size`, `patch_size`). The image is resized to
+ fill a maximum of this number of patches while preserving the aspect ratio. If the resulting number of patches
+ is lower, the image is padded in the patch dimension.
+ pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1):
+ Spatial factor applied before pixel shuffle reduces the resolution.
+ """
+
+ model_type = "isaac_vision"
+ base_config_key = "vision_config"
+
+ hidden_size: int = 768
+ intermediate_size: int = 3072
+ num_hidden_layers: int = 12
+ num_attention_heads: int = 12
+ num_channels: int = 3
+ patch_size: int | list[int] | tuple[int, int] = 16
+ hidden_act: str = "gelu_pytorch_tanh"
+ layer_norm_eps: float = 1e-6
+ attention_dropout: float | int = 0.0
+
+ num_patches: int = 256
+ pixel_shuffle_scale_factor: int = 1
+
+
+@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base")
+@strict
+class IsaacTextConfig(PreTrainedConfig):
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import IsaacTextConfig, IsaacTextModel
+
+ >>> configuration = IsaacTextConfig()
+ >>> model = IsaacTextModel(configuration)
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "isaac_text"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ # Default tensor parallel plan for base model `IsaacText`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
+ "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ vocab_size: int = 151936
+ hidden_size: int = 4096
+ intermediate_size: int = 22016
+ num_hidden_layers: int = 32
+ num_attention_heads: int = 32
+ num_key_value_heads: int | None = 32
+ head_dim: int = 128
+ hidden_act: str = "silu"
+ max_position_embeddings: int = 32768
+ initializer_range: float = 0.02
+ rms_norm_eps: float = 1e-6
+ use_cache: bool = True
+ tie_word_embeddings: bool = False
+ rope_parameters: RopeParameters | dict | None = None
+ attention_bias: bool = False
+ use_sliding_window: bool = False
+ max_window_layers: int = 28
+ attention_dropout: float | int = 0.0
+ pad_token_id: int | None = None
+ bos_token_id: int | None = None
+ eos_token_id: int | list[int] | None = None
+ ignore_keys_at_rope_validation = {"mrope_section", "mrope_interleaved"}
+
+ def __post_init__(self, **kwargs):
+ if self.num_key_value_heads is None:
+ self.num_key_value_heads = self.num_attention_heads
+
+ PretrainedConfig.__post_init__(self, **kwargs)
+
+
+@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base")
+@strict
+class IsaacConfig(PretrainedConfig):
+ r"""
+ vision_config (`IsaacVisionConfig` or `dict`, *optional*):
+ Configuration for the Isaac vision tower. Dictionaries are converted to [`IsaacVisionConfig`]. If unset,
+ the default [`IsaacVisionConfig`] is used.
+ text_config (`IsaacTextConfig` or `dict`, *optional*):
+ Configuration for the text backbone. Dictionaries are converted to [`IsaacTextConfig`].
+ vision_rescale_factor (`float`, *optional*, defaults to 1 / 255):
+ Rescale factor applied by the image processor before normalization.
+ max_sequence_length (`int`, *optional*, defaults to 16384):
+ Maximum multimodal sequence length produced by the processor and expected by the model.
+
+ Example:
+
+ ```python
+ >>> from transformers import IsaacConfig, IsaacModel
+
+ >>> configuration = IsaacConfig()
+ >>> model = IsaacModel(configuration)
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "isaac"
+ sub_configs = {"vision_config": IsaacVisionConfig, "text_config": IsaacTextConfig}
+ vision_config: IsaacVisionConfig | dict | None = None
+ text_config: IsaacTextConfig | dict | None = None
+ vision_rescale_factor: float = 1 / 255
+ max_sequence_length: int = 16384
+
+ def __post_init__(self, **kwargs):
+ if isinstance(self.text_config, dict):
+ self.text_config = self.sub_configs["text_config"](**self.text_config)
+ elif self.text_config is None:
+ self.text_config = self.sub_configs["text_config"]()
+ elif not isinstance(self.text_config, IsaacTextConfig):
+ raise TypeError(
+ f"text_config must be a dict or an IsaacTextConfig instance, got {type(self.text_config).__name__}."
+ )
+
+ if isinstance(self.vision_config, dict):
+ self.vision_config = self.sub_configs["vision_config"](**self.vision_config)
+ elif self.vision_config is None:
+ self.vision_config = self.sub_configs["vision_config"]()
+ elif not isinstance(self.vision_config, IsaacVisionConfig):
+ raise TypeError(
+ f"vision_config must be a dict or an IsaacVisionConfig instance, got {type(self.vision_config).__name__}."
+ )
+
+ self.vision_rescale_factor = float(self.vision_rescale_factor)
+ super().__post_init__(**kwargs)
+
+
+__all__ = ["IsaacConfig", "IsaacTextConfig", "IsaacVisionConfig"]
diff --git a/src/transformers/models/isaac/image_processing_isaac.py b/src/transformers/models/isaac/image_processing_isaac.py
new file mode 100644
index 000000000000..39750c1bd792
--- /dev/null
+++ b/src/transformers/models/isaac/image_processing_isaac.py
@@ -0,0 +1,375 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/isaac/modular_isaac.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_isaac.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2026 Perceptron, Inc and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import math
+from typing import Any
+
+from ... import TorchvisionBackend
+from ...feature_extraction_utils import BatchFeature
+from ...image_transforms import group_images_by_shape, reorder_images
+from ...image_utils import ImageInput, PILImageResampling, SizeDict, make_nested_list_of_images
+from ...processing_utils import ImagesKwargs, Unpack
+from ...utils import TensorType, auto_docstring
+from ...utils.constants import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
+from ...utils.import_utils import is_torch_available, requires
+
+
+if is_torch_available():
+ import torch
+ import torch.nn.functional as F
+ from torchvision.transforms.v2 import functional as tvF
+
+
+# --------------------------------Isaac Image Processor--------------------------------
+
+
+class IsaacImageProcessorKwargs(ImagesKwargs, total=False):
+ """
+ patch_size (`int`, *optional*):
+ Side length (in pixels) for square patches extracted from resized images.
+ max_num_patches (`int`, *optional*):
+ Upper bound on extracted patches per image after resizing.
+ min_num_patches (`int`, *optional*):
+ Lower bound on extracted patches per image after resizing.
+ pixel_shuffle_scale (`int`, *optional*):
+ Pixel-shuffle reduction factor applied in the vision tower.
+ """
+
+ patch_size: int
+ max_num_patches: int
+ min_num_patches: int
+ pixel_shuffle_scale: int
+
+
+# Disable as it causes issues with torch.compile
+@torch.compiler.disable
+def torch_extract_patches(image_tensor, patch_height, patch_width):
+ """
+ Extract patches from image tensor. Returns tensor of shape (batch, rows, columns, patch_height*patch_width*channels).
+
+ Args:
+ image_tensor (`torch.Tensor`):
+ Image tensor of shape (batch, channels, height, width).
+ patch_height (`int`):
+ Height of patches to extract.
+ patch_width (`int`):
+ Width of patches to extract.
+ """
+ batch_size, channels, height, width = image_tensor.shape
+ patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width))
+ patches = patches.reshape(batch_size, channels, patch_height, patch_width, -1)
+ patches = patches.permute(0, 4, 2, 3, 1).reshape(
+ batch_size, height // patch_height, width // patch_width, channels * patch_height * patch_width
+ )
+ return patches
+
+
+def get_scaled_image_size(
+ scale: float,
+ original_size: int,
+ patch_size: int,
+ pixel_shuffle_scale: int,
+) -> int:
+ scaled_size = scale * original_size
+ divisor = patch_size * pixel_shuffle_scale
+ scaled_size = math.ceil(scaled_size / divisor) * divisor
+ scaled_size = max(divisor, scaled_size)
+ return int(scaled_size)
+
+
+def get_image_size_for_max_num_patches(
+ image_height: int,
+ image_width: int,
+ patch_size: int,
+ max_num_patches: int,
+ min_num_patches: int | None = None,
+ eps: float = 1e-5,
+ pixel_shuffle_scale: int = 1,
+) -> tuple[int, int]:
+ r"""Compute a target resolution whose patch grid satisfies patching parametrization.
+
+ Args:
+ image_height (`int`):
+ Height in pixels of the source image prior to any resizing.
+ image_width (`int`):
+ Width in pixels of the source image prior to any resizing.
+ patch_size (`int`):
+ Size of the square patch used by the vision encoder.
+ max_num_patches (`int`):
+ Upper bound on `(height / patch_size) * (width / patch_size)` after resizing.
+ min_num_patches (`int`, *optional*):
+ Lower bound on the number of patches. When provided the image will be scaled up if necessary.
+ eps (`float`, *optional*, defaults to 1e-5):
+ Convergence tolerance for the internal binary search to determing the target dimensions.
+ pixel_shuffle_scale (`int`, *optional*, defaults to 1):
+ Additional stride multiplier applied when pixel shuffle later reduces spatial resolution.
+
+ Returns:
+ `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale`
+ and respect both the maximum and optional minimum patch-count constraints.
+ """
+
+ # Ensure divisibility
+ divisor = patch_size * pixel_shuffle_scale
+ adjusted_height = math.ceil(image_height / divisor) * divisor
+ adjusted_height = max(divisor, adjusted_height)
+ adjusted_width = math.ceil(image_width / divisor) * divisor
+ adjusted_width = max(divisor, adjusted_width)
+
+ num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size)
+
+ if min_num_patches is not None and num_patches < min_num_patches:
+ # Scale up via binary search to satisfy the minimum patch budget while
+ # preserving divisibility by patch_size * pixel_shuffle_scale.
+ scale_min, scale_max = 1.0, 100.0
+ while (scale_max - scale_min) >= eps:
+ scale = (scale_min + scale_max) / 2
+ target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale)
+ target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale)
+ num_patches = (target_height / patch_size) * (target_width / patch_size)
+ if num_patches >= min_num_patches:
+ scale_max = scale
+ else:
+ scale_min = scale
+ scale = scale_max
+ target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale)
+ target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale)
+ return target_height, target_width
+ elif num_patches <= max_num_patches:
+ return adjusted_height, adjusted_width
+ else:
+ # Scale down
+ scale_min, scale_max = eps / 10, 1.0
+ while (scale_max - scale_min) >= eps:
+ scale = (scale_min + scale_max) / 2
+ target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale)
+ target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale)
+ num_patches = (target_height / patch_size) * (target_width / patch_size)
+ if num_patches <= max_num_patches:
+ scale_min = scale
+ else:
+ scale_max = scale
+ scale = scale_min
+ target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale)
+ target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale)
+ return target_height, target_width
+
+
+@auto_docstring
+@requires(backends=("vision",))
+class IsaacImageProcessor(TorchvisionBackend):
+ model_input_names = ["pixel_values", "image_grid_thw"]
+ valid_kwargs = IsaacImageProcessorKwargs
+
+ resample = PILImageResampling.BILINEAR
+ do_resize = True
+ do_center_crop = False
+ patch_size = 16
+ max_num_patches = 256
+ min_num_patches = None
+ pixel_shuffle_scale = 1
+ do_pad = True
+ do_rescale = True
+ do_normalize = True
+ image_mean = IMAGENET_STANDARD_MEAN
+ image_std = IMAGENET_STANDARD_STD
+ do_convert_rgb = True
+ disable_grouping = False
+
+ def __init__(self, **kwargs: Unpack[IsaacImageProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ def _validate_preprocess_kwargs(self, **kwargs):
+ # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for.
+ kwargs.pop("do_resize", None)
+ return super()._validate_preprocess_kwargs(**kwargs)
+
+ def _prepare_images_structure(
+ self,
+ images: ImageInput,
+ expected_ndims: int = 3,
+ ) -> ImageInput:
+ images = self.fetch_images(images)
+ return make_nested_list_of_images(images, expected_ndims=expected_ndims)
+
+ def resize(
+ self,
+ image: torch.Tensor,
+ size: SizeDict,
+ **kwargs,
+ ) -> torch.Tensor:
+ if image.dtype == torch.uint8:
+ image = F.interpolate(image.float(), size=(size.height, size.width), mode="bilinear", align_corners=False)
+ return image.clamp(0, 255).round().to(torch.uint8)
+ return F.interpolate(image, size=(size.height, size.width), mode="bilinear", align_corners=False)
+
+ def pack_images(
+ self,
+ vision_patches: list[list[torch.Tensor]],
+ vision_token_grids: list[list[torch.Tensor]],
+ ) -> dict[str, torch.Tensor | None]:
+ batch_size = len(vision_patches)
+ flat_patches = [patches for sample_patches in vision_patches for patches in sample_patches]
+ if len(flat_patches) == 0:
+ return {"pixel_values": None, "image_grid_thw": None}
+
+ first_patch = flat_patches[0]
+ max_patches = max(patches.shape[0] for patches in flat_patches)
+ max_images = max((len(sample_patches) for sample_patches in vision_patches), default=0)
+
+ patch_dim = first_patch.shape[-1]
+ tensors = {
+ "pixel_values": torch.zeros(
+ (batch_size, max_images, max_patches, patch_dim),
+ device=first_patch.device,
+ dtype=first_patch.dtype,
+ ),
+ "image_grid_thw": torch.zeros((batch_size, max_images, 3), device=first_patch.device, dtype=torch.long),
+ }
+
+ for batch_idx, (sample_patches, sample_token_grids) in enumerate(
+ zip(vision_patches, vision_token_grids, strict=True)
+ ):
+ for image_slot, (patches, token_grid) in enumerate(zip(sample_patches, sample_token_grids, strict=True)):
+ patch_count = int(patches.shape[0])
+ tensors["pixel_values"][batch_idx, image_slot, :patch_count] = patches
+ tensors["image_grid_thw"][batch_idx, image_slot, 0] = 1
+ tensors["image_grid_thw"][batch_idx, image_slot, 1:] = token_grid
+
+ return tensors
+
+ def _preprocess(
+ self,
+ images: list[list[torch.Tensor]],
+ do_resize: bool,
+ resample: PILImageResampling | tvF.InterpolationMode | int | None,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: float | list[float] | None,
+ image_std: float | list[float] | None,
+ do_pad: bool,
+ patch_size: int,
+ max_num_patches: int,
+ min_num_patches: int,
+ pixel_shuffle_scale: int,
+ disable_grouping: bool | None,
+ return_tensors: str | TensorType | None,
+ **kwargs,
+ ) -> BatchFeature:
+ if all(len(sample_images) == 0 for sample_images in images):
+ return BatchFeature(data={"pixel_values": None, "image_grid_thw": None}, tensor_type=return_tensors)
+
+ grouped_images, grouped_images_index = group_images_by_shape(
+ images, disable_grouping=disable_grouping, is_nested=True
+ )
+ grouped_outputs = {}
+ for shape, stacked_images in grouped_images.items():
+ grouped_batch_size, channels, original_height, original_width = stacked_images.shape
+ if do_resize:
+ target_height, target_width = get_image_size_for_max_num_patches(
+ original_height,
+ original_width,
+ patch_size,
+ max_num_patches,
+ min_num_patches=min_num_patches,
+ pixel_shuffle_scale=pixel_shuffle_scale,
+ )
+ image_batch = self.resize(
+ stacked_images, SizeDict(height=target_height, width=target_width), resample=resample
+ )
+ else:
+ if (original_height % patch_size) or (original_width % patch_size):
+ raise ValueError(
+ f"Image dimensions (h={original_height}, w={original_width}) must be divisible by patch_size={patch_size} when resize is disabled; enable resizing or adjust the input resolution."
+ )
+ image_batch, target_height, target_width = stacked_images, original_height, original_width
+
+ image_batch = self.rescale_and_normalize(
+ image_batch,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ )
+
+ patches = torch_extract_patches(image_batch, patch_size, patch_size)
+ _, height_tokens, width_tokens, patch_dim = patches.shape
+
+ token_grid = (
+ torch.tensor([height_tokens, width_tokens], device=patches.device).long().expand(grouped_batch_size, 2)
+ )
+
+ if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale):
+ raise ValueError(
+ f"Token grid (h={height_tokens}, w={width_tokens}) must be divisible by pixel_shuffle_scale={pixel_shuffle_scale};"
+ f" adjust resize/patch parameters or disable pixel shuffle."
+ )
+
+ grouped_outputs[shape] = (
+ patches.reshape(grouped_batch_size, -1, patch_dim),
+ token_grid,
+ )
+
+ keys = ("vision_patches", "vision_token_grids")
+ nested_outputs = {}
+ for i, key in enumerate(keys):
+ nested_outputs[key] = reorder_images(
+ {shape: values[i] for shape, values in grouped_outputs.items()},
+ dict(grouped_images_index),
+ is_nested=True,
+ )
+
+ if not do_pad:
+ raise ValueError("IsaacImageProcessor doesn't support `do_pad=False` mode.")
+
+ tensors = self.pack_images(
+ vision_patches=nested_outputs["vision_patches"],
+ vision_token_grids=nested_outputs["vision_token_grids"],
+ )
+
+ return BatchFeature(data=tensors, tensor_type=return_tensors)
+
+ def get_number_of_image_patches(
+ self,
+ image_height: int,
+ image_width: int,
+ images_kwargs: dict[str, Any] | None = None,
+ ) -> int:
+ images_kwargs = images_kwargs or {}
+ patch_size = images_kwargs.get("patch_size", self.patch_size)
+ max_num_patches = images_kwargs.get("max_num_patches", self.max_num_patches)
+ min_num_patches = images_kwargs.get("min_num_patches", self.min_num_patches)
+ pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale", self.pixel_shuffle_scale)
+
+ target_height, target_width = get_image_size_for_max_num_patches(
+ image_height,
+ image_width,
+ patch_size,
+ max_num_patches,
+ min_num_patches=min_num_patches,
+ pixel_shuffle_scale=pixel_shuffle_scale,
+ )
+ return (target_height // patch_size) * (target_width // patch_size)
+
+
+__all__ = ["IsaacImageProcessor"]
diff --git a/src/transformers/models/isaac/modeling_isaac.py b/src/transformers/models/isaac/modeling_isaac.py
new file mode 100644
index 000000000000..1d293615cc37
--- /dev/null
+++ b/src/transformers/models/isaac/modeling_isaac.py
@@ -0,0 +1,1743 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/isaac/modular_isaac.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_isaac.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2026 Perceptron, Inc and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Any, NamedTuple, Optional
+
+from ... import initialization as init
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation.utils import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
+from ...masking_utils import create_bidirectional_mask, create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPast,
+ BaseModelOutputWithPooling,
+ CausalLMOutputWithPast,
+)
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, torch_compilable_check
+from ...utils.generic import TransformersKwargs, can_return_tuple, maybe_autocast, merge_with_config_defaults
+from ...utils.import_utils import is_torch_available, is_torchdynamo_compiling
+from ...utils.output_capturing import capture_outputs
+from .configuration_isaac import IsaacConfig, IsaacTextConfig, IsaacVisionConfig
+
+
+if is_torch_available():
+ import torch
+ import torch.nn as nn
+ import torch.nn.functional as F
+
+
+class IsaacVisionEmbeddings(nn.Module):
+ """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences.
+
+ Isaac accepts variable-resolution vision inputs as a single packed sequence with per-image
+ `token_grids`; packing/unpacking here reconstructs per-image shapes so we can resize positional
+ embeddings and build `cu_seqlens` for variable-length attention (not generic generation packing).
+ """
+
+ def __init__(self, config: IsaacVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Linear(
+ in_features=config.num_channels * self.patch_size * self.patch_size,
+ out_features=self.embed_dim,
+ )
+
+ self.num_patches = config.num_patches
+ self.position_embedding_size = int(self.num_patches**0.5)
+ self.position_embedding = nn.Parameter(
+ torch.empty(
+ self.position_embedding_size,
+ self.position_embedding_size,
+ self.embed_dim,
+ )
+ )
+
+ @staticmethod
+ def resize_positional_embeddings(
+ positional_embeddings: torch.Tensor,
+ spatial_shapes: torch.LongTensor,
+ max_length: int,
+ ) -> torch.Tensor:
+ """
+ Resize positional embeddings to image-specific size and pad to a fixed size.
+
+ Args:
+ positional_embeddings (`torch.Tensor`):
+ Position embeddings of shape (height, width, embed_dim)
+ spatial_shapes (`torch.LongTensor`):
+ Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
+ max_length (`int`):
+ Maximum length of the positional embeddings to pad resized positional embeddings to
+
+ Returns:
+ `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
+ """
+ batch_size = spatial_shapes.shape[0]
+ embed_dim = positional_embeddings.shape[-1]
+ source_dtype = positional_embeddings.dtype
+
+ resulted_positional_embeddings = torch.empty(
+ (batch_size, max_length, embed_dim),
+ device=positional_embeddings.device,
+ dtype=source_dtype,
+ )
+
+ # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
+ positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
+
+ # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
+ if positional_embeddings.device.type == "cpu":
+ positional_embeddings = positional_embeddings.to(torch.float32)
+
+ for i in range(batch_size):
+ # (1, dim, height, width) -> (1, dim, target_height, target_width)
+ height, width = spatial_shapes[i].tolist() # will be itemized in F.interpolate either way
+ torch_compilable_check((width > 0), "Width of resized positional embeddings must be positive.")
+ torch_compilable_check((height > 0), "Height of resized positional embeddings must be positive.")
+ torch_compilable_check((height * width) <= max_length, "Resized positional embeddings exceed max_length.")
+ resized_embeddings = F.interpolate(
+ positional_embeddings,
+ size=(height, width),
+ mode="bilinear",
+ align_corners=False,
+ antialias=True,
+ )
+
+ # (1, dim, target_height, target_width) -> (target_height * target_width, dim)
+ resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)
+
+ # Cast to original dtype
+ resized_embeddings = resized_embeddings.to(source_dtype)
+
+ resulted_positional_embeddings[i, : height * width] = resized_embeddings
+ resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
+
+ return resulted_positional_embeddings
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ image_grid_thw: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ pixel_values (`torch.FloatTensor`):
+ Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size)
+ spatial_shapes (`list[tuple[int, int]]`):
+ Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
+ """
+ # pixel_values: (num_images, max_patches, patch_dim)
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
+
+ resized_positional_embeddings = self.resize_positional_embeddings(
+ self.position_embedding,
+ image_grid_thw[:, 1:],
+ max_length=pixel_values.shape[1],
+ )
+ resized_positional_embeddings = resized_positional_embeddings.to(
+ device=patch_embeds.device, dtype=patch_embeds.dtype
+ )
+ embeddings = patch_embeds + resized_positional_embeddings
+
+ if attention_mask is not None:
+ embeddings = embeddings * attention_mask.unsqueeze(-1).to(device=embeddings.device, dtype=embeddings.dtype)
+
+ return embeddings
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: torch.Tensor | None,
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class IsaacVisionAttention(nn.Module):
+ """Custom attention that supports variable-length sequences with flash/SDPA backends."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+ self.is_causal = False
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ """Input shape: Batch x Time x Channel"""
+
+ input_shape = hidden_states.shape[:-1]
+
+ hidden_shape = (*input_shape, -1, self.head_dim)
+ queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
+ self.config._attn_implementation, eager_attention_forward
+ )
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ queries,
+ keys,
+ values,
+ attention_mask,
+ is_causal=self.is_causal,
+ scaling=self.scale,
+ dropout=0.0 if not self.training else self.dropout,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class IsaacMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class IsaacVisionEncoderLayer(GradientCheckpointingLayer):
+ """Isaac vision encoder layer using the shared attention interfaces."""
+
+ def __init__(self, config: IsaacVisionConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.self_attn = IsaacVisionAttention(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = IsaacMLP(config)
+
+ @auto_docstring
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class IsaacVisionEncoder(nn.Module):
+ """Encoder using Isaac encoder layers."""
+
+ def __init__(self, config: IsaacVisionConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ # Ignore copy
+ @auto_docstring
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutput:
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ hidden_states = encoder_layer(
+ hidden_states,
+ attention_mask,
+ **kwargs,
+ )
+
+ return BaseModelOutput(last_hidden_state=hidden_states)
+
+
+@auto_docstring
+class IsaacVisionModel(PreTrainedModel):
+ config: IsaacVisionConfig
+ _supports_sdpa = True
+ _supports_flash_attn = True
+ _can_record_outputs = {
+ "hidden_states": IsaacVisionEncoderLayer,
+ "attentions": IsaacVisionAttention,
+ }
+
+ def __init__(self, config: IsaacVisionConfig):
+ super().__init__(config)
+ self.embeddings = IsaacVisionEmbeddings(config)
+ self.encoder = IsaacVisionEncoder(config)
+ self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor
+
+ self.post_init()
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, IsaacVisionEmbeddings):
+ init.zeros_(module.position_embedding)
+
+ def pixel_shuffle_padded(
+ self,
+ hidden_states: torch.Tensor,
+ token_grids: torch.Tensor,
+ ) -> torch.Tensor:
+ """Apply pixel shuffle per image on padded batched vision embeddings.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Vision embeddings of shape `(num_images, max_patches, hidden_size)`.
+ token_grids (`torch.Tensor`):
+ Grid sizes `(height, width)` per image, shape `(num_images, 2)`.
+
+ Returns:
+ `torch.Tensor`: Pixel-shuffled embeddings of shape
+ `(num_images, max_tokens, hidden_size * scale_factor**2)`.
+ """
+ scale_factor = self.pixel_shuffle_scale_factor
+ num_images, max_patches, embed_dim = hidden_states.shape
+ output_dim = embed_dim * scale_factor * scale_factor
+
+ token_grids = token_grids.to(device=hidden_states.device, dtype=torch.long)
+ heights = token_grids[:, 0]
+ widths = token_grids[:, 1]
+ full_lengths = heights * widths
+
+ non_empty = full_lengths > 0
+ if not is_torchdynamo_compiling():
+ divisible = ((heights % scale_factor) == 0) & ((widths % scale_factor) == 0)
+ torch_compilable_check(
+ (~non_empty) | divisible,
+ f"Every non-empty (H, W) grid must be divisible by pixel_shuffle_scale={scale_factor}.",
+ )
+
+ output_lengths = (heights // scale_factor) * (widths // scale_factor)
+ max_output_tokens = output_lengths.max()
+ shuffled_4d = hidden_states.new_zeros((num_images, max_output_tokens, scale_factor * scale_factor, embed_dim))
+
+ token_positions = (
+ torch.arange(max_patches, device=hidden_states.device, dtype=torch.long)
+ .unsqueeze(0)
+ .expand(num_images, -1)
+ )
+ valid_token_mask = token_positions < full_lengths.unsqueeze(1)
+
+ safe_widths = torch.where(widths > 0, widths, torch.ones_like(widths))
+ row_index = torch.div(token_positions, safe_widths.unsqueeze(1), rounding_mode="floor")
+ col_index = token_positions.remainder(safe_widths.unsqueeze(1))
+
+ output_widths = widths.div(scale_factor, rounding_mode="floor")
+ output_index = row_index.div(scale_factor, rounding_mode="floor") * output_widths.unsqueeze(1)
+ output_index = output_index + col_index.div(scale_factor, rounding_mode="floor")
+ sub_index = row_index.remainder(scale_factor) * scale_factor + col_index.remainder(scale_factor)
+
+ batch_index = (
+ torch.arange(num_images, device=hidden_states.device, dtype=torch.long)
+ .unsqueeze(1)
+ .expand_as(token_positions)
+ )
+ shuffled_4d[batch_index[valid_token_mask], output_index[valid_token_mask], sub_index[valid_token_mask]] = (
+ hidden_states[valid_token_mask]
+ )
+
+ shuffled = shuffled_4d.view(num_images, max_output_tokens, output_dim)
+ return shuffled
+
+ @merge_with_config_defaults
+ @capture_outputs(tie_last_hidden_states=False)
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ image_grid_thw: torch.Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ image_grid_thw (`torch.LongTensor`, *optional*):
+ Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries.
+ """
+ full_lengths = image_grid_thw[:, 1] * image_grid_thw[:, 2]
+ token_positions = torch.arange(pixel_values.shape[1], device=pixel_values.device, dtype=torch.long)
+ image_patch_attention_mask = token_positions.unsqueeze(0) < full_lengths.unsqueeze(1)
+ image_patch_attention_mask = image_patch_attention_mask.to(dtype=torch.long)
+ hidden_states = self.embeddings(
+ pixel_values,
+ image_grid_thw,
+ attention_mask=image_patch_attention_mask,
+ )
+
+ encoder_attention_mask = create_bidirectional_mask(
+ config=self.config,
+ inputs_embeds=hidden_states,
+ attention_mask=image_patch_attention_mask,
+ )
+ encoder_outputs = self.encoder(inputs_embeds=hidden_states, attention_mask=encoder_attention_mask, **kwargs)
+ hidden_states = self.post_layernorm(encoder_outputs.last_hidden_state)
+
+ hidden_states = self.pixel_shuffle_padded(
+ hidden_states=hidden_states,
+ token_grids=image_grid_thw[:, 1:],
+ )
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=hidden_states,
+ pooler_output=None,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class IsaacRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: IsaacTextConfig, device=None):
+ super().__init__()
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+
+ self.rope_type = self.config.rope_parameters["rope_type"]
+ rope_init_fn: Callable = self.compute_default_rope_parameters
+ if self.rope_type != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
+
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
+ self.mrope_section = config.rope_parameters.get("mrope_section")
+ if self.mrope_section is None:
+ weights = (2, 1, 1)
+ self.mrope_section = [self.inv_freq.shape[0] * w // sum(weights) for w in weights]
+ self.mrope_section[0] += self.inv_freq.shape[0] - sum(self.mrope_section)
+
+ @staticmethod
+ def compute_default_rope_parameters(
+ config: IsaacTextConfig | None = None,
+ device: Optional["torch.device"] = None,
+ seq_len: int | None = None,
+ ) -> tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies according to the original RoPE implementation
+ Args:
+ config ([`~transformers.PreTrainedConfig`]):
+ The model configuration.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+ """
+ base = config.rope_parameters["rope_theta"]
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+
+ attention_factor = 1.0 # Unused in this type of RoPE
+
+ # Compute the inverse frequencies
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
+ )
+ return inv_freq, attention_factor
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ # In contrast to other models, Isaac has different position ids for the grids
+ # So we expand the inv_freq to shape (3, ...)
+ if position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
+ freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+ def apply_interleaved_mrope(self, freqs: torch.Tensor, mrope_section: list[int]) -> torch.Tensor:
+ """Apply interleaved MRoPE to 3D rotary embeddings.
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
+ interleaved [THWTHWTHW...TT], preserving frequency continuity.
+ args:
+ x: (3, bs, seq_len, head_dim // 2)
+ mrope_section: (3,)
+ returns:
+ x_t: (bs, seq_len, head_dim // 2)
+ """
+ chunks = freqs.split(tuple(mrope_section), dim=-1)
+ return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1)
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class IsaacTextRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps: float = 1e-6) -> None:
+ """
+ IsaacTextRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+@use_kernel_func_from_hub("rotary_pos_emb")
+def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+@use_kernelized_func(apply_rotary_pos_emb)
+class IsaacTextAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: IsaacTextConfig, layer_idx: int):
+ super().__init__()
+ self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.q_norm = IsaacTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
+ self.k_norm = IsaacTextRMSNorm(
+ self.head_dim, eps=config.rms_norm_eps
+ ) # thus post q_norm does not need reshape
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: torch.Tensor | None,
+ past_key_values: Cache | None = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
+
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
+ self.config._attn_implementation, eager_attention_forward
+ )
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class IsaacTextMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class IsaacTextDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: IsaacTextConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = IsaacTextAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = IsaacTextMLP(config)
+ self.input_layernorm = IsaacTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = IsaacTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ use_cache: bool | None = False,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class IsaacVisionRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ self.dim = dim
+ self.theta = theta
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.outer(seq, self.inv_freq)
+ return freqs
+
+
+@auto_docstring
+class IsaacPreTrainedModel(PreTrainedModel):
+ config: IsaacConfig
+ base_model_prefix = "model"
+ input_modalities = ("image", "video", "text")
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionBlock"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": IsaacTextDecoderLayer,
+ "attentions": IsaacTextAttention,
+ }
+
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, IsaacVisionRotaryEmbedding):
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
+ init.copy_(module.inv_freq, inv_freq)
+
+
+@auto_docstring(
+ custom_intro=(
+ "Text part of Isaac, "
+ "not a pure text-only model, as DeepStack integrates visual features into the early hidden states."
+ )
+)
+class IsaacTextModel(IsaacPreTrainedModel):
+ config: IsaacTextConfig
+ input_modalities = ("text",)
+ _no_split_modules = ["IsaacTextDecoderLayer"]
+
+ def __init__(self, config: IsaacTextConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [IsaacTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = IsaacTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = IsaacRotaryEmbedding(config=config, device=self.device)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @merge_with_config_defaults
+ @capture_outputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ # args for deepstack
+ visual_pos_masks: torch.Tensor | None = None,
+ deepstack_visual_embeds: list[torch.Tensor] | None = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple | BaseModelOutputWithPast:
+ r"""
+ visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*):
+ The mask of the visual positions.
+ deepstack_visual_embeds (`list[torch.Tensor]`, *optional*):
+ The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim).
+ The feature is extracted from the different visual encoder layers, and fed to the decoder
+ hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334).
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ # torch.jit.trace() doesn't support cache objects in the output
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # the hard coded `4` is for text, temporal, height and width.
+ if position_ids is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
+ position_ids = position_ids.view(1, 1, -1).expand(4, inputs_embeds.shape[0], -1)
+ elif position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(4, position_ids.shape[0], -1)
+
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
+ text_position_ids = position_ids[0]
+ position_ids = position_ids[1:]
+ else:
+ text_position_ids = None
+
+ attention_mask = create_causal_mask(
+ config=self.config,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ position_ids=text_position_ids,
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ for layer_idx, decoder_layer in enumerate(self.layers):
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=text_position_ids,
+ past_key_values=past_key_values,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = layer_outputs
+
+ # add visual features to the hidden states of first several layers
+ if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)):
+ hidden_states = self._deepstack_process(
+ hidden_states,
+ visual_pos_masks,
+ deepstack_visual_embeds[layer_idx],
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+ def _deepstack_process(
+ self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor
+ ):
+ visual_pos_masks = visual_pos_masks.to(hidden_states.device)
+ visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
+ hidden_states = hidden_states.clone()
+ local_this = hidden_states[visual_pos_masks, :] + visual_embeds
+ hidden_states[visual_pos_masks, :] = local_this
+ return hidden_states
+
+
+class IsaacMultiModalProjector(nn.Module):
+ """Maps vision tower outputs to the text hidden size with a SiLU MLP."""
+
+ def __init__(self, config: IsaacConfig):
+ super().__init__()
+ text_config = config.get_text_config()
+ vision_hidden_size = config.vision_config.hidden_size * (config.vision_config.pixel_shuffle_scale_factor**2)
+ backbone_hidden_size = text_config.hidden_size
+ self.linear_1 = nn.Linear(vision_hidden_size, 4 * vision_hidden_size, bias=False)
+ self.silu = nn.SiLU()
+ self.linear_2 = nn.Linear(4 * vision_hidden_size, backbone_hidden_size, bias=False)
+
+ def forward(self, image_features):
+ hidden_states = self.linear_1(image_features)
+ hidden_states = self.silu(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+@auto_docstring
+class IsaacModel(IsaacPreTrainedModel):
+ base_model_prefix = "model"
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ config: IsaacConfig
+ _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionEncoderLayer"]
+ input_modalities = ("image", "text")
+ supports_gradient_checkpointing = True
+ _can_compile_fullgraph = False
+ _supports_flex_attn = False
+ _tied_weights_keys = {}
+ _input_embed_layer = "language_model.embed_tokens"
+
+ def __init__(self, config: IsaacConfig):
+ super().__init__(config)
+ self.language_model = IsaacTextModel._from_config(config.text_config)
+ self.visual = IsaacVisionModel(config.vision_config)
+ self.multimodal_projector = IsaacMultiModalProjector(config)
+ self.max_sequence_length = config.max_sequence_length
+ self.vision_rescale_factor = config.vision_rescale_factor
+ self.rope_deltas = None
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def get_vision_position_ids(
+ self,
+ start_position: int,
+ grid_thw: torch.LongTensor,
+ image_metadata: torch.LongTensor,
+ ) -> torch.LongTensor:
+ """
+ Compute 3D positional indices for vision tokens derived from a single image or video input.
+
+ The positions are generated from the input grid defined by temporal (T), height (H), and
+ width (W) dimensions. Temporal and spatial dimensions can be downscaled according to the
+ merge sizes used in the vision backbone. The resulting positions are offset by `start_position`.
+
+ Args:
+ start_position (`int`):
+ Offset added to all computed positional indices.
+ grid_thw (`Sequence[int]` or `torch.Tensor` of shape `(3,)`):
+ The (T, H, W) grid representing the feature layout of the current image or video after patch embedding.
+ temp_merge_size (`int`, *optional*):
+ Factor by which the temporal dimension is reduced in the backbone. The temporal grid size is divided
+ by this value. Defaults to 1.
+ spatial_merge_size (`int`, *optional*):
+ Factor by which the spatial dimensions (H and W) are reduced in the backbone. Both H and W are divided
+ by this value. Defaults to 1.
+ time_interval (`int`, *optional*):
+ Spacing factor applied between consecutive temporal position indices.Defaults to 1.
+ device (`str` or `torch.device`, *optional*):
+ Device on which the resulting tensor is allocated. If `None`, uses the current default device.
+
+ Returns:
+ torch.LongTensor of shape (3, sequence_length):
+ Positional indices for temporal, height, and width dimensions,
+ flattened into sequence form and offset by `start_position`.
+ """
+ pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor
+ height = grid_thw[1].div(pixel_shuffle_scale, rounding_mode="floor").item()
+ width = grid_thw[2].div(pixel_shuffle_scale, rounding_mode="floor").item()
+ token_positions = torch.arange(height * width, device=grid_thw.device, dtype=torch.long)
+ vision_position_ids = torch.stack(
+ (
+ torch.full((token_positions.shape[0],), start_position, device=grid_thw.device, dtype=torch.long),
+ token_positions.div(width, rounding_mode="floor"),
+ token_positions.remainder(width),
+ ),
+ dim=0,
+ )
+ token_offset = int(image_metadata[0].item())
+ token_length = int(image_metadata[1].item())
+ return vision_position_ids[:, token_offset : token_offset + token_length]
+
+ def get_rope_index(
+ self,
+ input_ids: torch.LongTensor,
+ mm_token_type_ids: torch.Tensor,
+ image_grid_thw: torch.Tensor,
+ image_metadata: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Difference from Qwen2VL/Qwen2.5VL's get_rope_index:
+ - Since Qwen3.5 use timestamps to seperate videos, like , the video_grid_thw should also be split too.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+ mm_token_type_ids (`torch.IntTensor` of shape `(batch_size, sequence_length)`):
+ Token type ids matching each modality to a different value in the input sequence, i.e. text (0), image (1), video (2).
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ Returns:
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
+ """
+ if attention_mask is None:
+ if input_ids is None:
+ attention_mask = mm_token_type_ids.new_ones(mm_token_type_ids.shape, dtype=torch.long)
+ else:
+ attention_mask = input_ids.new_ones(input_ids.shape, dtype=torch.long)
+
+ if input_ids is None:
+ batch_size, seq_len = attention_mask.shape
+ position_dtype = torch.long
+ else:
+ batch_size, seq_len = input_ids.shape
+ position_dtype = input_ids.dtype
+
+ device = attention_mask.device
+ mm_token_type_ids = mm_token_type_ids.to(dtype=torch.long)
+ image_grid_thw = image_grid_thw.to(dtype=torch.long)
+ image_metadata = image_metadata.to(dtype=torch.long)
+ attention_mask = attention_mask.to(dtype=torch.long)
+ active_slot_mask = image_grid_thw[..., 0].eq(1)
+
+ position_ids = torch.zeros((3, batch_size, seq_len), device=device, dtype=position_dtype)
+ rope_deltas = torch.zeros((batch_size, 1), device=device, dtype=torch.long)
+
+ for batch_idx in range(batch_size):
+ sample_attention_mask = attention_mask[batch_idx].bool()
+ sample_token_types = mm_token_type_ids[batch_idx][sample_attention_mask]
+ sample_grids = image_grid_thw[batch_idx]
+ sample_metadata = image_metadata[batch_idx]
+ sample_active_slots = active_slot_mask[batch_idx]
+
+ current_pos = 0
+ image_idx = 0
+ seq_pos = 0
+ llm_pos_ids_list = []
+
+ while seq_pos < sample_token_types.shape[0]:
+ modality_type = int(sample_token_types[seq_pos].item())
+ if modality_type == 0:
+ group_end = seq_pos + 1
+ while group_end < sample_token_types.shape[0] and sample_token_types[group_end] == 0:
+ group_end += 1
+ group_length = group_end - seq_pos
+ llm_pos_ids_list.append(
+ torch.arange(group_length, device=device, dtype=torch.long).view(1, -1).expand(3, -1)
+ + current_pos
+ )
+ current_pos += group_length
+ seq_pos = group_end
+ else:
+ while image_idx < sample_metadata.shape[0] and (
+ not bool(sample_active_slots[image_idx].item()) or sample_metadata[image_idx, 1].item() == 0
+ ):
+ image_idx += 1
+ torch_compilable_check(
+ image_idx < sample_metadata.shape[0],
+ "Isaac multimodal sequence has more visible image tokens than batch-major image metadata slots.",
+ )
+ token_length = int(sample_metadata[image_idx, 1].item())
+ torch_compilable_check(
+ token_length <= sample_token_types.shape[0] - seq_pos,
+ "Isaac image metadata length exceeds the remaining multimodal placeholder span.",
+ )
+ llm_pos_ids_list.append(
+ self.get_vision_position_ids(current_pos, sample_grids[image_idx], sample_metadata[image_idx])
+ )
+ current_pos += 1
+ seq_pos += token_length
+ image_idx += 1
+
+ llm_positions = (
+ torch.cat(llm_pos_ids_list, dim=1)
+ if llm_pos_ids_list
+ else torch.zeros((3, 0), device=device, dtype=torch.long)
+ )
+ position_ids[:, batch_idx, sample_attention_mask] = llm_positions
+ rope_deltas[batch_idx, 0] = (
+ llm_positions.max() + 1 - sample_token_types.shape[0] if llm_positions.numel() > 0 else 0
+ )
+
+ return position_ids, rope_deltas
+
+ @can_return_tuple
+ @auto_docstring
+ def get_image_features(
+ self,
+ pixel_values: torch.Tensor,
+ image_grid_thw: torch.Tensor,
+ image_metadata: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | BaseModelOutputWithPooling:
+ r"""
+ pixel_values (`torch.FloatTensor`, *optional*):
+ Batch-major patch vectors shaped `(batch_size, max_images, max_patches, patch_dim)`.
+ image_grid_thw (`torch.LongTensor`, *optional*):
+ Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries.
+ image_metadata (`torch.Tensor`, *optional*):
+ Batch-major per-slot metadata `(offset, length)` shaped `(batch_size, max_images, 2)`.
+ """
+ active_slot_mask = image_grid_thw[..., 0].eq(1)
+ flat_pixel_values = pixel_values[active_slot_mask]
+ flat_image_grid_thw = image_grid_thw[active_slot_mask]
+
+ vision_outputs: BaseModelOutputWithPooling = self.visual(
+ pixel_values=flat_pixel_values,
+ image_grid_thw=flat_image_grid_thw,
+ return_dict=True,
+ **kwargs,
+ )
+ projected_features = self.multimodal_projector(vision_outputs.last_hidden_state)
+
+ # Truncate image features using offset and length
+ if image_metadata is None:
+ pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor
+ downsampled_height = flat_image_grid_thw[:, 1].div(pixel_shuffle_scale, rounding_mode="floor")
+ downsampled_width = flat_image_grid_thw[:, 2].div(pixel_shuffle_scale, rounding_mode="floor")
+ lengths = downsampled_height * downsampled_width
+ offsets = torch.zeros_like(lengths)
+ else:
+ torch_compilable_check(
+ image_metadata.shape[:2] == image_grid_thw.shape[:2],
+ "IsaacModel.get_image_features expects batch-major metadata aligned with `image_grid_thw`.",
+ )
+ offsets = image_metadata[active_slot_mask][:, 0]
+ lengths = image_metadata[active_slot_mask][:, 1]
+
+ image_features = tuple(
+ projected_features[image_idx, offset : offset + length]
+ for image_idx, (offset, length) in enumerate(zip(offsets.tolist(), lengths.tolist(), strict=True))
+ )
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=projected_features,
+ pooler_output=image_features,
+ hidden_states=vision_outputs.hidden_states,
+ attentions=vision_outputs.attentions,
+ )
+
+ def get_placeholder_mask(
+ self,
+ mm_token_type_ids: torch.LongTensor,
+ inputs_embeds: torch.FloatTensor,
+ image_features: torch.FloatTensor,
+ ) -> torch.BoolTensor:
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ image_token_mask = mm_token_type_ids.to(dtype=torch.long) == 1
+ n_image_tokens = image_token_mask.sum()
+ image_token_mask = image_token_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ torch_compilable_check(
+ inputs_embeds[image_token_mask].numel() == image_features.numel(),
+ f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}",
+ )
+ return image_token_mask
+
+ def compute_3d_position_ids(
+ self,
+ input_ids: torch.Tensor,
+ inputs_embeds: torch.Tensor,
+ attention_mask: torch.Tensor | None,
+ mm_token_type_ids: torch.Tensor | None = None,
+ image_grid_thw: torch.Tensor | None = None,
+ image_metadata: torch.Tensor | None = None,
+ past_key_values: Cache | None = None,
+ ) -> torch.Tensor:
+ past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length()
+ has_multimodal = (
+ image_grid_thw is not None
+ and image_metadata is not None
+ and bool(image_grid_thw[..., 0].eq(1).any().item())
+ )
+ if has_multimodal and mm_token_type_ids is None and input_ids is not None:
+ raise ValueError(
+ "Multimodal data was passed (via `image_grid_thw` or `image_metadata`) but `mm_token_type_ids` is "
+ "missing. Please pass `mm_token_type_ids` to the model so that multimodal RoPE (M-RoPE) can be "
+ "computed correctly. `mm_token_type_ids` is returned by the processor alongside `input_ids`."
+ )
+
+ if has_multimodal and past_seen_tokens == 0:
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids=input_ids,
+ mm_token_type_ids=mm_token_type_ids,
+ image_grid_thw=image_grid_thw,
+ image_metadata=image_metadata,
+ attention_mask=attention_mask,
+ )
+ self.rope_deltas = rope_deltas
+ return position_ids
+
+ if self.rope_deltas is None:
+ return None
+
+ rope_deltas = torch.as_tensor(self.rope_deltas, device=inputs_embeds.device, dtype=torch.long).reshape(-1, 1)
+ if rope_deltas.shape[0] != inputs_embeds.shape[0]:
+ if inputs_embeds.shape[0] % rope_deltas.shape[0] == 0:
+ rope_deltas = rope_deltas.repeat_interleave(inputs_embeds.shape[0] // rope_deltas.shape[0], dim=0)
+ else:
+ rope_deltas = rope_deltas[:1].expand(inputs_embeds.shape[0], -1)
+
+ if attention_mask is not None and attention_mask.shape[-1] > inputs_embeds.shape[1]:
+ rope_position = attention_mask.long().cumsum(dim=-1) - 1
+ rope_position = rope_position.masked_fill(attention_mask == 0, 0)
+ rope_position = rope_position[:, -inputs_embeds.shape[1] :]
+ else:
+ rope_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device,
+ dtype=torch.long,
+ ).view(1, -1)
+ rope_position = rope_position.expand(inputs_embeds.shape[0], -1)
+
+ position_ids = rope_position.view(1, inputs_embeds.shape[0], -1).expand(3, -1, -1)
+ return position_ids + rope_deltas.to(device=inputs_embeds.device).unsqueeze(0)
+
+ @auto_docstring(
+ custom_intro="""
+ Forward pass with multimodal MRoPE position ids.
+
+ When image placeholders are present, Isaac computes vision features, scatters them into the token
+ embeddings, and runs the shared text backbone on the mixed sequence.
+ """,
+ )
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ mm_token_type_ids: torch.LongTensor | None = None,
+ pixel_values: torch.Tensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ image_metadata: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | BaseModelOutputWithPast:
+ r"""
+ mm_token_type_ids (`torch.LongTensor`, *optional*):
+ Multimodal token type ids aligned with the token sequence, using `0 -> text` and `1 -> image`.
+ pixel_values (`torch.FloatTensor`, *optional*):
+ Batch-major patch vectors shaped `(batch_size, max_images, max_patches, patch_dim)`.
+ image_grid_thw (`torch.LongTensor`, *optional*):
+ Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries.
+ image_metadata (`torch.LongTensor`, *optional*):
+ Batch-major per-slot metadata shaped `(batch_size, max_images, 2)` with `(offset, length)`.
+ """
+ if (input_ids is None) == (inputs_embeds is None):
+ raise ValueError("You must specify exactly one of `input_ids` or `inputs_embeds`.")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ image_mask = None
+ if pixel_values is not None and image_grid_thw is not None:
+ image_outputs = self.get_image_features(
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ image_metadata=image_metadata,
+ return_dict=True,
+ )
+ image_embeds = image_outputs.pooler_output
+ if len(image_embeds) > 0:
+ image_embeds = torch.cat(image_embeds, dim=0).to(
+ device=inputs_embeds.device, dtype=inputs_embeds.dtype
+ )
+ image_mask = self.get_placeholder_mask(
+ mm_token_type_ids=mm_token_type_ids,
+ inputs_embeds=inputs_embeds,
+ image_features=image_embeds,
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if isinstance(attention_mask, dict):
+ attention_mask = attention_mask["full_attention"]
+
+ past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length()
+ computed_position_ids = self.compute_3d_position_ids(
+ input_ids=input_ids,
+ inputs_embeds=inputs_embeds,
+ mm_token_type_ids=mm_token_type_ids,
+ image_grid_thw=image_grid_thw,
+ image_metadata=image_metadata,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ )
+ if computed_position_ids is not None:
+ position_ids = computed_position_ids
+ elif past_seen_tokens > 0:
+ position_ids = None
+ elif position_ids is not None and past_seen_tokens == 0:
+ position_ids = position_ids.to(device=inputs_embeds.device)
+ if position_ids.ndim == 2:
+ position_ids = position_ids.view(1, position_ids.shape[0], -1).expand(3, -1, -1)
+
+ outputs = self.language_model(
+ input_ids=None,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ visual_pos_masks=image_mask[..., 0] if image_mask is not None else None,
+ deepstack_visual_embeds=None,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@dataclass
+class IsaacCausalLMOutputWithPast(CausalLMOutputWithPast):
+ """
+ Base class for Isaac causal language model (or autoregressive) outputs.
+
+ Args:
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ rope_deltas: torch.LongTensor | None = None
+
+
+@dataclass
+@auto_docstring
+class BaseModelOutputWithDeepstackFeatures(BaseModelOutputWithPooling):
+ r"""
+ deepstack_features (`List[torch.FloatTensor]`, *optional*):
+ List of hidden-states (feature maps) from deepstack layers.
+ """
+
+ deepstack_features: list[torch.FloatTensor] | None = None
+
+
+@auto_docstring
+class IsaacForConditionalGeneration(IsaacPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ config: IsaacConfig
+ config_class = IsaacConfig
+ input_modalities = ("image", "text")
+ _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionEncoderLayer"]
+ _can_compile_fullgraph = False
+
+ def __init__(self, config: IsaacConfig):
+ super().__init__(config)
+ self.model = IsaacModel(config)
+ self.vocab_size = config.get_text_config().vocab_size
+ self.lm_head = nn.Linear(config.get_text_config().hidden_size, config.get_text_config().vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ @auto_docstring
+ def get_video_features(
+ self,
+ pixel_values_videos: torch.FloatTensor,
+ video_grid_thw: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | BaseModelOutputWithDeepstackFeatures:
+ r"""
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input videos.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ return self.model.get_video_features(
+ pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs
+ )
+
+ @auto_docstring
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_grid_thw: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | BaseModelOutputWithDeepstackFeatures:
+ r"""
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs)
+
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ pixel_values: torch.Tensor | None = None,
+ pixel_values_videos: torch.FloatTensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ video_grid_thw: torch.LongTensor | None = None,
+ mm_token_type_ids: torch.IntTensor | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | IsaacCausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, IsaacForConditionalGeneration
+
+ >>> model = IsaacForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
+
+ >>> messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
+ },
+ {"type": "text", "text": "Describe the image."},
+ ],
+ }
+ ]
+
+ >>> inputs = processor.apply_chat_template(
+ messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ return_dict=True,
+ return_tensors="pt"
+ )
+
+ >>> # Generate
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=1024)
+ >>> generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
+ >>> output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ >>> print(output_text)
+ ```
+ """
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ mm_token_type_ids=mm_token_type_ids,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ return IsaacCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.model.rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values: Cache = None,
+ attention_mask: torch.Tensor | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ mm_token_type_ids: torch.LongTensor | None = None,
+ pixel_values: torch.Tensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ image_metadata: torch.LongTensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ is_first_iteration: bool = False,
+ use_cache: bool = True,
+ **kwargs,
+ ) -> dict[str, Any]:
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ is_first_iteration=is_first_iteration,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ multimodal_inputs = {
+ "mm_token_type_ids": mm_token_type_ids,
+ "pixel_values": pixel_values,
+ "image_grid_thw": image_grid_thw,
+ "image_metadata": image_metadata,
+ }
+ is_prefill = is_first_iteration or not use_cache
+ for key, value in multimodal_inputs.items():
+ model_inputs[key] = value if is_prefill else None
+ if model_inputs["mm_token_type_ids"] is not None:
+ sequence_length = None
+ if model_inputs.get("input_ids") is not None:
+ sequence_length = model_inputs["input_ids"].shape[1]
+ elif model_inputs.get("inputs_embeds") is not None:
+ sequence_length = model_inputs["inputs_embeds"].shape[1]
+
+ if sequence_length is not None:
+ current_length = model_inputs["mm_token_type_ids"].shape[1]
+ if current_length < sequence_length:
+ padding = model_inputs["mm_token_type_ids"].new_zeros(
+ (model_inputs["mm_token_type_ids"].shape[0], sequence_length - current_length)
+ )
+ model_inputs["mm_token_type_ids"] = torch.cat([model_inputs["mm_token_type_ids"], padding], dim=1)
+ elif current_length > sequence_length:
+ model_inputs["mm_token_type_ids"] = model_inputs["mm_token_type_ids"][:, -sequence_length:]
+
+ return model_inputs
+
+ def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs):
+ text_positions = super()._prepare_position_ids_for_generation(inputs_tensor, model_kwargs)
+
+ past_length = 0
+ if (cache := model_kwargs.get("past_key_values")) is not None:
+ past_length = cache.get_seq_length()
+ if past_length != 0 and self.model.rope_deltas is not None:
+ return text_positions[None, ...] + self.model.rope_deltas
+
+ if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
+ inputs_tensor = model_kwargs["input_ids"]
+
+ is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
+ if (
+ is_input_ids
+ and model_kwargs.get("mm_token_type_ids") is not None
+ and model_kwargs.get("image_grid_thw") is not None
+ and model_kwargs.get("image_metadata") is not None
+ ):
+ model_kwargs = {k: v for k, v in model_kwargs.items() if k != "input_ids"}
+ vision_positions, rope_deltas = self.model.get_rope_index(inputs_tensor, **model_kwargs)
+ self.model.rope_deltas = rope_deltas
+ else:
+ vision_positions = text_positions.unsqueeze(0).expand(3, -1, -1)
+ self.model.rope_deltas = torch.zeros(
+ inputs_tensor.shape[0], 1, dtype=torch.long, device=inputs_tensor.device
+ )
+
+ return torch.cat([text_positions[None, ...], vision_positions], dim=0)
+
+ def _get_image_nums_and_video_nums(
+ self,
+ input_ids: torch.LongTensor | None,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Returns:
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
+ """
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+
+ if inputs_embeds is not None:
+ vision_start_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ image_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ video_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ else:
+ vision_start_mask = input_ids == vision_start_token_id
+ image_mask = input_ids == image_token_id
+ video_mask = input_ids == video_token_id
+
+ vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
+ image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
+ video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
+
+ return image_nums, video_nums
+
+ def _expand_inputs_for_generation(
+ self,
+ expand_size: int = 1,
+ is_encoder_decoder: bool = False,
+ input_ids: torch.LongTensor | None = None,
+ **model_kwargs,
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
+ position_ids = model_kwargs.pop("position_ids", None)
+ if expand_size == 1:
+ if position_ids is not None:
+ model_kwargs["position_ids"] = position_ids
+ return input_ids, model_kwargs
+
+ visual_keys = ["pixel_values", "image_grid_thw", "image_metadata"]
+ for key in visual_keys:
+ value = model_kwargs.get(key)
+ if value is not None:
+ model_kwargs[key] = value.repeat_interleave(expand_size, dim=0)
+
+ if input_ids is not None:
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
+
+ for key, value in list(model_kwargs.items()):
+ if key == "position_ids" and value is not None and value.ndim == 3:
+ model_kwargs[key] = value.repeat_interleave(expand_size, dim=1)
+ elif value is not None and isinstance(value, torch.Tensor) and key not in visual_keys:
+ model_kwargs[key] = value.repeat_interleave(expand_size, dim=0)
+
+ if position_ids is not None:
+ dim = 1 if position_ids.ndim == 3 else 0
+ model_kwargs["position_ids"] = position_ids.repeat_interleave(expand_size, dim=dim)
+ return input_ids, model_kwargs
+
+
+class SinglePoint(NamedTuple):
+ x: int
+ y: int
+ mention: str | None = None
+ t: float | None = None
+
+
+class BoundingBox(NamedTuple):
+ top_left: Any
+ bottom_right: Any
+ mention: str | None = None
+ t: float | None = None
+
+
+class Polygon(NamedTuple):
+ points: tuple[Any, ...]
+ mention: str | None = None
+ t: float | None = None
+
+
+__all__ = ["IsaacTextModel", "IsaacVisionModel", "IsaacModel", "IsaacPreTrainedModel", "IsaacForConditionalGeneration"]
diff --git a/src/transformers/models/isaac/modular_isaac.py b/src/transformers/models/isaac/modular_isaac.py
new file mode 100644
index 000000000000..5d2935064980
--- /dev/null
+++ b/src/transformers/models/isaac/modular_isaac.py
@@ -0,0 +1,1642 @@
+# Copyright 2026 Perceptron, Inc and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import math
+import re
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Any, NamedTuple
+
+from huggingface_hub.dataclasses import strict
+
+from ... import TorchvisionBackend
+from ... import initialization as init
+from ...cache_utils import Cache
+from ...configuration_utils import PretrainedConfig
+from ...feature_extraction_utils import BatchFeature
+from ...generation.utils import GenerationMixin
+from ...image_transforms import group_images_by_shape, reorder_images
+from ...image_utils import ImageInput, PILImageResampling, SizeDict, make_nested_list_of_images
+from ...masking_utils import create_bidirectional_mask
+from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast
+from ...modeling_utils import PreTrainedModel
+from ...models.qwen3.configuration_qwen3 import Qwen3Config
+from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
+from ...utils import TensorType, auto_docstring, torch_compilable_check
+from ...utils.constants import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
+from ...utils.generic import TransformersKwargs, can_return_tuple, merge_with_config_defaults
+from ...utils.import_utils import (
+ is_torch_available,
+ is_torchdynamo_compiling,
+ is_torchvision_available,
+ requires,
+)
+from ...utils.output_capturing import capture_outputs
+from ..qwen3_vl.modeling_qwen3_vl import (
+ Qwen3VLForConditionalGeneration,
+ Qwen3VLModel,
+ Qwen3VLTextDecoderLayer,
+ Qwen3VLTextModel,
+ Qwen3VLTextRotaryEmbedding,
+)
+from ..siglip2.configuration_siglip2 import Siglip2VisionConfig
+from ..siglip2.modeling_siglip2 import (
+ Siglip2Attention,
+ Siglip2Encoder,
+ Siglip2EncoderLayer,
+ Siglip2VisionEmbeddings,
+)
+
+
+if is_torch_available():
+ import torch
+ import torch.nn as nn
+ import torch.nn.functional as F
+ from torchvision.transforms.v2 import functional as tvF
+
+if is_torchvision_available():
+ from ..pix2struct.image_processing_pix2struct import torch_extract_patches
+
+
+@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base")
+@strict
+class IsaacVisionConfig(Siglip2VisionConfig):
+ r"""
+ num_patches (`int`, *optional*, defaults to 256):
+ The number of patches in the image with the size of (`patch_size`, `patch_size`). The image is resized to
+ fill a maximum of this number of patches while preserving the aspect ratio. If the resulting number of patches
+ is lower, the image is padded in the patch dimension.
+ pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1):
+ Spatial factor applied before pixel shuffle reduces the resolution.
+ """
+
+ model_type = "isaac_vision"
+ base_config_key = "vision_config"
+ pixel_shuffle_scale_factor: int = 1
+
+
+@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base")
+@strict
+class IsaacTextConfig(Qwen3Config):
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import IsaacTextConfig, IsaacTextModel
+
+ >>> configuration = IsaacTextConfig()
+ >>> model = IsaacTextModel(configuration)
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "isaac_text"
+ ignore_keys_at_rope_validation = {"mrope_section", "mrope_interleaved"}
+ max_position_embeddings: int = 32768
+ sliding_window = AttributeError()
+ layer_types = AttributeError()
+
+ def __post_init__(self, **kwargs):
+ if self.num_key_value_heads is None:
+ self.num_key_value_heads = self.num_attention_heads
+
+ PretrainedConfig.__post_init__(self, **kwargs)
+
+
+@auto_docstring(checkpoint="PerceptronAI/Isaac-0.1-Base")
+@strict
+class IsaacConfig(PretrainedConfig):
+ r"""
+ vision_config (`IsaacVisionConfig` or `dict`, *optional*):
+ Configuration for the Isaac vision tower. Dictionaries are converted to [`IsaacVisionConfig`]. If unset,
+ the default [`IsaacVisionConfig`] is used.
+ text_config (`IsaacTextConfig` or `dict`, *optional*):
+ Configuration for the text backbone. Dictionaries are converted to [`IsaacTextConfig`].
+ vision_rescale_factor (`float`, *optional*, defaults to 1 / 255):
+ Rescale factor applied by the image processor before normalization.
+ max_sequence_length (`int`, *optional*, defaults to 16384):
+ Maximum multimodal sequence length produced by the processor and expected by the model.
+
+ Example:
+
+ ```python
+ >>> from transformers import IsaacConfig, IsaacModel
+
+ >>> configuration = IsaacConfig()
+ >>> model = IsaacModel(configuration)
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "isaac"
+ sub_configs = {"vision_config": IsaacVisionConfig, "text_config": IsaacTextConfig}
+ vision_config: IsaacVisionConfig | dict | None = None
+ text_config: IsaacTextConfig | dict | None = None
+ vision_rescale_factor: float = 1 / 255
+ max_sequence_length: int = 16384
+
+ def __post_init__(self, **kwargs):
+ if isinstance(self.text_config, dict):
+ self.text_config = self.sub_configs["text_config"](**self.text_config)
+ elif self.text_config is None:
+ self.text_config = self.sub_configs["text_config"]()
+ elif not isinstance(self.text_config, IsaacTextConfig):
+ raise TypeError(
+ f"text_config must be a dict or an IsaacTextConfig instance, got {type(self.text_config).__name__}."
+ )
+
+ if isinstance(self.vision_config, dict):
+ self.vision_config = self.sub_configs["vision_config"](**self.vision_config)
+ elif self.vision_config is None:
+ self.vision_config = self.sub_configs["vision_config"]()
+ elif not isinstance(self.vision_config, IsaacVisionConfig):
+ raise TypeError(
+ f"vision_config must be a dict or an IsaacVisionConfig instance, got {type(self.vision_config).__name__}."
+ )
+
+ self.vision_rescale_factor = float(self.vision_rescale_factor)
+ super().__post_init__(**kwargs)
+
+
+class IsaacVisionEmbeddings(Siglip2VisionEmbeddings):
+ """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences.
+
+ Isaac accepts variable-resolution vision inputs as a single packed sequence with per-image
+ `token_grids`; packing/unpacking here reconstructs per-image shapes so we can resize positional
+ embeddings and build `cu_seqlens` for variable-length attention (not generic generation packing).
+ """
+
+ def __init__(self, config: IsaacVisionConfig):
+ super().__init__(config)
+ self.position_embedding = nn.Parameter(
+ torch.empty(
+ self.position_embedding_size,
+ self.position_embedding_size,
+ self.embed_dim,
+ )
+ )
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ image_grid_thw: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ # pixel_values: (num_images, max_patches, patch_dim)
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
+
+ resized_positional_embeddings = self.resize_positional_embeddings(
+ self.position_embedding,
+ image_grid_thw[:, 1:],
+ max_length=pixel_values.shape[1],
+ )
+ resized_positional_embeddings = resized_positional_embeddings.to(
+ device=patch_embeds.device, dtype=patch_embeds.dtype
+ )
+ embeddings = patch_embeds + resized_positional_embeddings
+
+ if attention_mask is not None:
+ embeddings = embeddings * attention_mask.unsqueeze(-1).to(device=embeddings.device, dtype=embeddings.dtype)
+
+ return embeddings
+
+
+class IsaacVisionAttention(Siglip2Attention):
+ """Custom attention that supports variable-length sequences with flash/SDPA backends."""
+
+ pass
+
+
+class IsaacVisionEncoderLayer(Siglip2EncoderLayer):
+ """Isaac vision encoder layer using the shared attention interfaces."""
+
+ def __init__(self, config: IsaacVisionConfig):
+ super().__init__(config)
+ self.self_attn = IsaacVisionAttention(config)
+
+
+class IsaacVisionEncoder(Siglip2Encoder):
+ """Encoder using Isaac encoder layers."""
+
+ def __init__(self, config: IsaacVisionConfig):
+ super().__init__(config)
+ self.layers = nn.ModuleList([IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+
+
+@auto_docstring
+class IsaacVisionModel(PreTrainedModel):
+ config: IsaacVisionConfig
+ _supports_sdpa = True
+ _supports_flash_attn = True
+ _can_record_outputs = {
+ "hidden_states": IsaacVisionEncoderLayer,
+ "attentions": IsaacVisionAttention,
+ }
+
+ def __init__(self, config: IsaacVisionConfig):
+ super().__init__(config)
+ self.embeddings = IsaacVisionEmbeddings(config)
+ self.encoder = IsaacVisionEncoder(config)
+ self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor
+
+ self.post_init()
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, IsaacVisionEmbeddings):
+ init.zeros_(module.position_embedding)
+
+ def pixel_shuffle_padded(
+ self,
+ hidden_states: torch.Tensor,
+ token_grids: torch.Tensor,
+ ) -> torch.Tensor:
+ """Apply pixel shuffle per image on padded batched vision embeddings.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Vision embeddings of shape `(num_images, max_patches, hidden_size)`.
+ token_grids (`torch.Tensor`):
+ Grid sizes `(height, width)` per image, shape `(num_images, 2)`.
+
+ Returns:
+ `torch.Tensor`: Pixel-shuffled embeddings of shape
+ `(num_images, max_tokens, hidden_size * scale_factor**2)`.
+ """
+ scale_factor = self.pixel_shuffle_scale_factor
+ num_images, max_patches, embed_dim = hidden_states.shape
+ output_dim = embed_dim * scale_factor * scale_factor
+
+ token_grids = token_grids.to(device=hidden_states.device, dtype=torch.long)
+ heights = token_grids[:, 0]
+ widths = token_grids[:, 1]
+ full_lengths = heights * widths
+
+ non_empty = full_lengths > 0
+ if not is_torchdynamo_compiling():
+ divisible = ((heights % scale_factor) == 0) & ((widths % scale_factor) == 0)
+ torch_compilable_check(
+ (~non_empty) | divisible,
+ f"Every non-empty (H, W) grid must be divisible by pixel_shuffle_scale={scale_factor}.",
+ )
+
+ output_lengths = (heights // scale_factor) * (widths // scale_factor)
+ max_output_tokens = output_lengths.max()
+ shuffled_4d = hidden_states.new_zeros((num_images, max_output_tokens, scale_factor * scale_factor, embed_dim))
+
+ token_positions = (
+ torch.arange(max_patches, device=hidden_states.device, dtype=torch.long)
+ .unsqueeze(0)
+ .expand(num_images, -1)
+ )
+ valid_token_mask = token_positions < full_lengths.unsqueeze(1)
+
+ safe_widths = torch.where(widths > 0, widths, torch.ones_like(widths))
+ row_index = torch.div(token_positions, safe_widths.unsqueeze(1), rounding_mode="floor")
+ col_index = token_positions.remainder(safe_widths.unsqueeze(1))
+
+ output_widths = widths.div(scale_factor, rounding_mode="floor")
+ output_index = row_index.div(scale_factor, rounding_mode="floor") * output_widths.unsqueeze(1)
+ output_index = output_index + col_index.div(scale_factor, rounding_mode="floor")
+ sub_index = row_index.remainder(scale_factor) * scale_factor + col_index.remainder(scale_factor)
+
+ batch_index = (
+ torch.arange(num_images, device=hidden_states.device, dtype=torch.long)
+ .unsqueeze(1)
+ .expand_as(token_positions)
+ )
+ shuffled_4d[batch_index[valid_token_mask], output_index[valid_token_mask], sub_index[valid_token_mask]] = (
+ hidden_states[valid_token_mask]
+ )
+
+ shuffled = shuffled_4d.view(num_images, max_output_tokens, output_dim)
+ return shuffled
+
+ @merge_with_config_defaults
+ @capture_outputs(tie_last_hidden_states=False)
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ image_grid_thw: torch.Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ image_grid_thw (`torch.LongTensor`, *optional*):
+ Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries.
+ """
+ full_lengths = image_grid_thw[:, 1] * image_grid_thw[:, 2]
+ token_positions = torch.arange(pixel_values.shape[1], device=pixel_values.device, dtype=torch.long)
+ image_patch_attention_mask = token_positions.unsqueeze(0) < full_lengths.unsqueeze(1)
+ image_patch_attention_mask = image_patch_attention_mask.to(dtype=torch.long)
+ hidden_states = self.embeddings(
+ pixel_values,
+ image_grid_thw,
+ attention_mask=image_patch_attention_mask,
+ )
+
+ encoder_attention_mask = create_bidirectional_mask(
+ config=self.config,
+ inputs_embeds=hidden_states,
+ attention_mask=image_patch_attention_mask,
+ )
+ encoder_outputs = self.encoder(inputs_embeds=hidden_states, attention_mask=encoder_attention_mask, **kwargs)
+ hidden_states = self.post_layernorm(encoder_outputs.last_hidden_state)
+
+ hidden_states = self.pixel_shuffle_padded(
+ hidden_states=hidden_states,
+ token_grids=image_grid_thw[:, 1:],
+ )
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=hidden_states,
+ pooler_output=None,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class IsaacRotaryEmbedding(Qwen3VLTextRotaryEmbedding):
+ def __init__(self, config: IsaacTextConfig, device=None):
+ super().__init__(config, device=device)
+ self.mrope_section = config.rope_parameters.get("mrope_section")
+ if self.mrope_section is None:
+ weights = (2, 1, 1)
+ self.mrope_section = [self.inv_freq.shape[0] * w // sum(weights) for w in weights]
+ self.mrope_section[0] += self.inv_freq.shape[0] - sum(self.mrope_section)
+
+ def apply_interleaved_mrope(self, freqs: torch.Tensor, mrope_section: list[int]) -> torch.Tensor:
+ chunks = freqs.split(tuple(mrope_section), dim=-1)
+ return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1)
+
+
+class IsaacTextDecoderLayer(Qwen3VLTextDecoderLayer):
+ pass
+
+
+class IsaacTextModel(Qwen3VLTextModel):
+ def __init__(self, config: IsaacTextConfig):
+ super().__init__(config)
+ self.rotary_emb = IsaacRotaryEmbedding(config=config, device=self.device)
+
+
+class IsaacMultiModalProjector(nn.Module):
+ """Maps vision tower outputs to the text hidden size with a SiLU MLP."""
+
+ def __init__(self, config: IsaacConfig):
+ super().__init__()
+ text_config = config.get_text_config()
+ vision_hidden_size = config.vision_config.hidden_size * (config.vision_config.pixel_shuffle_scale_factor**2)
+ backbone_hidden_size = text_config.hidden_size
+ self.linear_1 = nn.Linear(vision_hidden_size, 4 * vision_hidden_size, bias=False)
+ self.silu = nn.SiLU()
+ self.linear_2 = nn.Linear(4 * vision_hidden_size, backbone_hidden_size, bias=False)
+
+ def forward(self, image_features):
+ hidden_states = self.linear_1(image_features)
+ hidden_states = self.silu(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+@auto_docstring
+class IsaacModel(Qwen3VLModel):
+ input_modalities = ("image", "text")
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionEncoderLayer"]
+ _can_compile_fullgraph = False
+ _supports_flex_attn = False
+ _tied_weights_keys = {}
+ _input_embed_layer = "language_model.embed_tokens"
+
+ def __init__(self, config: IsaacConfig):
+ PreTrainedModel.__init__(self, config)
+ self.language_model = IsaacTextModel._from_config(config.text_config)
+ self.visual = IsaacVisionModel(config.vision_config)
+ self.multimodal_projector = IsaacMultiModalProjector(config)
+ self.max_sequence_length = config.max_sequence_length
+ self.vision_rescale_factor = config.vision_rescale_factor
+ self.rope_deltas = None
+
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def get_image_features(
+ self,
+ pixel_values: torch.Tensor,
+ image_grid_thw: torch.Tensor,
+ image_metadata: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | BaseModelOutputWithPooling:
+ r"""
+ pixel_values (`torch.FloatTensor`, *optional*):
+ Batch-major patch vectors shaped `(batch_size, max_images, max_patches, patch_dim)`.
+ image_grid_thw (`torch.LongTensor`, *optional*):
+ Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries.
+ image_metadata (`torch.Tensor`, *optional*):
+ Batch-major per-slot metadata `(offset, length)` shaped `(batch_size, max_images, 2)`.
+ """
+ active_slot_mask = image_grid_thw[..., 0].eq(1)
+ flat_pixel_values = pixel_values[active_slot_mask]
+ flat_image_grid_thw = image_grid_thw[active_slot_mask]
+
+ vision_outputs: BaseModelOutputWithPooling = self.visual(
+ pixel_values=flat_pixel_values,
+ image_grid_thw=flat_image_grid_thw,
+ return_dict=True,
+ **kwargs,
+ )
+ projected_features = self.multimodal_projector(vision_outputs.last_hidden_state)
+
+ # Truncate image features using offset and length
+ if image_metadata is None:
+ pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor
+ downsampled_height = flat_image_grid_thw[:, 1].div(pixel_shuffle_scale, rounding_mode="floor")
+ downsampled_width = flat_image_grid_thw[:, 2].div(pixel_shuffle_scale, rounding_mode="floor")
+ lengths = downsampled_height * downsampled_width
+ offsets = torch.zeros_like(lengths)
+ else:
+ torch_compilable_check(
+ image_metadata.shape[:2] == image_grid_thw.shape[:2],
+ "IsaacModel.get_image_features expects batch-major metadata aligned with `image_grid_thw`.",
+ )
+ offsets = image_metadata[active_slot_mask][:, 0]
+ lengths = image_metadata[active_slot_mask][:, 1]
+
+ image_features = tuple(
+ projected_features[image_idx, offset : offset + length]
+ for image_idx, (offset, length) in enumerate(zip(offsets.tolist(), lengths.tolist(), strict=True))
+ )
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=projected_features,
+ pooler_output=image_features,
+ hidden_states=vision_outputs.hidden_states,
+ attentions=vision_outputs.attentions,
+ )
+
+ def get_placeholder_mask(
+ self,
+ mm_token_type_ids: torch.LongTensor,
+ inputs_embeds: torch.FloatTensor,
+ image_features: torch.FloatTensor,
+ ) -> torch.BoolTensor:
+ image_token_mask = mm_token_type_ids.to(dtype=torch.long) == 1
+ n_image_tokens = image_token_mask.sum()
+ image_token_mask = image_token_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ torch_compilable_check(
+ inputs_embeds[image_token_mask].numel() == image_features.numel(),
+ f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}",
+ )
+ return image_token_mask
+
+ def get_video_features(self, **super_kwargs):
+ raise AttributeError("Isaac is image-only and does not support `pixel_values_videos` or `video_grid_thw`.")
+
+ def get_vision_position_ids(
+ self,
+ start_position: int,
+ grid_thw: torch.LongTensor,
+ image_metadata: torch.LongTensor,
+ ) -> torch.LongTensor:
+ pixel_shuffle_scale = self.config.vision_config.pixel_shuffle_scale_factor
+ height = grid_thw[1].div(pixel_shuffle_scale, rounding_mode="floor").item()
+ width = grid_thw[2].div(pixel_shuffle_scale, rounding_mode="floor").item()
+ token_positions = torch.arange(height * width, device=grid_thw.device, dtype=torch.long)
+ vision_position_ids = torch.stack(
+ (
+ torch.full((token_positions.shape[0],), start_position, device=grid_thw.device, dtype=torch.long),
+ token_positions.div(width, rounding_mode="floor"),
+ token_positions.remainder(width),
+ ),
+ dim=0,
+ )
+ token_offset = int(image_metadata[0].item())
+ token_length = int(image_metadata[1].item())
+ return vision_position_ids[:, token_offset : token_offset + token_length]
+
+ def get_rope_index(
+ self,
+ input_ids: torch.LongTensor,
+ mm_token_type_ids: torch.Tensor,
+ image_grid_thw: torch.Tensor,
+ image_metadata: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if attention_mask is None:
+ if input_ids is None:
+ attention_mask = mm_token_type_ids.new_ones(mm_token_type_ids.shape, dtype=torch.long)
+ else:
+ attention_mask = input_ids.new_ones(input_ids.shape, dtype=torch.long)
+
+ if input_ids is None:
+ batch_size, seq_len = attention_mask.shape
+ position_dtype = torch.long
+ else:
+ batch_size, seq_len = input_ids.shape
+ position_dtype = input_ids.dtype
+
+ device = attention_mask.device
+ mm_token_type_ids = mm_token_type_ids.to(dtype=torch.long)
+ image_grid_thw = image_grid_thw.to(dtype=torch.long)
+ image_metadata = image_metadata.to(dtype=torch.long)
+ attention_mask = attention_mask.to(dtype=torch.long)
+ active_slot_mask = image_grid_thw[..., 0].eq(1)
+
+ position_ids = torch.zeros((3, batch_size, seq_len), device=device, dtype=position_dtype)
+ rope_deltas = torch.zeros((batch_size, 1), device=device, dtype=torch.long)
+
+ for batch_idx in range(batch_size):
+ sample_attention_mask = attention_mask[batch_idx].bool()
+ sample_token_types = mm_token_type_ids[batch_idx][sample_attention_mask]
+ sample_grids = image_grid_thw[batch_idx]
+ sample_metadata = image_metadata[batch_idx]
+ sample_active_slots = active_slot_mask[batch_idx]
+
+ current_pos = 0
+ image_idx = 0
+ seq_pos = 0
+ llm_pos_ids_list = []
+
+ while seq_pos < sample_token_types.shape[0]:
+ modality_type = int(sample_token_types[seq_pos].item())
+ if modality_type == 0:
+ group_end = seq_pos + 1
+ while group_end < sample_token_types.shape[0] and sample_token_types[group_end] == 0:
+ group_end += 1
+ group_length = group_end - seq_pos
+ llm_pos_ids_list.append(
+ torch.arange(group_length, device=device, dtype=torch.long).view(1, -1).expand(3, -1)
+ + current_pos
+ )
+ current_pos += group_length
+ seq_pos = group_end
+ else:
+ while image_idx < sample_metadata.shape[0] and (
+ not bool(sample_active_slots[image_idx].item()) or sample_metadata[image_idx, 1].item() == 0
+ ):
+ image_idx += 1
+ torch_compilable_check(
+ image_idx < sample_metadata.shape[0],
+ "Isaac multimodal sequence has more visible image tokens than batch-major image metadata slots.",
+ )
+ token_length = int(sample_metadata[image_idx, 1].item())
+ torch_compilable_check(
+ token_length <= sample_token_types.shape[0] - seq_pos,
+ "Isaac image metadata length exceeds the remaining multimodal placeholder span.",
+ )
+ llm_pos_ids_list.append(
+ self.get_vision_position_ids(current_pos, sample_grids[image_idx], sample_metadata[image_idx])
+ )
+ current_pos += 1
+ seq_pos += token_length
+ image_idx += 1
+
+ llm_positions = (
+ torch.cat(llm_pos_ids_list, dim=1)
+ if llm_pos_ids_list
+ else torch.zeros((3, 0), device=device, dtype=torch.long)
+ )
+ position_ids[:, batch_idx, sample_attention_mask] = llm_positions
+ rope_deltas[batch_idx, 0] = (
+ llm_positions.max() + 1 - sample_token_types.shape[0] if llm_positions.numel() > 0 else 0
+ )
+
+ return position_ids, rope_deltas
+
+ def compute_3d_position_ids(
+ self,
+ input_ids: torch.Tensor,
+ inputs_embeds: torch.Tensor,
+ attention_mask: torch.Tensor | None,
+ mm_token_type_ids: torch.Tensor | None = None,
+ image_grid_thw: torch.Tensor | None = None,
+ image_metadata: torch.Tensor | None = None,
+ past_key_values: Cache | None = None,
+ ) -> torch.Tensor:
+ past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length()
+ has_multimodal = (
+ image_grid_thw is not None
+ and image_metadata is not None
+ and bool(image_grid_thw[..., 0].eq(1).any().item())
+ )
+ if has_multimodal and mm_token_type_ids is None and input_ids is not None:
+ raise ValueError(
+ "Multimodal data was passed (via `image_grid_thw` or `image_metadata`) but `mm_token_type_ids` is "
+ "missing. Please pass `mm_token_type_ids` to the model so that multimodal RoPE (M-RoPE) can be "
+ "computed correctly. `mm_token_type_ids` is returned by the processor alongside `input_ids`."
+ )
+
+ if has_multimodal and past_seen_tokens == 0:
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids=input_ids,
+ mm_token_type_ids=mm_token_type_ids,
+ image_grid_thw=image_grid_thw,
+ image_metadata=image_metadata,
+ attention_mask=attention_mask,
+ )
+ self.rope_deltas = rope_deltas
+ return position_ids
+
+ if self.rope_deltas is None:
+ return None
+
+ rope_deltas = torch.as_tensor(self.rope_deltas, device=inputs_embeds.device, dtype=torch.long).reshape(-1, 1)
+ if rope_deltas.shape[0] != inputs_embeds.shape[0]:
+ if inputs_embeds.shape[0] % rope_deltas.shape[0] == 0:
+ rope_deltas = rope_deltas.repeat_interleave(inputs_embeds.shape[0] // rope_deltas.shape[0], dim=0)
+ else:
+ rope_deltas = rope_deltas[:1].expand(inputs_embeds.shape[0], -1)
+
+ if attention_mask is not None and attention_mask.shape[-1] > inputs_embeds.shape[1]:
+ rope_position = attention_mask.long().cumsum(dim=-1) - 1
+ rope_position = rope_position.masked_fill(attention_mask == 0, 0)
+ rope_position = rope_position[:, -inputs_embeds.shape[1] :]
+ else:
+ rope_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device,
+ dtype=torch.long,
+ ).view(1, -1)
+ rope_position = rope_position.expand(inputs_embeds.shape[0], -1)
+
+ position_ids = rope_position.view(1, inputs_embeds.shape[0], -1).expand(3, -1, -1)
+ return position_ids + rope_deltas.to(device=inputs_embeds.device).unsqueeze(0)
+
+ @auto_docstring(
+ custom_intro="""
+ Forward pass with multimodal MRoPE position ids.
+
+ When image placeholders are present, Isaac computes vision features, scatters them into the token
+ embeddings, and runs the shared text backbone on the mixed sequence.
+ """,
+ )
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ mm_token_type_ids: torch.LongTensor | None = None,
+ pixel_values: torch.Tensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ image_metadata: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | BaseModelOutputWithPast:
+ r"""
+ mm_token_type_ids (`torch.LongTensor`, *optional*):
+ Multimodal token type ids aligned with the token sequence, using `0 -> text` and `1 -> image`.
+ pixel_values (`torch.FloatTensor`, *optional*):
+ Batch-major patch vectors shaped `(batch_size, max_images, max_patches, patch_dim)`.
+ image_grid_thw (`torch.LongTensor`, *optional*):
+ Batch-major per-slot grids shaped `(batch_size, max_images, 3)` with `(T=1, H, W)` entries.
+ image_metadata (`torch.LongTensor`, *optional*):
+ Batch-major per-slot metadata shaped `(batch_size, max_images, 2)` with `(offset, length)`.
+ """
+ if (input_ids is None) == (inputs_embeds is None):
+ raise ValueError("You must specify exactly one of `input_ids` or `inputs_embeds`.")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ image_mask = None
+ if pixel_values is not None and image_grid_thw is not None:
+ image_outputs = self.get_image_features(
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ image_metadata=image_metadata,
+ return_dict=True,
+ )
+ image_embeds = image_outputs.pooler_output
+ if len(image_embeds) > 0:
+ image_embeds = torch.cat(image_embeds, dim=0).to(
+ device=inputs_embeds.device, dtype=inputs_embeds.dtype
+ )
+ image_mask = self.get_placeholder_mask(
+ mm_token_type_ids=mm_token_type_ids,
+ inputs_embeds=inputs_embeds,
+ image_features=image_embeds,
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if isinstance(attention_mask, dict):
+ attention_mask = attention_mask["full_attention"]
+
+ past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length()
+ computed_position_ids = self.compute_3d_position_ids(
+ input_ids=input_ids,
+ inputs_embeds=inputs_embeds,
+ mm_token_type_ids=mm_token_type_ids,
+ image_grid_thw=image_grid_thw,
+ image_metadata=image_metadata,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ )
+ if computed_position_ids is not None:
+ position_ids = computed_position_ids
+ elif past_seen_tokens > 0:
+ position_ids = None
+ elif position_ids is not None and past_seen_tokens == 0:
+ position_ids = position_ids.to(device=inputs_embeds.device)
+ if position_ids.ndim == 2:
+ position_ids = position_ids.view(1, position_ids.shape[0], -1).expand(3, -1, -1)
+
+ outputs = self.language_model(
+ input_ids=None,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ visual_pos_masks=image_mask[..., 0] if image_mask is not None else None,
+ deepstack_visual_embeds=None,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@dataclass
+class IsaacCausalLMOutputWithPast(CausalLMOutputWithPast):
+ """
+ Base class for Isaac causal language model (or autoregressive) outputs.
+
+ Args:
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ rope_deltas: torch.LongTensor | None = None
+
+
+@auto_docstring
+class IsaacForConditionalGeneration(Qwen3VLForConditionalGeneration, GenerationMixin):
+ config_class = IsaacConfig
+ input_modalities = ("image", "text")
+ _no_split_modules = ["IsaacTextDecoderLayer", "IsaacVisionEncoderLayer"]
+ _can_compile_fullgraph = False
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+
+ def __init__(self, config: IsaacConfig):
+ PreTrainedModel.__init__(self, config)
+ self.model = IsaacModel(config)
+ self.vocab_size = config.get_text_config().vocab_size
+ self.lm_head = nn.Linear(config.get_text_config().hidden_size, config.get_text_config().vocab_size, bias=False)
+ self.post_init()
+
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ pixel_values: torch.Tensor | None = None,
+ pixel_values_videos: torch.FloatTensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ video_grid_thw: torch.LongTensor | None = None,
+ mm_token_type_ids: torch.IntTensor | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | IsaacCausalLMOutputWithPast:
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ mm_token_type_ids=mm_token_type_ids,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ return IsaacCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.model.rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values: Cache = None,
+ attention_mask: torch.Tensor | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ mm_token_type_ids: torch.LongTensor | None = None,
+ pixel_values: torch.Tensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ image_metadata: torch.LongTensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ is_first_iteration: bool = False,
+ use_cache: bool = True,
+ **kwargs,
+ ) -> dict[str, Any]:
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ is_first_iteration=is_first_iteration,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ multimodal_inputs = {
+ "mm_token_type_ids": mm_token_type_ids,
+ "pixel_values": pixel_values,
+ "image_grid_thw": image_grid_thw,
+ "image_metadata": image_metadata,
+ }
+ is_prefill = is_first_iteration or not use_cache
+ for key, value in multimodal_inputs.items():
+ model_inputs[key] = value if is_prefill else None
+ if model_inputs["mm_token_type_ids"] is not None:
+ sequence_length = None
+ if model_inputs.get("input_ids") is not None:
+ sequence_length = model_inputs["input_ids"].shape[1]
+ elif model_inputs.get("inputs_embeds") is not None:
+ sequence_length = model_inputs["inputs_embeds"].shape[1]
+
+ if sequence_length is not None:
+ current_length = model_inputs["mm_token_type_ids"].shape[1]
+ if current_length < sequence_length:
+ padding = model_inputs["mm_token_type_ids"].new_zeros(
+ (model_inputs["mm_token_type_ids"].shape[0], sequence_length - current_length)
+ )
+ model_inputs["mm_token_type_ids"] = torch.cat([model_inputs["mm_token_type_ids"], padding], dim=1)
+ elif current_length > sequence_length:
+ model_inputs["mm_token_type_ids"] = model_inputs["mm_token_type_ids"][:, -sequence_length:]
+
+ return model_inputs
+
+ def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs):
+ text_positions = GenerationMixin._prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs)
+
+ past_length = 0
+ if (cache := model_kwargs.get("past_key_values")) is not None:
+ past_length = cache.get_seq_length()
+ if past_length != 0 and self.model.rope_deltas is not None:
+ return text_positions[None, ...] + self.model.rope_deltas
+
+ if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
+ inputs_tensor = model_kwargs["input_ids"]
+
+ is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
+ if (
+ is_input_ids
+ and model_kwargs.get("mm_token_type_ids") is not None
+ and model_kwargs.get("image_grid_thw") is not None
+ and model_kwargs.get("image_metadata") is not None
+ ):
+ model_kwargs = {k: v for k, v in model_kwargs.items() if k != "input_ids"}
+ vision_positions, rope_deltas = self.model.get_rope_index(inputs_tensor, **model_kwargs)
+ self.model.rope_deltas = rope_deltas
+ else:
+ vision_positions = text_positions.unsqueeze(0).expand(3, -1, -1)
+ self.model.rope_deltas = torch.zeros(
+ inputs_tensor.shape[0], 1, dtype=torch.long, device=inputs_tensor.device
+ )
+
+ return torch.cat([text_positions[None, ...], vision_positions], dim=0)
+
+ def _expand_inputs_for_generation(
+ self,
+ expand_size: int = 1,
+ is_encoder_decoder: bool = False,
+ input_ids: torch.LongTensor | None = None,
+ **model_kwargs,
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
+ position_ids = model_kwargs.pop("position_ids", None)
+ if expand_size == 1:
+ if position_ids is not None:
+ model_kwargs["position_ids"] = position_ids
+ return input_ids, model_kwargs
+
+ visual_keys = ["pixel_values", "image_grid_thw", "image_metadata"]
+ for key in visual_keys:
+ value = model_kwargs.get(key)
+ if value is not None:
+ model_kwargs[key] = value.repeat_interleave(expand_size, dim=0)
+
+ if input_ids is not None:
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
+
+ for key, value in list(model_kwargs.items()):
+ if key == "position_ids" and value is not None and value.ndim == 3:
+ model_kwargs[key] = value.repeat_interleave(expand_size, dim=1)
+ elif value is not None and isinstance(value, torch.Tensor) and key not in visual_keys:
+ model_kwargs[key] = value.repeat_interleave(expand_size, dim=0)
+
+ if position_ids is not None:
+ dim = 1 if position_ids.ndim == 3 else 0
+ model_kwargs["position_ids"] = position_ids.repeat_interleave(expand_size, dim=dim)
+ return input_ids, model_kwargs
+
+
+# --------------------------------Isaac Image Processor--------------------------------
+
+
+class IsaacImageProcessorKwargs(ImagesKwargs, total=False):
+ """
+ patch_size (`int`, *optional*):
+ Side length (in pixels) for square patches extracted from resized images.
+ max_num_patches (`int`, *optional*):
+ Upper bound on extracted patches per image after resizing.
+ min_num_patches (`int`, *optional*):
+ Lower bound on extracted patches per image after resizing.
+ pixel_shuffle_scale (`int`, *optional*):
+ Pixel-shuffle reduction factor applied in the vision tower.
+ """
+
+ patch_size: int
+ max_num_patches: int
+ min_num_patches: int
+ pixel_shuffle_scale: int
+
+
+def get_scaled_image_size(
+ scale: float,
+ original_size: int,
+ patch_size: int,
+ pixel_shuffle_scale: int,
+) -> int:
+ scaled_size = scale * original_size
+ divisor = patch_size * pixel_shuffle_scale
+ scaled_size = math.ceil(scaled_size / divisor) * divisor
+ scaled_size = max(divisor, scaled_size)
+ return int(scaled_size)
+
+
+def get_image_size_for_max_num_patches(
+ image_height: int,
+ image_width: int,
+ patch_size: int,
+ max_num_patches: int,
+ min_num_patches: int | None = None,
+ eps: float = 1e-5,
+ pixel_shuffle_scale: int = 1,
+) -> tuple[int, int]:
+ r"""Compute a target resolution whose patch grid satisfies patching parametrization.
+
+ Args:
+ image_height (`int`):
+ Height in pixels of the source image prior to any resizing.
+ image_width (`int`):
+ Width in pixels of the source image prior to any resizing.
+ patch_size (`int`):
+ Size of the square patch used by the vision encoder.
+ max_num_patches (`int`):
+ Upper bound on `(height / patch_size) * (width / patch_size)` after resizing.
+ min_num_patches (`int`, *optional*):
+ Lower bound on the number of patches. When provided the image will be scaled up if necessary.
+ eps (`float`, *optional*, defaults to 1e-5):
+ Convergence tolerance for the internal binary search to determing the target dimensions.
+ pixel_shuffle_scale (`int`, *optional*, defaults to 1):
+ Additional stride multiplier applied when pixel shuffle later reduces spatial resolution.
+
+ Returns:
+ `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale`
+ and respect both the maximum and optional minimum patch-count constraints.
+ """
+
+ # Ensure divisibility
+ divisor = patch_size * pixel_shuffle_scale
+ adjusted_height = math.ceil(image_height / divisor) * divisor
+ adjusted_height = max(divisor, adjusted_height)
+ adjusted_width = math.ceil(image_width / divisor) * divisor
+ adjusted_width = max(divisor, adjusted_width)
+
+ num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size)
+
+ if min_num_patches is not None and num_patches < min_num_patches:
+ # Scale up via binary search to satisfy the minimum patch budget while
+ # preserving divisibility by patch_size * pixel_shuffle_scale.
+ scale_min, scale_max = 1.0, 100.0
+ while (scale_max - scale_min) >= eps:
+ scale = (scale_min + scale_max) / 2
+ target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale)
+ target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale)
+ num_patches = (target_height / patch_size) * (target_width / patch_size)
+ if num_patches >= min_num_patches:
+ scale_max = scale
+ else:
+ scale_min = scale
+ scale = scale_max
+ target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale)
+ target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale)
+ return target_height, target_width
+ elif num_patches <= max_num_patches:
+ return adjusted_height, adjusted_width
+ else:
+ # Scale down
+ scale_min, scale_max = eps / 10, 1.0
+ while (scale_max - scale_min) >= eps:
+ scale = (scale_min + scale_max) / 2
+ target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale)
+ target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale)
+ num_patches = (target_height / patch_size) * (target_width / patch_size)
+ if num_patches <= max_num_patches:
+ scale_min = scale
+ else:
+ scale_max = scale
+ scale = scale_min
+ target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale)
+ target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale)
+ return target_height, target_width
+
+
+@auto_docstring
+@requires(backends=("vision",))
+class IsaacImageProcessor(TorchvisionBackend):
+ model_input_names = ["pixel_values", "image_grid_thw"]
+ valid_kwargs = IsaacImageProcessorKwargs
+
+ resample = PILImageResampling.BILINEAR
+ do_resize = True
+ do_center_crop = False
+ patch_size = 16
+ max_num_patches = 256
+ min_num_patches = None
+ pixel_shuffle_scale = 1
+ do_pad = True
+ do_rescale = True
+ do_normalize = True
+ image_mean = IMAGENET_STANDARD_MEAN
+ image_std = IMAGENET_STANDARD_STD
+ do_convert_rgb = True
+ disable_grouping = False
+
+ def __init__(self, **kwargs: Unpack[IsaacImageProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ def _validate_preprocess_kwargs(self, **kwargs):
+ # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for.
+ kwargs.pop("do_resize", None)
+ return super()._validate_preprocess_kwargs(**kwargs)
+
+ def _prepare_images_structure(
+ self,
+ images: ImageInput,
+ expected_ndims: int = 3,
+ ) -> ImageInput:
+ images = self.fetch_images(images)
+ return make_nested_list_of_images(images, expected_ndims=expected_ndims)
+
+ def resize(
+ self,
+ image: torch.Tensor,
+ size: SizeDict,
+ **kwargs,
+ ) -> torch.Tensor:
+ if image.dtype == torch.uint8:
+ image = F.interpolate(image.float(), size=(size.height, size.width), mode="bilinear", align_corners=False)
+ return image.clamp(0, 255).round().to(torch.uint8)
+ return F.interpolate(image, size=(size.height, size.width), mode="bilinear", align_corners=False)
+
+ def pack_images(
+ self,
+ vision_patches: list[list[torch.Tensor]],
+ vision_token_grids: list[list[torch.Tensor]],
+ ) -> dict[str, torch.Tensor | None]:
+ batch_size = len(vision_patches)
+ flat_patches = [patches for sample_patches in vision_patches for patches in sample_patches]
+ if len(flat_patches) == 0:
+ return {"pixel_values": None, "image_grid_thw": None}
+
+ first_patch = flat_patches[0]
+ max_patches = max(patches.shape[0] for patches in flat_patches)
+ max_images = max((len(sample_patches) for sample_patches in vision_patches), default=0)
+
+ patch_dim = first_patch.shape[-1]
+ tensors = {
+ "pixel_values": torch.zeros(
+ (batch_size, max_images, max_patches, patch_dim),
+ device=first_patch.device,
+ dtype=first_patch.dtype,
+ ),
+ "image_grid_thw": torch.zeros((batch_size, max_images, 3), device=first_patch.device, dtype=torch.long),
+ }
+
+ for batch_idx, (sample_patches, sample_token_grids) in enumerate(
+ zip(vision_patches, vision_token_grids, strict=True)
+ ):
+ for image_slot, (patches, token_grid) in enumerate(zip(sample_patches, sample_token_grids, strict=True)):
+ patch_count = int(patches.shape[0])
+ tensors["pixel_values"][batch_idx, image_slot, :patch_count] = patches
+ tensors["image_grid_thw"][batch_idx, image_slot, 0] = 1
+ tensors["image_grid_thw"][batch_idx, image_slot, 1:] = token_grid
+
+ return tensors
+
+ def _preprocess(
+ self,
+ images: list[list[torch.Tensor]],
+ do_resize: bool,
+ resample: PILImageResampling | tvF.InterpolationMode | int | None,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: float | list[float] | None,
+ image_std: float | list[float] | None,
+ do_pad: bool,
+ patch_size: int,
+ max_num_patches: int,
+ min_num_patches: int,
+ pixel_shuffle_scale: int,
+ disable_grouping: bool | None,
+ return_tensors: str | TensorType | None,
+ **kwargs,
+ ) -> BatchFeature:
+ if all(len(sample_images) == 0 for sample_images in images):
+ return BatchFeature(data={"pixel_values": None, "image_grid_thw": None}, tensor_type=return_tensors)
+
+ grouped_images, grouped_images_index = group_images_by_shape(
+ images, disable_grouping=disable_grouping, is_nested=True
+ )
+ grouped_outputs = {}
+ for shape, stacked_images in grouped_images.items():
+ grouped_batch_size, channels, original_height, original_width = stacked_images.shape
+ if do_resize:
+ target_height, target_width = get_image_size_for_max_num_patches(
+ original_height,
+ original_width,
+ patch_size,
+ max_num_patches,
+ min_num_patches=min_num_patches,
+ pixel_shuffle_scale=pixel_shuffle_scale,
+ )
+ image_batch = self.resize(
+ stacked_images, SizeDict(height=target_height, width=target_width), resample=resample
+ )
+ else:
+ if (original_height % patch_size) or (original_width % patch_size):
+ raise ValueError(
+ f"Image dimensions (h={original_height}, w={original_width}) must be divisible by patch_size={patch_size} when resize is disabled; enable resizing or adjust the input resolution."
+ )
+ image_batch, target_height, target_width = stacked_images, original_height, original_width
+
+ image_batch = self.rescale_and_normalize(
+ image_batch,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ )
+
+ patches = torch_extract_patches(image_batch, patch_size, patch_size)
+ _, height_tokens, width_tokens, patch_dim = patches.shape
+
+ token_grid = (
+ torch.tensor([height_tokens, width_tokens], device=patches.device).long().expand(grouped_batch_size, 2)
+ )
+
+ if (height_tokens % pixel_shuffle_scale) or (width_tokens % pixel_shuffle_scale):
+ raise ValueError(
+ f"Token grid (h={height_tokens}, w={width_tokens}) must be divisible by pixel_shuffle_scale={pixel_shuffle_scale};"
+ f" adjust resize/patch parameters or disable pixel shuffle."
+ )
+
+ grouped_outputs[shape] = (
+ patches.reshape(grouped_batch_size, -1, patch_dim),
+ token_grid,
+ )
+
+ keys = ("vision_patches", "vision_token_grids")
+ nested_outputs = {}
+ for i, key in enumerate(keys):
+ nested_outputs[key] = reorder_images(
+ {shape: values[i] for shape, values in grouped_outputs.items()},
+ dict(grouped_images_index),
+ is_nested=True,
+ )
+
+ if not do_pad:
+ raise ValueError("IsaacImageProcessor doesn't support `do_pad=False` mode.")
+
+ tensors = self.pack_images(
+ vision_patches=nested_outputs["vision_patches"],
+ vision_token_grids=nested_outputs["vision_token_grids"],
+ )
+
+ return BatchFeature(data=tensors, tensor_type=return_tensors)
+
+ def get_number_of_image_patches(
+ self,
+ image_height: int,
+ image_width: int,
+ images_kwargs: dict[str, Any] | None = None,
+ ) -> int:
+ images_kwargs = images_kwargs or {}
+ patch_size = images_kwargs.get("patch_size", self.patch_size)
+ max_num_patches = images_kwargs.get("max_num_patches", self.max_num_patches)
+ min_num_patches = images_kwargs.get("min_num_patches", self.min_num_patches)
+ pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale", self.pixel_shuffle_scale)
+
+ target_height, target_width = get_image_size_for_max_num_patches(
+ image_height,
+ image_width,
+ patch_size,
+ max_num_patches,
+ min_num_patches=min_num_patches,
+ pixel_shuffle_scale=pixel_shuffle_scale,
+ )
+ return (target_height // patch_size) * (target_width // patch_size)
+
+
+# --------------------------------Isaac Processor--------------------------------
+
+
+class IsaacProcessorKwargs(ProcessingKwargs, total=False):
+ images_kwargs = IsaacImageProcessorKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": True,
+ "truncation": True,
+ "truncation_side": "left",
+ "return_attention_mask": True,
+ "return_overflowing_tokens": True,
+ "return_mm_token_type_ids": True,
+ "add_special_tokens": False,
+ },
+ }
+
+
+class SinglePoint(NamedTuple):
+ x: int
+ y: int
+ mention: str | None = None
+ t: float | None = None
+
+
+class BoundingBox(NamedTuple):
+ top_left: Any
+ bottom_right: Any
+ mention: str | None = None
+ t: float | None = None
+
+
+class Polygon(NamedTuple):
+ points: tuple[Any, ...]
+ mention: str | None = None
+ t: float | None = None
+
+
+_point_box_or_polygon_tag = re.compile(
+ r"<(?Ppoint|point_box|polygon)(?P[^>]*)>(?P[\s\S]*?)(?P=tag)>", re.IGNORECASE
+)
+_attr_re = re.compile(r"(\w+)\s*=\s*(?:\"([^\"]*)\"|([^\s>]+))")
+_coord_re = re.compile(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)")
+
+
+@auto_docstring
+@requires(backends=("vision",))
+class IsaacProcessor(ProcessorMixin):
+ def __init__(
+ self,
+ image_processor,
+ tokenizer,
+ chat_template: str | dict[str, str] | None = None,
+ max_sequence_length: int = 16384,
+ ):
+ r"""
+ max_sequence_length (`int`, *optional*, defaults to 16384):
+ Maximum packed multimodal sequence length produced by the processor.
+ """
+ if chat_template is None:
+ chat_template = getattr(tokenizer, "chat_template", None)
+
+ self.pad_token_id = tokenizer.pad_token_id
+ self.image_token = getattr(tokenizer, "image_pad_token", None) or getattr(tokenizer, "image_token", None)
+ self.image_token_id = getattr(tokenizer, "image_pad_token_id", None) or getattr(
+ tokenizer, "image_token_id", None
+ )
+ self.max_sequence_length = max_sequence_length
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ text: str | list[str],
+ images: ImageInput | None = None,
+ **kwargs,
+ ) -> BatchFeature:
+ output_kwargs = self._merge_kwargs(
+ IsaacProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ # 1. Validate number of that text and images match
+ texts = [text] if isinstance(text, str) else text.copy()
+ rendered_image_token = ""
+ if self.image_token is not None and self.image_token != rendered_image_token:
+ # Isaac's current chat template still renders ``, while the tokenizer exposes
+ # `<|image_pad|>`. Normalize here so apply_chat_template(..., tokenize=True,
+ # return_dict=True) follows the standard ProcessorMixin path.
+ texts = [text_value.replace(rendered_image_token, self.image_token) for text_value in texts]
+ if images is None:
+ batched_images = [[] for _ in texts]
+ else:
+ fetched_images = self.image_processor.fetch_images(images)
+ batched_images = make_nested_list_of_images(fetched_images)
+ if len(batched_images) != len(texts):
+ num_images_in_text = [text_value.count(self.image_token) for text_value in texts]
+ num_images_in_images = [len(sample_images) for sample_images in batched_images]
+ add_message = ""
+ if sum(num_images_in_text) == sum(num_images_in_images):
+ add_message = " Make sure to pass your images as a nested list, where each sub-list holds images for one text sample."
+ raise ValueError(
+ f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(texts)}).{add_message}"
+ )
+
+ # 2. Process images
+ image_inputs = self.image_processor(images=batched_images, **output_kwargs["images_kwargs"])
+ image_grid_thw = image_inputs["image_grid_thw"]
+
+ # 3. Expand text with image placeholders
+ merge_length = self.image_processor.pixel_shuffle_scale**2
+ if image_grid_thw is None:
+ vision_segment_lengths = None
+ else:
+ vision_segment_lengths = image_grid_thw.prod(dim=-1) // merge_length
+ for batch_idx in range(len(texts)):
+ image_idx = 0
+ while self.image_token in texts[batch_idx]:
+ num_image_tokens = vision_segment_lengths[batch_idx, image_idx]
+ texts[batch_idx] = texts[batch_idx].replace(
+ self.image_token, "<|placeholder|>" * num_image_tokens, 1
+ )
+ image_idx += 1
+ texts[batch_idx] = texts[batch_idx].replace("<|placeholder|>", self.image_token)
+
+ # 4. Process text
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids")
+ max_length = output_kwargs["text_kwargs"].pop("max_length", None)
+ max_length = self.max_sequence_length if max_length is None else max_length
+ text_inputs = self.tokenizer(texts, max_length=max_length, **output_kwargs["text_kwargs"])
+
+ truncated_input_ids: list[list[int] | None] = [None] * len(texts)
+ truncated_attention_mask: list[list[int] | None] = [None] * len(texts)
+ offset_mappings = text_inputs.get("offset_mapping")
+ truncated_offset_mapping: list[list[list[int]] | None] | None = None
+ if offset_mappings is not None:
+ truncated_offset_mapping = [None] * len(texts)
+ overflow_input_ids_per_sample = defaultdict(int)
+
+ # 5. Drop overflowing token ids
+ if offset_mappings is None:
+ iterator = zip(
+ text_inputs["overflow_to_sample_mapping"], text_inputs["input_ids"], text_inputs["attention_mask"]
+ )
+ else:
+ iterator = zip(
+ text_inputs["overflow_to_sample_mapping"],
+ text_inputs["input_ids"],
+ text_inputs["attention_mask"],
+ offset_mappings,
+ )
+
+ for sample in iterator:
+ if offset_mappings is None:
+ batch_idx, input_ids, attention_mask = sample
+ offset_mapping = None
+ else:
+ batch_idx, input_ids, attention_mask, offset_mapping = sample
+
+ if truncated_input_ids[batch_idx] is None:
+ truncated_input_ids[batch_idx] = input_ids
+ truncated_attention_mask[batch_idx] = attention_mask
+ if truncated_offset_mapping is not None:
+ truncated_offset_mapping[batch_idx] = offset_mapping
+ else:
+ overflow_input_ids_per_sample[batch_idx] += input_ids.count(self.image_token_id)
+
+ # 6. Do the same for overflowing pixel values. Isaac truncates images based on `max_length`
+ # We can't really truncate pixels, so we pass over an image offset mask. Model will crop off
+ # truncated image pixels at run-time using this mask
+ image_metadata = None
+ if image_grid_thw is not None:
+ batch_size, max_images = image_grid_thw.shape[:2]
+ image_metadata = torch.zeros((batch_size, max_images, 2), dtype=torch.long)
+ for batch_idx, image_lengths in enumerate(vision_segment_lengths):
+ remaining_dropped = overflow_input_ids_per_sample[batch_idx]
+ for image_idx, length in enumerate(image_lengths):
+ offset = 0
+ if 0 < remaining_dropped < length:
+ offset = remaining_dropped
+ length -= offset
+ remaining_dropped = 0
+ elif remaining_dropped >= length:
+ dropped_length = length
+ length = 0
+ remaining_dropped -= dropped_length
+
+ # Record which suffix of this image's placeholder span survives left truncation.
+ # The model still encodes the full image and uses this window for both feature gathering and vision RoPE.
+ image_metadata[batch_idx, image_idx, 0] = offset
+ image_metadata[batch_idx, image_idx, 1] = length
+
+ data = {
+ "input_ids": torch.tensor(truncated_input_ids, dtype=torch.long),
+ "attention_mask": torch.tensor(truncated_attention_mask, dtype=torch.long),
+ "image_metadata": image_metadata,
+ **image_inputs,
+ }
+ if truncated_offset_mapping is not None:
+ data["offset_mapping"] = torch.tensor(truncated_offset_mapping, dtype=torch.long)
+
+ if return_mm_token_type_ids:
+ data["mm_token_type_ids"] = self.create_mm_token_type_ids(data["input_ids"])
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
+ vision_data = {}
+ if image_sizes is not None:
+ images_kwargs = dict(IsaacProcessorKwargs._defaults.get("images_kwargs", {}))
+ images_kwargs.update(kwargs)
+
+ num_image_patches = [
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
+ for image_size in image_sizes
+ ]
+ pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale") or self.image_processor.pixel_shuffle_scale
+ num_image_tokens = [num_patches // pixel_shuffle_scale**2 for num_patches in num_image_patches]
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
+
+ return MultiModalData(**vision_data)
+
+ @property
+ def model_input_names(self):
+ return super().model_input_names + ["mm_token_type_ids", "image_metadata"]
+
+ @staticmethod
+ def _maybe_float(value: str | None) -> float | None:
+ try:
+ return float(value)
+ except (ValueError, TypeError):
+ return None
+
+ @classmethod
+ def _parse_attrs(cls, attr_text: str) -> dict[str, str]:
+ attrs = {}
+ for match in _attr_re.finditer(attr_text):
+ key = match.group(1)
+ value = match.group(2) or match.group(3) or ""
+ attrs[key] = value
+ return attrs
+
+ @classmethod
+ def _parse_point_body(
+ cls,
+ body: str,
+ mention: str | None = None,
+ t: str | None = None,
+ ) -> Any:
+ match = _coord_re.search(body)
+ if not match:
+ raise ValueError(f"Malformed tag: {body!r}")
+ x, y = int(match.group(1)), int(match.group(2))
+ return SinglePoint(x=x, y=y, mention=mention, t=cls._maybe_float(t))
+
+ @classmethod
+ def _parse_box_body(
+ cls,
+ body: str,
+ mention: str | None = None,
+ t: str | None = None,
+ ) -> Any:
+ coords = list(_coord_re.finditer(body))
+ if len(coords) < 2:
+ raise ValueError(f"Malformed tag: {body!r}")
+
+ top_left = SinglePoint(x=int(coords[0].group(1)), y=int(coords[0].group(2)))
+ bottom_right = SinglePoint(x=int(coords[1].group(1)), y=int(coords[1].group(2)))
+ return BoundingBox(top_left=top_left, bottom_right=bottom_right, mention=mention, t=cls._maybe_float(t))
+
+ @classmethod
+ def _parse_polygon_body(
+ cls,
+ body: str,
+ mention: str | None = None,
+ t: str | None = None,
+ ) -> Any:
+ coords = list(_coord_re.finditer(body))
+ if len(coords) < 3:
+ raise ValueError(f"Malformed tag: {body!r}")
+
+ points = tuple(SinglePoint(x=int(coord.group(1)), y=int(coord.group(2))) for coord in coords)
+ return Polygon(points=points, mention=mention, t=cls._maybe_float(t))
+
+ @classmethod
+ def clean_text_and_extract_points(
+ cls,
+ text: str,
+ expected: str | None = None,
+ ) -> tuple[str, list[Any]]:
+ results: list[Any] = []
+ for match in _point_box_or_polygon_tag.finditer(text):
+ tag = match.group("tag").lower()
+ attrs = cls._parse_attrs(match.group("attrs"))
+ mention = attrs.get("mention")
+ t = attrs.get("t")
+ if tag == "point":
+ if expected not in (None, "point"):
+ continue
+ results.append(cls._parse_point_body(match.group("body"), mention=mention, t=t))
+ elif tag == "point_box":
+ if expected not in (None, "box"):
+ continue
+ results.append(cls._parse_box_body(match.group("body"), mention=mention, t=t))
+ else:
+ if expected not in (None, "polygon"):
+ continue
+ results.append(cls._parse_polygon_body(match.group("body"), mention=mention, t=t))
+
+ clean_text = re.sub(r"\s+", " ", _point_box_or_polygon_tag.sub("", text or "")).strip()
+ return clean_text, results
+
+ def post_process_generation(
+ self,
+ text: str,
+ expected: str | None = None,
+ cleanup_and_extract: bool = True,
+ ) -> str | tuple[str, list[Any]]:
+ if cleanup_and_extract:
+ return self.clean_text_and_extract_points(text, expected=expected)
+ return text
+
+ def post_process_image_text_to_text(
+ self,
+ generated_outputs,
+ skip_special_tokens: bool = True,
+ cleanup_and_extract: bool = False,
+ expected: str | None = None,
+ **kwargs,
+ ):
+ generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs)
+ return [
+ self.post_process_generation(text, expected=expected, cleanup_and_extract=cleanup_and_extract)
+ for text in generated_texts
+ ]
+
+
+__all__ = [
+ "IsaacConfig",
+ "IsaacTextConfig",
+ "IsaacTextModel",
+ "IsaacVisionConfig",
+ "IsaacVisionModel",
+ "IsaacModel",
+ "IsaacPreTrainedModel", # noqa: F822
+ "IsaacForConditionalGeneration",
+ "IsaacImageProcessor",
+ "IsaacProcessor",
+]
diff --git a/src/transformers/models/isaac/processing_isaac.py b/src/transformers/models/isaac/processing_isaac.py
new file mode 100644
index 000000000000..44dc5688e2a2
--- /dev/null
+++ b/src/transformers/models/isaac/processing_isaac.py
@@ -0,0 +1,356 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/isaac/modular_isaac.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_isaac.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2026 Perceptron, Inc and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from collections import defaultdict
+from typing import Any
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ImageInput, make_nested_list_of_images
+from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin
+from ...utils import auto_docstring
+from ...utils.import_utils import is_torch_available, requires
+from .image_processing_isaac import IsaacImageProcessorKwargs
+from .modeling_isaac import BoundingBox, Polygon, SinglePoint
+
+
+if is_torch_available():
+ import torch
+
+
+# --------------------------------Isaac Processor--------------------------------
+
+
+class IsaacProcessorKwargs(ProcessingKwargs, total=False):
+ images_kwargs = IsaacImageProcessorKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": True,
+ "truncation": True,
+ "truncation_side": "left",
+ "return_attention_mask": True,
+ "return_overflowing_tokens": True,
+ "return_mm_token_type_ids": True,
+ "add_special_tokens": False,
+ },
+ }
+
+
+_point_box_or_polygon_tag = re.compile(
+ r"<(?Ppoint|point_box|polygon)(?P[^>]*)>(?P[\s\S]*?)(?P=tag)>", re.IGNORECASE
+)
+_attr_re = re.compile(r"(\w+)\s*=\s*(?:\"([^\"]*)\"|([^\s>]+))")
+_coord_re = re.compile(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)")
+
+
+@auto_docstring
+@requires(backends=("vision",))
+class IsaacProcessor(ProcessorMixin):
+ def __init__(
+ self,
+ image_processor,
+ tokenizer,
+ chat_template: str | dict[str, str] | None = None,
+ max_sequence_length: int = 16384,
+ ):
+ r"""
+ max_sequence_length (`int`, *optional*, defaults to 16384):
+ Maximum packed multimodal sequence length produced by the processor.
+ """
+ if chat_template is None:
+ chat_template = getattr(tokenizer, "chat_template", None)
+
+ self.pad_token_id = tokenizer.pad_token_id
+ self.image_token = getattr(tokenizer, "image_pad_token", None) or getattr(tokenizer, "image_token", None)
+ self.image_token_id = getattr(tokenizer, "image_pad_token_id", None) or getattr(
+ tokenizer, "image_token_id", None
+ )
+ self.max_sequence_length = max_sequence_length
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ text: str | list[str],
+ images: ImageInput | None = None,
+ **kwargs,
+ ) -> BatchFeature:
+ output_kwargs = self._merge_kwargs(
+ IsaacProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ # 1. Validate number of that text and images match
+ texts = [text] if isinstance(text, str) else text.copy()
+ rendered_image_token = ""
+ if self.image_token is not None and self.image_token != rendered_image_token:
+ # Isaac's current chat template still renders ``, while the tokenizer exposes
+ # `<|image_pad|>`. Normalize here so apply_chat_template(..., tokenize=True,
+ # return_dict=True) follows the standard ProcessorMixin path.
+ texts = [text_value.replace(rendered_image_token, self.image_token) for text_value in texts]
+ if images is None:
+ batched_images = [[] for _ in texts]
+ else:
+ fetched_images = self.image_processor.fetch_images(images)
+ batched_images = make_nested_list_of_images(fetched_images)
+ if len(batched_images) != len(texts):
+ num_images_in_text = [text_value.count(self.image_token) for text_value in texts]
+ num_images_in_images = [len(sample_images) for sample_images in batched_images]
+ add_message = ""
+ if sum(num_images_in_text) == sum(num_images_in_images):
+ add_message = " Make sure to pass your images as a nested list, where each sub-list holds images for one text sample."
+ raise ValueError(
+ f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(texts)}).{add_message}"
+ )
+
+ # 2. Process images
+ image_inputs = self.image_processor(images=batched_images, **output_kwargs["images_kwargs"])
+ image_grid_thw = image_inputs["image_grid_thw"]
+
+ # 3. Expand text with image placeholders
+ merge_length = self.image_processor.pixel_shuffle_scale**2
+ if image_grid_thw is None:
+ vision_segment_lengths = None
+ else:
+ vision_segment_lengths = image_grid_thw.prod(dim=-1) // merge_length
+ for batch_idx in range(len(texts)):
+ image_idx = 0
+ while self.image_token in texts[batch_idx]:
+ num_image_tokens = vision_segment_lengths[batch_idx, image_idx]
+ texts[batch_idx] = texts[batch_idx].replace(
+ self.image_token, "<|placeholder|>" * num_image_tokens, 1
+ )
+ image_idx += 1
+ texts[batch_idx] = texts[batch_idx].replace("<|placeholder|>", self.image_token)
+
+ # 4. Process text
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids")
+ max_length = output_kwargs["text_kwargs"].pop("max_length", None)
+ max_length = self.max_sequence_length if max_length is None else max_length
+ text_inputs = self.tokenizer(texts, max_length=max_length, **output_kwargs["text_kwargs"])
+
+ truncated_input_ids: list[list[int] | None] = [None] * len(texts)
+ truncated_attention_mask: list[list[int] | None] = [None] * len(texts)
+ offset_mappings = text_inputs.get("offset_mapping")
+ truncated_offset_mapping: list[list[list[int]] | None] | None = None
+ if offset_mappings is not None:
+ truncated_offset_mapping = [None] * len(texts)
+ overflow_input_ids_per_sample = defaultdict(int)
+
+ # 5. Drop overflowing token ids
+ if offset_mappings is None:
+ iterator = zip(
+ text_inputs["overflow_to_sample_mapping"], text_inputs["input_ids"], text_inputs["attention_mask"]
+ )
+ else:
+ iterator = zip(
+ text_inputs["overflow_to_sample_mapping"],
+ text_inputs["input_ids"],
+ text_inputs["attention_mask"],
+ offset_mappings,
+ )
+
+ for sample in iterator:
+ if offset_mappings is None:
+ batch_idx, input_ids, attention_mask = sample
+ offset_mapping = None
+ else:
+ batch_idx, input_ids, attention_mask, offset_mapping = sample
+
+ if truncated_input_ids[batch_idx] is None:
+ truncated_input_ids[batch_idx] = input_ids
+ truncated_attention_mask[batch_idx] = attention_mask
+ if truncated_offset_mapping is not None:
+ truncated_offset_mapping[batch_idx] = offset_mapping
+ else:
+ overflow_input_ids_per_sample[batch_idx] += input_ids.count(self.image_token_id)
+
+ # 6. Do the same for overflowing pixel values. Isaac truncates images based on `max_length`
+ # We can't really truncate pixels, so we pass over an image offset mask. Model will crop off
+ # truncated image pixels at run-time using this mask
+ image_metadata = None
+ if image_grid_thw is not None:
+ batch_size, max_images = image_grid_thw.shape[:2]
+ image_metadata = torch.zeros((batch_size, max_images, 2), dtype=torch.long)
+ for batch_idx, image_lengths in enumerate(vision_segment_lengths):
+ remaining_dropped = overflow_input_ids_per_sample[batch_idx]
+ for image_idx, length in enumerate(image_lengths):
+ offset = 0
+ if 0 < remaining_dropped < length:
+ offset = remaining_dropped
+ length -= offset
+ remaining_dropped = 0
+ elif remaining_dropped >= length:
+ dropped_length = length
+ length = 0
+ remaining_dropped -= dropped_length
+
+ # Record which suffix of this image's placeholder span survives left truncation.
+ # The model still encodes the full image and uses this window for both feature gathering and vision RoPE.
+ image_metadata[batch_idx, image_idx, 0] = offset
+ image_metadata[batch_idx, image_idx, 1] = length
+
+ data = {
+ "input_ids": torch.tensor(truncated_input_ids, dtype=torch.long),
+ "attention_mask": torch.tensor(truncated_attention_mask, dtype=torch.long),
+ "image_metadata": image_metadata,
+ **image_inputs,
+ }
+ if truncated_offset_mapping is not None:
+ data["offset_mapping"] = torch.tensor(truncated_offset_mapping, dtype=torch.long)
+
+ if return_mm_token_type_ids:
+ data["mm_token_type_ids"] = self.create_mm_token_type_ids(data["input_ids"])
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
+ vision_data = {}
+ if image_sizes is not None:
+ images_kwargs = dict(IsaacProcessorKwargs._defaults.get("images_kwargs", {}))
+ images_kwargs.update(kwargs)
+
+ num_image_patches = [
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
+ for image_size in image_sizes
+ ]
+ pixel_shuffle_scale = images_kwargs.get("pixel_shuffle_scale") or self.image_processor.pixel_shuffle_scale
+ num_image_tokens = [num_patches // pixel_shuffle_scale**2 for num_patches in num_image_patches]
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
+
+ return MultiModalData(**vision_data)
+
+ @property
+ def model_input_names(self):
+ return super().model_input_names + ["mm_token_type_ids", "image_metadata"]
+
+ @staticmethod
+ def _maybe_float(value: str | None) -> float | None:
+ try:
+ return float(value)
+ except (ValueError, TypeError):
+ return None
+
+ @classmethod
+ def _parse_attrs(cls, attr_text: str) -> dict[str, str]:
+ attrs = {}
+ for match in _attr_re.finditer(attr_text):
+ key = match.group(1)
+ value = match.group(2) or match.group(3) or ""
+ attrs[key] = value
+ return attrs
+
+ @classmethod
+ def _parse_point_body(
+ cls,
+ body: str,
+ mention: str | None = None,
+ t: str | None = None,
+ ) -> Any:
+ match = _coord_re.search(body)
+ if not match:
+ raise ValueError(f"Malformed tag: {body!r}")
+ x, y = int(match.group(1)), int(match.group(2))
+ return SinglePoint(x=x, y=y, mention=mention, t=cls._maybe_float(t))
+
+ @classmethod
+ def _parse_box_body(
+ cls,
+ body: str,
+ mention: str | None = None,
+ t: str | None = None,
+ ) -> Any:
+ coords = list(_coord_re.finditer(body))
+ if len(coords) < 2:
+ raise ValueError(f"Malformed tag: {body!r}")
+
+ top_left = SinglePoint(x=int(coords[0].group(1)), y=int(coords[0].group(2)))
+ bottom_right = SinglePoint(x=int(coords[1].group(1)), y=int(coords[1].group(2)))
+ return BoundingBox(top_left=top_left, bottom_right=bottom_right, mention=mention, t=cls._maybe_float(t))
+
+ @classmethod
+ def _parse_polygon_body(
+ cls,
+ body: str,
+ mention: str | None = None,
+ t: str | None = None,
+ ) -> Any:
+ coords = list(_coord_re.finditer(body))
+ if len(coords) < 3:
+ raise ValueError(f"Malformed tag: {body!r}")
+
+ points = tuple(SinglePoint(x=int(coord.group(1)), y=int(coord.group(2))) for coord in coords)
+ return Polygon(points=points, mention=mention, t=cls._maybe_float(t))
+
+ @classmethod
+ def clean_text_and_extract_points(
+ cls,
+ text: str,
+ expected: str | None = None,
+ ) -> tuple[str, list[Any]]:
+ results: list[Any] = []
+ for match in _point_box_or_polygon_tag.finditer(text):
+ tag = match.group("tag").lower()
+ attrs = cls._parse_attrs(match.group("attrs"))
+ mention = attrs.get("mention")
+ t = attrs.get("t")
+ if tag == "point":
+ if expected not in (None, "point"):
+ continue
+ results.append(cls._parse_point_body(match.group("body"), mention=mention, t=t))
+ elif tag == "point_box":
+ if expected not in (None, "box"):
+ continue
+ results.append(cls._parse_box_body(match.group("body"), mention=mention, t=t))
+ else:
+ if expected not in (None, "polygon"):
+ continue
+ results.append(cls._parse_polygon_body(match.group("body"), mention=mention, t=t))
+
+ clean_text = re.sub(r"\s+", " ", _point_box_or_polygon_tag.sub("", text or "")).strip()
+ return clean_text, results
+
+ def post_process_generation(
+ self,
+ text: str,
+ expected: str | None = None,
+ cleanup_and_extract: bool = True,
+ ) -> str | tuple[str, list[Any]]:
+ if cleanup_and_extract:
+ return self.clean_text_and_extract_points(text, expected=expected)
+ return text
+
+ def post_process_image_text_to_text(
+ self,
+ generated_outputs,
+ skip_special_tokens: bool = True,
+ cleanup_and_extract: bool = False,
+ expected: str | None = None,
+ **kwargs,
+ ):
+ generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs)
+ return [
+ self.post_process_generation(text, expected=expected, cleanup_and_extract=cleanup_and_extract)
+ for text in generated_texts
+ ]
+
+
+__all__ = ["IsaacProcessor"]
diff --git a/tests/models/isaac/__init__.py b/tests/models/isaac/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/isaac/test_image_processing_isaac.py b/tests/models/isaac/test_image_processing_isaac.py
new file mode 100644
index 000000000000..7c7627805a5f
--- /dev/null
+++ b/tests/models/isaac/test_image_processing_isaac.py
@@ -0,0 +1,278 @@
+# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+
+import numpy as np
+
+from transformers.models.isaac.image_processing_isaac import get_image_size_for_max_num_patches
+from transformers.testing_utils import require_torch, require_vision
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
+
+
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ from PIL import Image
+
+
+def _make_dummy_image(size=(32, 32), color=(255, 0, 0)):
+ return Image.new("RGB", size, color=color)
+
+
+class IsaacImageProcessingTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ min_resolution=30,
+ max_resolution=80,
+ do_resize=True,
+ do_rescale=True,
+ rescale_factor=1 / 255,
+ do_normalize=True,
+ image_mean=None,
+ image_std=None,
+ patch_size=16,
+ max_num_patches=16,
+ min_num_patches=4,
+ pixel_shuffle_scale=1,
+ do_convert_rgb=True,
+ ):
+ if image_mean is None:
+ image_mean = [0.485, 0.456, 0.406]
+ if image_std is None:
+ image_std = [0.229, 0.224, 0.225]
+
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.do_resize = do_resize
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.patch_size = patch_size
+ self.max_num_patches = max_num_patches
+ self.min_num_patches = min_num_patches
+ self.pixel_shuffle_scale = pixel_shuffle_scale
+ self.do_convert_rgb = do_convert_rgb
+
+ @property
+ def patch_dim(self):
+ return self.num_channels * self.patch_size * self.patch_size
+
+ def prepare_image_processor_dict(self):
+ return {
+ "do_resize": self.do_resize,
+ "do_rescale": self.do_rescale,
+ "rescale_factor": self.rescale_factor,
+ "do_normalize": self.do_normalize,
+ "image_mean": self.image_mean,
+ "image_std": self.image_std,
+ "patch_size": self.patch_size,
+ "max_num_patches": self.max_num_patches,
+ "min_num_patches": self.min_num_patches,
+ "pixel_shuffle_scale": self.pixel_shuffle_scale,
+ "do_convert_rgb": self.do_convert_rgb,
+ }
+
+ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
+ images = prepare_image_inputs(
+ batch_size=self.batch_size,
+ min_resolution=self.min_resolution,
+ max_resolution=self.max_resolution,
+ num_channels=self.num_channels,
+ equal_resolution=equal_resolution,
+ numpify=numpify,
+ torchify=torchify,
+ )
+ return [[image] for image in images]
+
+ def expected_output_image_shape(self, images):
+ max_images = 0
+ max_patches = 0
+ for sample_images in images:
+ if not isinstance(sample_images, (list, tuple)):
+ sample_images = [sample_images]
+
+ max_images = max(max_images, len(sample_images))
+ for image in sample_images:
+ if isinstance(image, Image.Image):
+ width, height = image.size
+ elif isinstance(image, np.ndarray):
+ height, width = image.shape[:2]
+ else:
+ height, width = image.shape[-2:]
+
+ target_height, target_width = get_image_size_for_max_num_patches(
+ image_height=height,
+ image_width=width,
+ patch_size=self.patch_size,
+ max_num_patches=self.max_num_patches,
+ min_num_patches=self.min_num_patches,
+ pixel_shuffle_scale=self.pixel_shuffle_scale,
+ )
+ max_patches = max(max_patches, (target_height // self.patch_size) * (target_width // self.patch_size))
+
+ return (max_images, max_patches, self.patch_dim)
+
+
+@require_torch
+@require_vision
+class IsaacImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ self.image_processor_tester = IsaacImageProcessingTester(self)
+
+ @property
+ def image_processor_dict(self):
+ return self.image_processor_tester.prepare_image_processor_dict()
+
+ def test_call_pil(self):
+ for image_processing_class in self.image_processing_classes.values():
+ image_processing = image_processing_class(**self.image_processor_dict)
+ image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
+ for sample_images in image_inputs:
+ self.assertEqual(len(sample_images), 1)
+ self.assertIsInstance(sample_images[0], Image.Image)
+
+ encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
+ expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
+ self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
+
+ encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
+ expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
+ self.assertEqual(
+ tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
+ )
+
+ def test_call_numpy(self):
+ for image_processing_class in self.image_processing_classes.values():
+ image_processing = image_processing_class(**self.image_processor_dict)
+ image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
+ for sample_images in image_inputs:
+ self.assertEqual(len(sample_images), 1)
+ self.assertIsInstance(sample_images[0], np.ndarray)
+
+ encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
+ expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
+ self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
+
+ encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
+ expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
+ self.assertEqual(
+ tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
+ )
+
+ def test_call_pytorch(self):
+ for image_processing_class in self.image_processing_classes.values():
+ image_processing = image_processing_class(**self.image_processor_dict)
+ image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
+ for sample_images in image_inputs:
+ self.assertEqual(len(sample_images), 1)
+ self.assertIsInstance(sample_images[0], torch.Tensor)
+
+ encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
+ expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
+ self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
+
+ encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
+ expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
+ self.assertEqual(
+ tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
+ )
+
+ @unittest.skip(reason="Isaac image processor 4-channel coverage is not defined")
+ def test_call_numpy_4_channels(self):
+ pass
+
+ def test_flat_list_is_single_multi_image_sample(self):
+ for image_processing_class in self.image_processing_classes.values():
+ image_processor = image_processing_class(
+ **{
+ **self.image_processor_dict,
+ "do_resize": False,
+ "patch_size": 16,
+ "max_num_patches": 64,
+ "min_num_patches": 1,
+ "pixel_shuffle_scale": 1,
+ }
+ )
+ image_inputs = [
+ _make_dummy_image(size=(32, 32), color=(255, 0, 0)),
+ _make_dummy_image(size=(32, 32), color=(0, 255, 0)),
+ ]
+
+ encoding = image_processor(image_inputs, return_tensors="pt")
+ self.assertEqual(tuple(encoding["pixel_values"].shape), (1, 2, 4, 768))
+
+ expected_grids = torch.tensor([[[1, 2, 2], [1, 2, 2]]], dtype=torch.long)
+ torch.testing.assert_close(encoding["image_grid_thw"], expected_grids)
+
+ def test_nested_multi_image_batch_preserves_grids_and_padding(self):
+ for image_processing_class in self.image_processing_classes.values():
+ image_processor = image_processing_class(
+ **{
+ **self.image_processor_dict,
+ "do_resize": False,
+ "patch_size": 16,
+ "max_num_patches": 64,
+ "min_num_patches": 1,
+ "pixel_shuffle_scale": 1,
+ }
+ )
+ image_inputs = [
+ [_make_dummy_image(size=(32, 32), color=(255, 0, 0))],
+ [
+ _make_dummy_image(size=(48, 32), color=(0, 255, 0)),
+ _make_dummy_image(size=(32, 48), color=(0, 0, 255)),
+ ],
+ ]
+
+ encoding = image_processor(image_inputs, return_tensors="pt")
+ self.assertEqual(tuple(encoding["pixel_values"].shape), (2, 2, 6, 768))
+
+ expected_grids = torch.tensor(
+ [
+ [[1, 2, 2], [0, 0, 0]],
+ [[1, 2, 3], [1, 3, 2]],
+ ],
+ dtype=torch.long,
+ )
+
+ torch.testing.assert_close(encoding["image_grid_thw"], expected_grids)
+ self.assertTrue(torch.all(encoding["pixel_values"][0, 1] == 0))
+
+ def test_pixel_shuffle_scale_requires_divisible_token_grid(self):
+ for image_processing_class in self.image_processing_classes.values():
+ image_processor = image_processing_class(
+ **{
+ **self.image_processor_dict,
+ "do_resize": False,
+ "patch_size": 16,
+ "pixel_shuffle_scale": 2,
+ }
+ )
+
+ with self.assertRaisesRegex(ValueError, "must be divisible by pixel_shuffle_scale"):
+ image_processor([[_make_dummy_image(size=(32, 16))]], return_tensors="pt")
diff --git a/tests/models/isaac/test_modeling_isaac.py b/tests/models/isaac/test_modeling_isaac.py
new file mode 100644
index 000000000000..ccfb83411f66
--- /dev/null
+++ b/tests/models/isaac/test_modeling_isaac.py
@@ -0,0 +1,893 @@
+# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Testing suite for the Isaac model."""
+
+import base64
+import io
+import os
+import unittest
+from functools import lru_cache
+from pathlib import Path
+
+import pytest
+from huggingface_hub import is_offline_mode
+
+from tests.generation.test_utils import (
+ GenerationTesterMixin,
+)
+from tests.test_configuration_common import ConfigTester
+from tests.test_pipeline_mixin import PipelineTesterMixin
+from transformers import (
+ IsaacConfig,
+ IsaacForConditionalGeneration,
+ IsaacModel,
+ is_torch_available,
+)
+from transformers.models.isaac.processing_isaac import IsaacProcessor
+from transformers.testing_utils import (
+ require_flash_attn,
+ require_torch,
+ require_vision,
+ slow,
+ torch_device,
+)
+from transformers.utils import is_vision_available
+
+
+if is_vision_available():
+ from PIL import Image
+else:
+ Image = None
+
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
+
+
+if is_torch_available():
+ import torch
+
+
+BASE_MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1-Base")
+MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1")
+
+BASE_MODEL_REVISION = os.environ.get("ISAAC_TEST_MODEL_REVISION", "refs/pr/3") or None
+MODEL_REVISION = os.environ.get("ISAAC_TEST_MODEL_REVISION", "refs/pr/5") or None
+
+LOCAL_CHECKPOINT = os.environ.get("ISAAC_TEST_MODEL_PATH")
+RED_DOT_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
+ISAAC_IMAGE_TOKEN = "<|image_pad|>"
+
+
+def compute_logits_statistics(tensor: torch.Tensor) -> dict[str, object]:
+ """
+ Summarize logits with simple statistics that are stable across minor
+ implementation changes yet still sensitive to behavioral regressions.
+ """
+
+ float_tensor = tensor.detach().to(torch.float32).cpu()
+ flat = float_tensor.reshape(-1).to(torch.float64)
+
+ def _rounded(value: torch.Tensor | float) -> float:
+ return round(float(value), 10)
+
+ return {
+ "shape": list(float_tensor.shape),
+ "numel": flat.numel(),
+ "mean": _rounded(flat.mean()),
+ "std": _rounded(flat.std(unbiased=False)),
+ "min": _rounded(flat.min()),
+ "max": _rounded(flat.max()),
+ "sum": _rounded(flat.sum()),
+ "l2_norm": _rounded(torch.linalg.vector_norm(flat, ord=2)),
+ }
+
+
+def pack_image_inputs(pixel_values, image_token_grids, image_token_offsets=None, image_token_lengths=None):
+ batch_size, max_images, _, _ = pixel_values.shape
+ device = pixel_values.device
+
+ if image_token_offsets is None:
+ image_token_offsets = torch.zeros((batch_size, max_images), device=device, dtype=torch.long)
+ if image_token_lengths is None:
+ image_token_lengths = image_token_grids[..., 0] * image_token_grids[..., 1]
+
+ image_grid_thw = torch.zeros((batch_size, max_images, 3), device=device, dtype=torch.long)
+ active_slots = image_token_grids.prod(dim=-1).gt(0)
+ image_grid_thw[..., 0] = active_slots.to(dtype=torch.long)
+ image_grid_thw[..., 1:] = image_token_grids
+
+ image_metadata = torch.stack(
+ (
+ image_token_offsets.to(device=device, dtype=torch.long),
+ image_token_lengths.to(device=device, dtype=torch.long),
+ ),
+ dim=-1,
+ )
+
+ return pixel_values, image_grid_thw, image_metadata
+
+
+@lru_cache(maxsize=1)
+def _load_red_dot_image():
+ if Image is None:
+ return None
+ data = base64.b64decode(RED_DOT_B64)
+ return Image.open(io.BytesIO(data)).convert("RGB")
+
+
+def _base_reference_checkpoint_or_skip():
+ if LOCAL_CHECKPOINT:
+ resolved = Path(LOCAL_CHECKPOINT).expanduser()
+ if not resolved.exists():
+ pytest.skip(f"Local checkpoint path {resolved} does not exist.")
+ return str(resolved)
+ if is_offline_mode():
+ pytest.skip("Offline mode: set ISAAC_TEST_MODEL_PATH to a local checkpoint to run these tests.")
+ return BASE_MODEL_ID
+
+
+def _reference_checkpoint_or_skip():
+ if LOCAL_CHECKPOINT:
+ resolved = Path(LOCAL_CHECKPOINT).expanduser()
+ if not resolved.exists():
+ pytest.skip(f"Local checkpoint path {resolved} does not exist.")
+ return str(resolved)
+ if is_offline_mode():
+ pytest.skip("Offline mode: set ISAAC_TEST_MODEL_PATH to a local checkpoint to run these tests.")
+ return MODEL_ID
+
+
+class IsaacModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=2,
+ seq_length=5,
+ vocab_size=99,
+ hidden_size=32,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.is_training = True
+ self.expected_num_hidden_layers = num_hidden_layers + 1
+
+ self.text_config = {
+ "bos_token_id": 0,
+ "eos_token_id": 1,
+ "pad_token_id": 2,
+ "hidden_act": "silu",
+ "head_dim": hidden_size // num_attention_heads,
+ "hidden_size": hidden_size,
+ "vocab_size": vocab_size,
+ "intermediate_size": hidden_size * 3,
+ "max_position_embeddings": 128,
+ "model_type": "qwen3",
+ "num_attention_heads": num_attention_heads,
+ "num_hidden_layers": num_hidden_layers,
+ "num_key_value_heads": num_attention_heads,
+ # Keep the same multi-RoPE setup as the reference checkpoints but shrink the
+ # sections so they sum to the rotary half-dimension (4) of this tiny test model.
+ "rope_parameters": {"rope_type": "default", "mrope_section": [2, 1, 1], "mrope_interleaved": True},
+ "tie_word_embeddings": True,
+ }
+
+ self.vision_config = {
+ "hidden_size": hidden_size,
+ "intermediate_size": hidden_size * 2,
+ "num_hidden_layers": 1,
+ "num_attention_heads": num_attention_heads,
+ "num_channels": 3,
+ "num_patches": 64,
+ "patch_size": 4,
+ "pixel_shuffle_scale_factor": 1,
+ "attention_dropout": 0.0,
+ "layer_norm_eps": 1e-6,
+ }
+
+ def get_config(self):
+ config = IsaacConfig(
+ text_config=self.text_config,
+ vision_config=self.vision_config,
+ )
+ # Rely on eager attention so output_attentions tests remain compatible without flash attention.
+ config._attn_implementation = "eager"
+ config.text_config._attn_implementation = "eager"
+ config.vision_attn_implementation = "eager"
+ return config
+
+ def prepare_config_and_inputs(self):
+ config = self.get_config()
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ attention_mask = torch.ones(
+ (self.batch_size, self.seq_length),
+ dtype=torch.long,
+ device=torch_device,
+ )
+ labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ return config, input_ids, attention_mask, labels
+
+ def prepare_config_and_inputs_for_common(self):
+ config, input_ids, attention_mask, labels = self.prepare_config_and_inputs()
+ patch_size = self.vision_config["patch_size"]
+ patch_dim = self.vision_config["num_channels"] * patch_size * patch_size
+ num_image_patches = 4
+ vision_patches = torch.randn(
+ (self.batch_size, 1, num_image_patches, patch_dim), device=torch_device, dtype=torch.float32
+ )
+ image_token_grids = torch.tensor([[[2, 2]]] * self.batch_size, device=torch_device, dtype=torch.long)
+ pixel_values, image_grid_thw, image_metadata = pack_image_inputs(
+ pixel_values=vision_patches,
+ image_token_grids=image_token_grids,
+ )
+ mm_token_type_ids = torch.zeros((self.batch_size, self.seq_length), device=torch_device, dtype=torch.long)
+ mm_token_type_ids[:, :num_image_patches] = 1
+ inputs_dict = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "mm_token_type_ids": mm_token_type_ids,
+ "pixel_values": pixel_values,
+ "image_grid_thw": image_grid_thw,
+ "image_metadata": image_metadata,
+ }
+ if labels is not None:
+ inputs_dict["labels"] = labels
+ return config, inputs_dict
+
+
+@require_torch
+class IsaacModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ all_model_classes = (IsaacModel, IsaacForConditionalGeneration) if is_torch_available() else ()
+ pipeline_model_mapping = (
+ {
+ "image-to-text": IsaacForConditionalGeneration,
+ "image-text-to-text": IsaacForConditionalGeneration,
+ }
+ if is_torch_available()
+ else {}
+ )
+ _is_composite = True
+ test_attention_outputs = False
+ test_all_params_have_gradient = False
+
+ def setUp(self):
+ self.model_tester = IsaacModelTester(self)
+ self.config_tester = ConfigTester(
+ self,
+ config_class=IsaacConfig,
+ has_text_modality=False,
+ )
+
+ def test_config(self):
+ self.maxDiff = None
+ self.config_tester.run_common_tests()
+
+ def prepare_config_and_inputs_for_generate(self, batch_size=2):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ input_keys_to_ignore = [
+ "decoder_input_ids",
+ "decoder_attention_mask",
+ "use_cache",
+ "labels",
+ ]
+
+ filtered_inputs_dict = {
+ k: v[:batch_size, ...]
+ if isinstance(v, torch.Tensor) and k not in ["pixel_values", "image_grid_thw", "image_metadata"]
+ else v
+ for k, v in inputs_dict.items()
+ if k not in input_keys_to_ignore
+ }
+
+ filtered_inputs_dict["pixel_values"] = inputs_dict["pixel_values"][:batch_size]
+ filtered_inputs_dict["image_grid_thw"] = inputs_dict["image_grid_thw"][:batch_size]
+ filtered_inputs_dict["image_metadata"] = inputs_dict["image_metadata"][:batch_size]
+
+ text_gen_config = config.get_text_config(decoder=True)
+ if text_gen_config.eos_token_id is not None and text_gen_config.pad_token_id is None:
+ text_gen_config.pad_token_id = (
+ text_gen_config.eos_token_id
+ if isinstance(text_gen_config.eos_token_id, int)
+ else text_gen_config.eos_token_id[0]
+ )
+ text_gen_config.eos_token_id = None
+ text_gen_config.forced_eos_token_id = None
+
+ return config, filtered_inputs_dict
+
+ @pytest.mark.generate
+ def test_left_padding_compatibility(self):
+ _, inputs_dict = self.prepare_config_and_inputs_for_generate()
+ mm_token_type_ids = inputs_dict["mm_token_type_ids"]
+ pad_size = (mm_token_type_ids.shape[0], 32)
+ padded_mm_token_type_ids = torch.cat(
+ (torch.zeros(pad_size, dtype=mm_token_type_ids.dtype, device=torch_device), mm_token_type_ids), dim=1
+ )
+
+ super().test_left_padding_compatibility(
+ unpadded_custom_inputs={"mm_token_type_ids": mm_token_type_ids},
+ padded_custom_inputs={"mm_token_type_ids": padded_mm_token_type_ids},
+ )
+
+ @unittest.skip(reason="Assisted decoding not supported; Qwen3 backbone does not implement returning attentions")
+ def test_assisted_decoding_matches_greedy_search_0_random(self):
+ pass
+
+ @unittest.skip(reason="Assisted decoding not supported; Qwen3 backbone does not implement returning attentions")
+ def test_assisted_decoding_matches_greedy_search_1_same(self):
+ pass
+
+ @unittest.skip(reason="Unsupported")
+ def test_flash_attn_kernels_inference_equivalence(self):
+ pass
+
+ @unittest.skip(reason="Isaac is image-only.")
+ def test_get_video_features_output_0(self):
+ pass
+
+ @unittest.skip(reason="Isaac is image-only.")
+ def test_get_video_features_output_1(self):
+ pass
+
+ @unittest.skip(reason="Isaac is image-only.")
+ def test_get_video_features_output_2(self):
+ pass
+
+ @unittest.skip(reason="Isaac is image-only.")
+ def test_get_video_features_hidden_states(self):
+ pass
+
+ @unittest.skip(reason="Isaac is image-only.")
+ def test_get_video_features_attentions(self):
+ pass
+
+
+@require_torch
+@require_vision
+@slow
+@require_flash_attn
+class IsaacGenerationIntegrationTest(unittest.TestCase):
+ max_new_tokens = 25
+ dtype = torch.bfloat16
+
+ def setUp(self):
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ self.checkpoint = _base_reference_checkpoint_or_skip()
+ self.hf_config = IsaacConfig.from_pretrained(self.checkpoint, revision=BASE_MODEL_REVISION)
+ self.processor = IsaacProcessor.from_pretrained(self.checkpoint, revision=BASE_MODEL_REVISION, do_pad=True)
+ self.tokenizer = self.processor.tokenizer
+ self.hf_config.vision_config._attn_implementation = "flash_attention_2"
+ self.hf_config.vision_config.attn_implementation = "flash_attention_2"
+ self.model = IsaacForConditionalGeneration.from_pretrained(
+ self.checkpoint, config=self.hf_config, revision=BASE_MODEL_REVISION
+ )
+ self.model = self.model.to(device=self.device, dtype=self.dtype)
+ self.model.eval()
+
+ def test_generate_from_image_text(self):
+ image = _load_red_dot_image()
+ if image is None:
+ pytest.skip("PIL.Image is required for Isaac generation tests.")
+
+ conversation = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Describe this image:"},
+ {"type": "image", "image": image},
+ ],
+ }
+ ]
+ inputs = self.processor.apply_chat_template(
+ conversation,
+ tokenize=True,
+ return_dict=True,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ ).to(self.device, dtype=self.dtype)
+
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **inputs,
+ max_new_tokens=self.max_new_tokens,
+ do_sample=False,
+ pad_token_id=self.tokenizer.eos_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ return_dict_in_generate=True,
+ )
+
+ generated_ids = outputs.sequences[:, inputs["input_ids"].shape[1] :]
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ expected_fragment = "The image is a close-up photograph of a red cross symbol."
+ assert expected_fragment in generated_text
+
+ def test_generate_from_text_only(self):
+ conversation = [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": "What is the pythogorean theorem?"}],
+ }
+ ]
+ inputs = self.processor.apply_chat_template(
+ conversation,
+ tokenize=True,
+ return_dict=True,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ ).to(self.device, dtype=self.dtype)
+
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **inputs,
+ max_new_tokens=100,
+ do_sample=False,
+ pad_token_id=self.tokenizer.eos_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ return_dict_in_generate=True,
+ )
+
+ generated_ids = outputs.sequences[:, inputs["input_ids"].shape[1] :]
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ expected_fragmenet = "The Pythagorean theorem is a fundamental principle in geometry that relates the lengths of the sides of a right-angled triangle. Let's break it down step by step:"
+ assert expected_fragmenet in generated_text
+
+ def test_vqa_from_image(self):
+ conversation = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "url": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp",
+ },
+ {"type": "text", "text": "Is it safe to cross the street at this moment?"},
+ ],
+ }
+ ]
+ inputs = self.processor.apply_chat_template(
+ conversation,
+ tokenize=True,
+ return_dict=True,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ ).to(self.device, dtype=self.dtype)
+
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **inputs,
+ max_new_tokens=256,
+ do_sample=False,
+ pad_token_id=self.tokenizer.eos_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ return_dict_in_generate=True,
+ )
+
+ generated_ids = outputs.sequences[:, inputs["input_ids"].shape[1] :]
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ expected_response = "\nNo, it is not safe to cross the street at this moment. The traffic light for pedestrians is red, indicating that it is not safe to cross."
+ assert generated_text == expected_response
+
+ def test_logit_equivalence(self):
+ image = _load_red_dot_image()
+ if image is None:
+ pytest.skip("PIL.Image is required for Isaac generation tests.")
+
+ conversation = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Describe this image:"},
+ {"type": "image", "image": image},
+ ],
+ }
+ ]
+ inputs = self.processor.apply_chat_template(
+ conversation,
+ tokenize=True,
+ return_dict=True,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ ).to(self.device, dtype=self.dtype)
+
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **inputs,
+ max_new_tokens=10,
+ do_sample=False,
+ pad_token_id=self.tokenizer.eos_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ return_dict_in_generate=True,
+ output_logits=True,
+ )
+
+ hf_logits = torch.cat(outputs.logits, dim=0)
+ logit_stats = compute_logits_statistics(hf_logits)
+ expected_logit_stats = {
+ "shape": [10, 151936],
+ "numel": 1519360,
+ "mean": 0.0608877803,
+ "std": 2.8308793244,
+ "min": -12.0625,
+ "max": 31.0,
+ "sum": 92510.4578057677,
+ "l2_norm": 3490.2146142251,
+ }
+ assert logit_stats == expected_logit_stats
+
+ def test_batched_generation_matches_individual(self):
+ image = _load_red_dot_image()
+ if image is None:
+ pytest.skip("PIL.Image is required for Isaac generation tests.")
+
+ conversations = [
+ [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": "What is the pythogorean theorem?"}],
+ }
+ ],
+ [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Describe this image:"},
+ {"type": "image", "image": image},
+ ],
+ }
+ ],
+ [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "url": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp",
+ },
+ {"type": "text", "text": "Is it safe to cross the street at this moment?"},
+ ],
+ }
+ ],
+ ]
+
+ single_inputs = [
+ self.processor.apply_chat_template(
+ conversation,
+ tokenize=True,
+ return_dict=True,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ )
+ for conversation in conversations
+ ]
+ batch_inputs = self.processor.apply_chat_template(
+ conversations,
+ tokenize=True,
+ return_dict=True,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ processor_kwargs={"padding_side": "left"},
+ )
+ batch_input_ids = batch_inputs["input_ids"]
+ max_length = batch_input_ids.shape[1]
+
+ pad_id = self.tokenizer.pad_token_id
+ if pad_id is None:
+ pad_id = getattr(self.processor, "pad_token_id", 0)
+
+ sample_lengths = [single_input["input_ids"].squeeze(0).shape[0] for single_input in single_inputs]
+ for i, (single_input, batch_ids, single_len) in enumerate(zip(single_inputs, batch_input_ids, sample_lengths)):
+ single_ids = single_input["input_ids"].squeeze(0)
+ torch.testing.assert_close(batch_ids[-single_len:], single_ids)
+
+ batch_modality_row = batch_inputs["mm_token_type_ids"][i]
+ expected_modality = torch.full(
+ (max_length,),
+ batch_modality_row[-1].item(),
+ dtype=batch_modality_row.dtype,
+ device=batch_modality_row.device,
+ )
+ expected_modality[-single_len:] = single_input["mm_token_type_ids"].squeeze(0)
+ torch.testing.assert_close(batch_modality_row, expected_modality)
+
+ if batch_inputs["image_grid_thw"] is not None:
+ batch_image_mask = batch_inputs["image_grid_thw"][i, :, 0].eq(1)
+ expected_image_count = int(batch_image_mask.sum().item())
+ if single_input["image_grid_thw"] is None:
+ assert expected_image_count == 0
+ else:
+ single_image_mask = single_input["image_grid_thw"][0, :, 0].eq(1)
+ assert expected_image_count == int(single_image_mask.sum().item())
+ if expected_image_count > 0:
+ batch_image_grid_thw = batch_inputs["image_grid_thw"][i, batch_image_mask]
+ single_image_grid_thw = single_input["image_grid_thw"][0, single_image_mask]
+ batch_image_metadata = batch_inputs["image_metadata"][i, batch_image_mask]
+ single_image_metadata = single_input["image_metadata"][0, single_image_mask]
+
+ torch.testing.assert_close(batch_image_grid_thw, single_image_grid_thw)
+ torch.testing.assert_close(batch_image_metadata, single_image_metadata)
+
+ for batch_pixel_values, single_pixel_values, grid_thw in zip(
+ batch_inputs["pixel_values"][i, batch_image_mask],
+ single_input["pixel_values"][0, single_image_mask],
+ batch_image_grid_thw,
+ strict=True,
+ ):
+ valid_patch_count = int((grid_thw[1] * grid_thw[2]).item())
+ torch.testing.assert_close(
+ batch_pixel_values[:valid_patch_count],
+ single_pixel_values[:valid_patch_count],
+ )
+
+ if single_len == max_length:
+ continue
+
+ pad_span = batch_ids[: max_length - single_len]
+ assert torch.all(pad_span == pad_id), f"sample {i} left pad span not padded with pad id"
+ torch.testing.assert_close(
+ batch_inputs["attention_mask"][i],
+ batch_ids.ne(pad_id).long(),
+ )
+
+ single_texts = []
+ for single_input in single_inputs:
+ single_input = single_input.to(self.device, dtype=self.dtype)
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **single_input,
+ max_new_tokens=self.max_new_tokens,
+ do_sample=False,
+ pad_token_id=self.tokenizer.eos_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ return_dict_in_generate=True,
+ )
+ generated_ids = outputs.sequences[:, single_input["input_ids"].shape[1] :]
+ single_texts.append(self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0])
+
+ batch_inputs = batch_inputs.to(self.device, dtype=self.dtype)
+ with torch.no_grad():
+ batch_outputs = self.model.generate(
+ **batch_inputs,
+ max_new_tokens=100,
+ do_sample=False,
+ pad_token_id=self.tokenizer.eos_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ return_dict_in_generate=True,
+ )
+ batch_generated_ids = batch_outputs.sequences[:, batch_inputs["input_ids"].shape[1] :]
+ batch_texts = self.processor.batch_decode(batch_generated_ids, skip_special_tokens=True)
+ assert len(batch_texts) == len(single_texts) == 3
+
+ for i, (batch_text, single_text) in enumerate(zip(batch_texts, single_texts)):
+ assert single_text in batch_text, f"batch[{i}] mismatch: {batch_text!r} vs single[{i}] {single_text!r}"
+
+ def test_batched_beam_generation_matches_individual(self):
+ image = _load_red_dot_image()
+ if image is None:
+ pytest.skip("PIL.Image is required for Isaac generation tests.")
+
+ conversations = [
+ [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": "What is the pythogorean theorem?"}],
+ }
+ ],
+ [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Describe this image:"},
+ {"type": "image", "image": image},
+ ],
+ }
+ ],
+ [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "url": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp",
+ },
+ {"type": "text", "text": "Is it safe to cross the street at this moment?"},
+ ],
+ }
+ ],
+ ]
+ beam_kwargs = {"num_beams": 2}
+
+ single_texts = []
+ for conversation in conversations:
+ single_input = self.processor.apply_chat_template(
+ conversation,
+ tokenize=True,
+ return_dict=True,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ ).to(self.device, dtype=self.dtype)
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **single_input,
+ max_new_tokens=self.max_new_tokens,
+ do_sample=False,
+ pad_token_id=self.tokenizer.eos_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ return_dict_in_generate=True,
+ **beam_kwargs,
+ )
+ generated_ids = outputs.sequences[:, single_input["input_ids"].shape[1] :]
+ single_texts.append(self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0])
+
+ batch_inputs = self.processor.apply_chat_template(
+ conversations,
+ tokenize=True,
+ return_dict=True,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ processor_kwargs={"padding_side": "left"},
+ ).to(self.device, dtype=self.dtype)
+ with torch.no_grad():
+ batch_outputs = self.model.generate(
+ **batch_inputs,
+ max_new_tokens=100,
+ do_sample=False,
+ pad_token_id=self.tokenizer.eos_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ return_dict_in_generate=True,
+ **beam_kwargs,
+ )
+ batch_generated_ids = batch_outputs.sequences[:, batch_inputs["input_ids"].shape[1] :]
+ batch_texts = self.processor.batch_decode(batch_generated_ids, skip_special_tokens=True)
+ assert len(batch_texts) == len(single_texts) == 3
+
+ for i, (batch_text, single_text) in enumerate(zip(batch_texts, single_texts)):
+ assert single_text in batch_text, (
+ f"beam batch[{i}] mismatch: {batch_text!r} vs single[{i}] {single_text!r}"
+ )
+
+
+@require_torch
+@require_vision
+@slow
+@require_flash_attn
+class IsaacBoxPointingIntegrationTest(unittest.TestCase):
+ max_new_tokens = 256
+ dtype = torch.bfloat16
+
+ def setUp(self):
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ self.checkpoint = _reference_checkpoint_or_skip()
+ self.hf_config = IsaacConfig.from_pretrained(self.checkpoint, revision=MODEL_REVISION)
+ # The current local slow fallback only supports padded packing for this checkpoint.
+ self.processor = IsaacProcessor.from_pretrained(self.checkpoint, revision=MODEL_REVISION, do_pad=True)
+ self.tokenizer = self.processor.tokenizer
+ self.hf_config.vision_config._attn_implementation = "flash_attention_2"
+ self.hf_config.vision_config.attn_implementation = "flash_attention_2"
+ self.model = IsaacForConditionalGeneration.from_pretrained(
+ self.checkpoint, config=self.hf_config, revision=MODEL_REVISION
+ )
+ self.model = self.model.to(device=self.device, dtype=self.dtype)
+ self.model.eval()
+
+ def test_hf_generate_box_points(self):
+ conversation = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "BOX"},
+ {
+ "type": "image",
+ "url": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp",
+ },
+ {
+ "type": "text",
+ "text": "Determine whether it is safe to cross the street. Look for signage and moving traffic.",
+ },
+ ],
+ }
+ ]
+ inputs = self.processor.apply_chat_template(
+ conversation,
+ tokenize=True,
+ return_dict=True,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ ).to(self.device, dtype=self.dtype)
+
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **inputs,
+ max_new_tokens=self.max_new_tokens,
+ do_sample=False,
+ pad_token_id=self.tokenizer.eos_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ return_dict_in_generate=True,
+ )
+
+ generated_ids = outputs.sequences[:, inputs["input_ids"].shape[1] :]
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ _, points = self.processor.post_process_generation(generated_text, expected="box")
+ assert len(points) == 1
+ first_point = points[0]
+ assert first_point.top_left.x < first_point.bottom_right.x
+ assert first_point.top_left.y < first_point.bottom_right.y
+ assert first_point.mention == "traffic light"
+ assert first_point.top_left.x == 808
+ assert first_point.top_left.y == 247
+ assert first_point.bottom_right.x == 863
+ assert first_point.bottom_right.y == 386
+
+ def test_hf_generate_polygon_points(self):
+ conversation = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "POLYGON"},
+ {
+ "type": "image",
+ "url": "https://raw.githubusercontent.com/perceptron-ai-inc/perceptron/refs/heads/main/huggingface/assets/example.webp",
+ },
+ {
+ "type": "text",
+ "text": "Determine whether it is safe to cross the street. Look for signage and moving traffic.",
+ },
+ ],
+ }
+ ]
+ inputs = self.processor.apply_chat_template(
+ conversation,
+ tokenize=True,
+ return_dict=True,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ ).to(self.device, dtype=self.dtype)
+
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **inputs,
+ max_new_tokens=self.max_new_tokens,
+ do_sample=False,
+ pad_token_id=self.tokenizer.eos_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ return_dict_in_generate=True,
+ )
+
+ generated_ids = outputs.sequences[:, inputs["input_ids"].shape[1] :]
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ _, polygons = self.processor.post_process_generation(generated_text, expected="polygon")
+ assert len(polygons) == 1
+ first_polygon = polygons[0]
+ xs = [point.x for point in first_polygon.points]
+ ys = [point.y for point in first_polygon.points]
+ expected_left, expected_top, expected_right, expected_bottom = 808, 247, 863, 386
+
+ assert len(first_polygon.points) >= 3
+ assert first_polygon.mention == "traffic light"
+ assert min(xs) >= expected_left - 4
+ assert max(xs) <= expected_right + 4
+ assert min(ys) >= expected_top - 4
+ assert max(ys) <= expected_bottom + 4
+ assert max(xs) - min(xs) >= 35
+ assert max(ys) - min(ys) >= 100
+ assert any(abs(x - expected_left) <= 12 for x in xs)
+ assert any(abs(x - expected_right) <= 12 for x in xs)
+ assert any(abs(y - expected_top) <= 12 for y in ys)
+ assert any(abs(y - expected_bottom) <= 12 for y in ys)
diff --git a/tests/models/isaac/test_processing_isaac.py b/tests/models/isaac/test_processing_isaac.py
new file mode 100644
index 000000000000..7c2bbc3f6048
--- /dev/null
+++ b/tests/models/isaac/test_processing_isaac.py
@@ -0,0 +1,172 @@
+# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Testing suite for the Isaac processor."""
+
+import os
+import unittest
+from pathlib import Path
+
+import numpy as np
+import pytest
+from huggingface_hub import is_offline_mode
+
+from transformers.models.isaac.processing_isaac import IsaacProcessor
+from transformers.testing_utils import require_torch, require_vision
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_processing_common import ProcessorTesterMixin
+
+
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ from PIL import Image
+else:
+ Image = None
+
+
+def _make_dummy_image(size=(32, 32), color=(255, 0, 0)):
+ if Image is None:
+ raise RuntimeError("PIL.Image is not available in this environment.")
+ return Image.new("RGB", size, color=color)
+
+
+BASE_MODEL_ID = os.environ.get("ISAAC_TEST_MODEL_ID", "PerceptronAI/Isaac-0.1-Base")
+BASE_MODEL_REVISION = os.environ.get("ISAAC_TEST_MODEL_REVISION", "refs/pr/3") or None
+LOCAL_CHECKPOINT = os.environ.get("ISAAC_TEST_MODEL_PATH")
+
+
+def _checkpoint_or_skip(model_id=BASE_MODEL_ID):
+ if LOCAL_CHECKPOINT:
+ resolved = Path(LOCAL_CHECKPOINT).expanduser()
+ if not resolved.exists():
+ pytest.skip(f"Local checkpoint path {resolved} does not exist.")
+ return str(resolved)
+ if is_offline_mode():
+ pytest.skip("Offline mode: set ISAAC_TEST_MODEL_PATH to a local checkpoint to run these tests.")
+ return model_id
+
+
+@require_torch
+@require_vision
+class IsaacProcessorTest(ProcessorTesterMixin, unittest.TestCase):
+ processor_class = IsaacProcessor
+ model_id = BASE_MODEL_ID
+ images_input_name = "pixel_values"
+
+ @classmethod
+ def _setup_from_pretrained(cls, model_id, **kwargs):
+ checkpoint = _checkpoint_or_skip(model_id)
+ return super()._setup_from_pretrained(
+ checkpoint,
+ revision=BASE_MODEL_REVISION,
+ patch_size=4,
+ max_num_patches=4,
+ **kwargs,
+ )
+
+ @classmethod
+ def _setup_test_attributes(cls, processor):
+ cls.image_token = processor.image_token
+ cls.pad_token_id = processor.tokenizer.pad_token_id
+ cls.image_pad_token_id = processor.image_token_id
+
+ def prepare_image_inputs(self, batch_size: int | None = None, nested: bool = False):
+ if batch_size is None:
+ return _make_dummy_image(size=(16, 16))
+ images = [_make_dummy_image(size=(16, 16), color=(50 * (i + 1), 0, 0)) for i in range(batch_size)]
+ if nested:
+ return [[image] for image in images]
+ return images
+
+ @unittest.skip("Isaac chat templates emit placeholders but the processor consumes image pad tokens")
+ def test_apply_chat_template_image_0(self):
+ pass
+
+ @unittest.skip("Isaac chat templates emit placeholders but the processor consumes image pad tokens")
+ def test_apply_chat_template_image_1(self):
+ pass
+
+ def test_apply_chat_template_image_placeholder_expands_to_image_pad_tokens(self):
+ processor = self.get_processor()
+ image = _make_dummy_image(size=(16, 16))
+ messages = [
+ [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Describe this."},
+ {"type": "image", "image": image},
+ ],
+ }
+ ]
+ ]
+
+ formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
+ self.assertEqual(len(formatted_prompt), 1)
+ self.assertIn("", formatted_prompt[0])
+
+ out_dict = processor.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ )
+ self.assertTrue(
+ all(
+ key in out_dict
+ for key in [
+ "input_ids",
+ "attention_mask",
+ "pixel_values",
+ "image_grid_thw",
+ "image_metadata",
+ "mm_token_type_ids",
+ ]
+ )
+ )
+
+ expected_num_image_tokens = processor._get_num_multimodal_tokens(image_sizes=[(image.height, image.width)])[
+ "num_image_tokens"
+ ][0]
+ actual_num_image_tokens = int(out_dict["input_ids"][0].eq(processor.image_token_id).sum().item())
+
+ self.assertEqual(actual_num_image_tokens, expected_num_image_tokens)
+ self.assertEqual(int(out_dict["mm_token_type_ids"][0].sum().item()), expected_num_image_tokens)
+ self.assertEqual(int(out_dict["image_metadata"][0, 0, 1].item()), expected_num_image_tokens)
+ self.assertTrue(
+ torch.all(out_dict["mm_token_type_ids"][0][out_dict["input_ids"][0].eq(processor.image_token_id)] == 1)
+ )
+
+ def test_get_num_multimodal_tokens_matches_processor_call(self):
+ processor = self.get_processor()
+
+ image_sizes = [(100, 100), (300, 100), (500, 30), (213, 167)]
+ image_inputs = [np.random.randint(255, size=(h, w, 3), dtype=np.uint8) for h, w in image_sizes]
+
+ text = [f"This is an image {self.image_token}"] * len(image_inputs)
+ inputs = processor(
+ text=text,
+ images=[[image] for image in image_inputs],
+ padding=True,
+ return_mm_token_type_ids=True,
+ return_tensors="pt",
+ )
+
+ num_image_tokens_from_call = inputs.mm_token_type_ids.sum(-1).tolist()
+ num_image_tokens_from_helper = processor._get_num_multimodal_tokens(image_sizes=image_sizes)
+ self.assertListEqual(num_image_tokens_from_call, num_image_tokens_from_helper["num_image_tokens"])
diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py
index 7366845c4d78..3ac8ff1573f8 100644
--- a/tests/utils/test_modeling_utils.py
+++ b/tests/utils/test_modeling_utils.py
@@ -3246,6 +3246,67 @@ def test_vision_language_model(self):
assert dec is model.model.language_model, f"LLaVA get_decoder() should return language_model, got {type(dec)}"
+class TestEmbeddingAccessMixin(unittest.TestCase):
+ def test_get_input_embeddings_supports_dotted_input_embed_layer(self):
+ class NestedEmbeddingModel(PreTrainedModel):
+ config_class = PreTrainedConfig
+ _input_embed_layer = "text_model.embed_tokens"
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.text_model = nn.Module()
+ self.text_model.embed_tokens = nn.Embedding(8, 4)
+
+ def forward(self, input_ids=None):
+ return input_ids
+
+ model = NestedEmbeddingModel(PreTrainedConfig())
+
+ assert model.get_input_embeddings() is model.text_model.embed_tokens
+
+ def test_set_input_embeddings_supports_dotted_input_embed_layer(self):
+ class NestedEmbeddingModel(PreTrainedModel):
+ config_class = PreTrainedConfig
+ _input_embed_layer = "text_model.embed_tokens"
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.text_model = nn.Module()
+ self.text_model.embed_tokens = nn.Embedding(8, 4)
+
+ def forward(self, input_ids=None):
+ return input_ids
+
+ model = NestedEmbeddingModel(PreTrainedConfig())
+ new_embeddings = nn.Embedding(10, 4)
+
+ model.set_input_embeddings(new_embeddings)
+
+ assert model.get_input_embeddings() is new_embeddings
+ assert model.text_model.embed_tokens is new_embeddings
+
+ def test_invalid_dotted_input_embed_layer_raises(self):
+ class NestedEmbeddingModel(PreTrainedModel):
+ config_class = PreTrainedConfig
+ _input_embed_layer = "text_model.missing_embed_tokens"
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.text_model = nn.Module()
+ self.text_model.embed_tokens = nn.Embedding(8, 4)
+
+ def forward(self, input_ids=None):
+ return input_ids
+
+ model = NestedEmbeddingModel(PreTrainedConfig())
+
+ with self.assertRaises(NotImplementedError):
+ model.get_input_embeddings()
+
+ with self.assertRaises(NotImplementedError):
+ model.set_input_embeddings(nn.Embedding(10, 4))
+
+
class TestGetEncoder(unittest.TestCase):
def test_seq2seq_lm_get_encoder_returns_encoder(self):
cfg = BartConfig(
diff --git a/utils/check_repo.py b/utils/check_repo.py
index b1a3d158c716..4d040fc165c6 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -217,6 +217,8 @@
"Qwen3VLMoeTextModel", # Building part of bigger (tested) model.
"Qwen3_5TextModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3_5ForConditionalGeneration.
"Qwen3_5MoeTextModel", # Building part of bigger (tested) model. Tested implicitly through Qwen3_5MoeForConditionalGeneration.
+ "IsaacTextModel", # Building part of bigger (tested) model. Tested implicitly through IsaacForConditionalGeneration.
+ "IsaacVisionModel", # Building part of bigger (tested) model. Tested implicitly through IsaacForConditionalGeneration.
"Qwen2_5OmniForConditionalGeneration", # Not a regular model. Testted in Qwen2_5OmniModelIntergrationTest
"Qwen2_5OmniTalkerForConditionalGeneration", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest.
"Qwen2_5OmniTalkerModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5OmniModelIntergrationTest.
@@ -467,6 +469,8 @@
"PaddleOCRVisionModel", # Building part of bigger (tested) model
"PaddleOCRVisionTransformer", # Building part of bigger (tested) model
"PaddleOCRTextModel", # Building part of bigger (tested) model
+ "IsaacTextModel", # Building part of a bigger model
+ "IsaacVisionModel", # Building part of a bigger model
"Qwen2_5OmniTalkerForConditionalGeneration", # Building part of a bigger model
"Qwen2_5OmniTalkerModel", # Building part of a bigger model
"Qwen2_5OmniThinkerForConditionalGeneration", # Building part of a bigger model
@@ -1130,6 +1134,7 @@ def find_all_documented_objects() -> list[str]:
"Ernie4_5_VL_MoeImageProcessorFast", # BC Alias
"Ernie4_5_VL_MoeImageProcessorPil", # BC Alias
"Ernie4_5_VL_MoeModel", # BC Alias
+ "IsaacVisionModel", # Internal building block tested implicitly through IsaacForConditionalGeneration.
"Ernie4_5_VL_MoeTextConfig", # BC Alias
"Ernie4_5_VL_MoeTextModel", # BC Alias
"Ernie4_5_VL_MoeVariableResolutionResamplerModel", # BC Alias