diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index ba69db1c5e78..496b4d95c51c 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -1351,6 +1351,8 @@
title: SAM3
- local: model_doc/sam3_video
title: SAM3 Video
+ - local: model_doc/sam3_lite_text
+ title: SAM3-LiteText
- local: model_doc/shieldgemma2
title: ShieldGemma2
- local: model_doc/siglip
diff --git a/docs/source/en/model_doc/nomic_bert.md b/docs/source/en/model_doc/nomic_bert.md
index 73b3adc8a35f..2017805fe42a 100644
--- a/docs/source/en/model_doc/nomic_bert.md
+++ b/docs/source/en/model_doc/nomic_bert.md
@@ -23,7 +23,7 @@ limitations under the License.
## Overview
-NomicBERT was proposed in [Nomic Embed: Training a Reproducible Long Context Text Embedder](https://arxiv.org/abs/2402.01613) by
+NomicBERT was proposed in [Nomic Embed: Training a Reproducible Long Context Text Embedder](https://huggingface.co/papers/2402.01613) by
Zach Nussbaum, John X. Morris, Brandon Duderstadt, and Andriy Mulyar. It is BERT-inspired with the most notable extension applying
[Rotary Position Embeddings](https://huggingface.co/papers/2104.09864.pdf) to an encoder model.
diff --git a/docs/source/en/model_doc/sam3_lite_text.md b/docs/source/en/model_doc/sam3_lite_text.md
new file mode 100644
index 000000000000..79ebb9a1701c
--- /dev/null
+++ b/docs/source/en/model_doc/sam3_lite_text.md
@@ -0,0 +1,118 @@
+
+*This model was released on 2026-02-12 and added to Hugging Face Transformers on 2026-04-12.*
+
+# SAM3-LiteText
+
+
+
+

+
+
+
+## Overview
+
+SAM3-LiteText was proposed in [SAM3-LiteText: An Anatomical Study of the SAM3 Text Encoder for Efficient Vision-Language Segmentation](https://huggingface.co/papers/2602.12173) by Chengxi Zeng, Yuxuan Jiang, Ge Gao, Shuai Wang, Duolikun Danier, Bin Zhu, Stevan Rudinac, David Bull, and Fan Zhang.
+
+SAM3-LiteText is a lightweight variant of [SAM3](sam3) that replaces the heavy SAM3 text encoder (353M parameters) with a compact MobileCLIP-based text encoder optimized through knowledge distillation. The SAM3 ViT-H image encoder is kept intact. This reduces text encoder parameters by up to 88% while maintaining segmentation performance comparable to the original model.
+
+The abstract from the paper is the following:
+
+*Vision-language segmentation models such as SAM3 enable flexible, prompt-driven visual grounding, but inherit large, general-purpose text encoders originally designed for open-ended language understanding. In practice, segmentation prompts are short, structured, and semantically constrained, leading to substantial over-provisioning in text encoder capacity and persistent computational and memory overhead. In this paper, we perform a large-scale anatomical analysis of text prompting in vision-language segmentation, covering 404,796 real prompts across multiple benchmarks. Our analysis reveals severe redundancy: most context windows are underutilized, vocabulary usage is highly sparse, and text embeddings lie on low-dimensional manifold despite high-dimensional representations. Motivated by these findings, we propose SAM3-LiteText, a lightweight text encoding framework that replaces the original SAM3 text encoder with a compact MobileCLIP student that is optimized by knowledge distillation. Extensive experiments on image and video segmentation benchmarks show that SAM3-LiteText reduces text encoder parameters by up to 88%, substantially reducing static memory footprint, while maintaining segmentation performance comparable to the original model.*
+
+The text encoder architecture is based on [MobileCLIP](https://huggingface.co/papers/2311.17049) and comes in three variants:
+
+| Variant | Text Encoder | Text Params | Reduction |
+|---|---|---|---|
+| SAM3-LiteText-S0-16 | MobileCLIP-S0 | 42.54M | ~88% |
+| SAM3-LiteText-S1-16 | MobileCLIP-S1 | 63.53M | ~82% |
+| SAM3-LiteText-L-16 | MobileCLIP2-L | 123.80M | ~65% |
+
+This model was contributed by [nielsr](https://huggingface.co/nielsr) and [yonigozlan](https://huggingface.co/yonigozlan).
+The original code can be found [here](https://github.com/SimonZeng7108/efficientsam3/tree/sam3_litetext).
+
+## Usage
+
+SAM3-LiteText is a drop-in replacement for SAM3 with a lightweight text encoder. It uses the same processor ([`Sam3Processor`]) and supports the same prompting interface. Refer to the [SAM3 documentation](sam3) for detailed usage examples including text prompts, box prompts, batched inference, and more.
+
+```python
+from io import BytesIO
+
+import httpx
+from transformers import AutoModel, AutoProcessor
+from PIL import Image
+
+model = AutoModel.from_pretrained("yonigozlan/sam3-litetext-s0", device_map="auto")
+processor = AutoProcessor.from_pretrained("yonigozlan/sam3-litetext-s0")
+
+image_url = "http://images.cocodataset.org/val2017/000000077595.jpg"
+image = Image.open(BytesIO(httpx.get(image_url).content)).convert("RGB")
+
+inputs = processor(images=image, text="ear", return_tensors="pt").to(model.device)
+
+outputs = model(**inputs)
+
+results = processor.post_process_instance_segmentation(
+ outputs,
+ threshold=0.5,
+ mask_threshold=0.5,
+ target_sizes=inputs.get("original_sizes").tolist(),
+)[0]
+
+print(f"Found {len(results['masks'])} objects")
+```
+
+## Sam3LiteTextConfig
+
+[[autodoc]] Sam3LiteTextConfig
+
+## Sam3LiteTextTextConfig
+
+[[autodoc]] Sam3LiteTextTextConfig
+
+## Sam3LiteTextGeometryEncoderConfig
+
+[[autodoc]] Sam3LiteTextGeometryEncoderConfig
+
+## Sam3LiteTextDETREncoderConfig
+
+[[autodoc]] Sam3LiteTextDETREncoderConfig
+
+## Sam3LiteTextDETRDecoderConfig
+
+[[autodoc]] Sam3LiteTextDETRDecoderConfig
+
+## Sam3LiteTextMaskDecoderConfig
+
+[[autodoc]] Sam3LiteTextMaskDecoderConfig
+
+## Sam3LiteTextTextModel
+
+[[autodoc]] Sam3LiteTextTextModel
+ - forward
+
+## Sam3LiteTextModel
+
+[[autodoc]] Sam3LiteTextModel
+ - forward
+
+## Sam3LiteTextPreTrainedModel
+
+[[autodoc]] Sam3LiteTextPreTrainedModel
+ - forward
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 989be9eb114e..acc5e2fdeac0 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -367,6 +367,7 @@
from .sam2 import *
from .sam2_video import *
from .sam3 import *
+ from .sam3_lite_text import *
from .sam3_tracker import *
from .sam3_tracker_video import *
from .sam3_video import *
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 2c0fe88d0e74..07b0ec3854ef 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -421,6 +421,8 @@
("sam2_video", "Sam2VideoConfig"),
("sam2_vision_model", "Sam2VisionConfig"),
("sam3", "Sam3Config"),
+ ("sam3_lite_text", "Sam3LiteTextConfig"),
+ ("sam3_lite_text_text_model", "Sam3LiteTextTextConfig"),
("sam3_tracker", "Sam3TrackerConfig"),
("sam3_tracker_video", "Sam3TrackerVideoConfig"),
("sam3_video", "Sam3VideoConfig"),
@@ -954,6 +956,8 @@
("sam2_video", "Sam2VideoModel"),
("sam2_vision_model", "Sam2VisionModel"),
("sam3", "SAM3"),
+ ("sam3_lite_text", "SAM3-LiteText"),
+ ("sam3_lite_text_text_model", "SAM3-LiteText Text Model"),
("sam3_tracker", "Sam3Tracker"),
("sam3_tracker_video", "Sam3TrackerVideo"),
("sam3_video", "Sam3VideoModel"),
@@ -1142,6 +1146,7 @@
("sam_vision_model", "sam"),
("sam2_vision_model", "sam2"),
("sam2_hiera_det_model", "sam2"),
+ ("sam3_lite_text_text_model", "sam3_lite_text"),
("sam3_vit_model", "sam3"),
("sam3_vision_model", "sam3"),
("edgetam_vision_model", "edgetam"),
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 4ada3ba6b8ed..e554249cac4a 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -226,6 +226,7 @@
("sam2", {"torchvision": "Sam2ImageProcessor"}),
("sam2_video", {"torchvision": "Sam2ImageProcessor"}),
("sam3", {"torchvision": "Sam3ImageProcessor"}),
+ ("sam3_lite_text", {"torchvision": "Sam3ImageProcessor"}),
("sam3_tracker", {"torchvision": "Sam3ImageProcessor"}),
("sam3_tracker_video", {"torchvision": "Sam3ImageProcessor"}),
("sam3_video", {"torchvision": "Sam3ImageProcessor"}),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index d4cb17cddfa6..50bbd5721413 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -399,6 +399,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("sam2_video", "Sam2VideoModel"),
("sam2_vision_model", "Sam2VisionModel"),
("sam3", "Sam3Model"),
+ ("sam3_lite_text", "Sam3LiteTextModel"),
+ ("sam3_lite_text_text_model", "Sam3LiteTextTextModel"),
("sam3_tracker", "Sam3TrackerModel"),
("sam3_tracker", "Sam3TrackerModel"),
("sam3_tracker_video", "Sam3TrackerVideoModel"),
diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py
index 262480b71485..1b6384b41364 100644
--- a/src/transformers/models/auto/processing_auto.py
+++ b/src/transformers/models/auto/processing_auto.py
@@ -152,6 +152,7 @@
("sam", "SamProcessor"),
("sam2", "Sam2Processor"),
("sam3", "Sam3Processor"),
+ ("sam3_lite_text", "Sam3Processor"),
("sam_hq", "SamHQProcessor"),
("seamless_m4t", "SeamlessM4TProcessor"),
("sew", "Wav2Vec2Processor"),
diff --git a/src/transformers/models/sam3/convert_sam3_to_hf.py b/src/transformers/models/sam3/convert_sam3_to_hf.py
index 927b1cc5f208..af1786c94c7b 100644
--- a/src/transformers/models/sam3/convert_sam3_to_hf.py
+++ b/src/transformers/models/sam3/convert_sam3_to_hf.py
@@ -25,7 +25,7 @@
import regex as re
import torch
-from transformers import CLIPTokenizerFast, Sam3Config, Sam3ImageProcessorFast, Sam3Model, Sam3Processor
+from transformers import CLIPTokenizerFast, Sam3Config, Sam3ImageProcessor, Sam3Model, Sam3Processor
from transformers.utils import logging
@@ -383,7 +383,7 @@ def convert_sam3_checkpoint(
# Save processor
print("Creating and saving processor...")
- image_processor = Sam3ImageProcessorFast()
+ image_processor = Sam3ImageProcessor()
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", max_length=32, model_max_length=32)
processor = Sam3Processor(image_processor=image_processor, tokenizer=tokenizer)
processor.save_pretrained(output_path)
diff --git a/src/transformers/models/sam3/modeling_sam3.py b/src/transformers/models/sam3/modeling_sam3.py
index 5a9aa329daa0..0dee80edf5db 100644
--- a/src/transformers/models/sam3/modeling_sam3.py
+++ b/src/transformers/models/sam3/modeling_sam3.py
@@ -279,7 +279,7 @@ def box_cxcywh_to_xyxy(x):
class Sam3MLP(nn.Module):
- def __init__(self, config: Sam3ViTConfig):
+ def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
diff --git a/src/transformers/models/sam3_lite_text/__init__.py b/src/transformers/models/sam3_lite_text/__init__.py
new file mode 100644
index 000000000000..753f93a87c0e
--- /dev/null
+++ b/src/transformers/models/sam3_lite_text/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2026 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_sam3_lite_text import *
+ from .modeling_sam3_lite_text import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/sam3_lite_text/configuration_sam3_lite_text.py b/src/transformers/models/sam3_lite_text/configuration_sam3_lite_text.py
new file mode 100644
index 000000000000..696751075611
--- /dev/null
+++ b/src/transformers/models/sam3_lite_text/configuration_sam3_lite_text.py
@@ -0,0 +1,330 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/sam3_lite_text/modular_sam3_lite_text.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_sam3_lite_text.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2026 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from huggingface_hub.dataclasses import strict
+
+from ...configuration_utils import PreTrainedConfig
+from ...utils import auto_docstring
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+@auto_docstring(checkpoint="facebook/sam3_lite_text")
+@strict
+class Sam3LiteTextViTConfig(PreTrainedConfig):
+ r"""
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ Base frequency for RoPE.
+ window_size (`int`, *optional*, defaults to 24):
+ Window size for windowed attention.
+ global_attn_indexes (`list[int]`, *optional*, defaults to `[7, 15, 23, 31]`):
+ Indexes of layers with global attention.
+ pretrain_image_size (`int`, *optional*, defaults to 336):
+ Pretrained model image size for position embedding initialization.
+ hidden_dropout (`float`, *optional*, defaults to 0.0):
+ Dropout probability for hidden states.
+ """
+
+ base_config_key = "backbone_config"
+ model_type = "sam3_vit_model"
+
+ hidden_size: int = 1024
+ intermediate_size: int = 4736
+ num_hidden_layers: int = 32
+ num_attention_heads: int = 16
+ num_channels: int = 3
+ image_size: int | list[int] | tuple[int, int] = 1008
+ patch_size: int | list[int] | tuple[int, int] = 14
+ hidden_act: str = "gelu"
+ layer_norm_eps: float = 1e-6
+ attention_dropout: float | int = 0.0
+ rope_theta: float = 10000.0
+ window_size: int = 24
+ global_attn_indexes: list[int] | None = None
+ layer_scale_init_value: float | None = None
+ pretrain_image_size: int | list[int] | tuple[int, int] = 336
+ hidden_dropout: float | int = 0.0
+ initializer_range: float = 0.02
+
+ def __post_init__(self, **kwargs):
+ super().__post_init__(**kwargs)
+ if self.global_attn_indexes is None:
+ self.global_attn_indexes = [7, 15, 23, 31]
+
+
+@auto_docstring(checkpoint="facebook/sam3_lite_text")
+@strict
+class Sam3LiteTextVisionConfig(PreTrainedConfig):
+ r"""
+ fpn_hidden_size (`int`, *optional*, defaults to 256):
+ The hidden dimension of the FPN.
+ backbone_feature_sizes (`List[List[int]]`, *optional*, defaults to `[[288, 288], [144, 144], [72, 72]]`):
+ The spatial sizes (height, width) of the feature maps from the backbone at different scales.
+ scale_factors (`list[float]`, *optional*, defaults to `[4.0, 2.0, 1.0, 0.5]`):
+ Scale factors for FPN multi-scale features. List of scaling factors for each FPN level.
+ """
+
+ base_config_key = "vision_config"
+ model_type = "sam3_vision_model"
+ sub_configs = {"backbone_config": AutoConfig}
+
+ backbone_config: dict | PreTrainedConfig | None = None
+ fpn_hidden_size: int = 256
+ backbone_feature_sizes: list | None = None
+ scale_factors: list[float] | None = None
+ hidden_act: str = "gelu"
+ layer_norm_eps: float = 1e-6
+ initializer_range: float = 0.02
+
+ def __post_init__(self, **kwargs):
+ self.scale_factors = [4.0, 2.0, 1.0, 0.5] if self.scale_factors is None else self.scale_factors
+ if self.backbone_feature_sizes is None:
+ self.backbone_feature_sizes = [[288, 288], [144, 144], [72, 72]]
+
+ if isinstance(self.backbone_config, dict):
+ self.backbone_config["model_type"] = self.backbone_config.get("model_type", "sam3_vit_model")
+ self.backbone_config = CONFIG_MAPPING[self.backbone_config["model_type"]](**self.backbone_config)
+ elif self.backbone_config is None:
+ self.backbone_config = CONFIG_MAPPING["sam3_vit_model"]()
+
+ super().__post_init__(**kwargs)
+
+ @property
+ def image_size(self):
+ """Image size for the vision encoder."""
+ return self.backbone_config.image_size
+
+ @image_size.setter
+ def image_size(self, value):
+ """Set the image size and propagate to backbone."""
+ self.backbone_config.image_size = value
+
+
+@auto_docstring(checkpoint="facebook/sam3_lite_text")
+@strict
+class Sam3LiteTextGeometryEncoderConfig(PreTrainedConfig):
+ r"""
+ roi_size (`int`, *optional*, defaults to 7):
+ ROI size for box pooling operations.
+ """
+
+ model_type = "sam3_lite_text_geometry_encoder"
+
+ hidden_size: int = 256
+ num_layers: int = 3
+ num_attention_heads: int = 8
+ intermediate_size: int = 2048
+ dropout: float | int = 0.1
+ hidden_act: str = "relu"
+ hidden_dropout: float | int = 0.0
+ layer_norm_eps: float = 1e-6
+ roi_size: int = 7
+ initializer_range: float = 0.02
+
+
+@auto_docstring(checkpoint="facebook/sam3_lite_text")
+@strict
+class Sam3LiteTextDETREncoderConfig(PreTrainedConfig):
+ r"""
+ hidden_dropout (`float`, *optional*, defaults to 0.0):
+ Dropout probability for hidden states.
+ """
+
+ model_type = "sam3_lite_text_detr_encoder"
+
+ hidden_size: int = 256
+ num_layers: int = 6
+ num_attention_heads: int = 8
+ intermediate_size: int = 2048
+ dropout: float | int = 0.1
+ hidden_act: str = "relu"
+ hidden_dropout: float | int = 0.0
+ layer_norm_eps: float = 1e-6
+ initializer_range: float = 0.02
+
+
+@auto_docstring(checkpoint="facebook/sam3_lite_text")
+@strict
+class Sam3LiteTextDETRDecoderConfig(PreTrainedConfig):
+ r"""
+ num_queries (`int`, *optional*, defaults to 200):
+ Number of object queries.
+ """
+
+ model_type = "sam3_lite_text_detr_decoder"
+
+ hidden_size: int = 256
+ num_layers: int = 6
+ num_queries: int = 200
+ num_attention_heads: int = 8
+ intermediate_size: int = 2048
+ dropout: float | int = 0.1
+ hidden_act: str = "relu"
+ hidden_dropout: float | int = 0.0
+ layer_norm_eps: float = 1e-6
+ initializer_range: float = 0.02
+
+
+@auto_docstring(checkpoint="facebook/sam3_lite_text")
+@strict
+class Sam3LiteTextMaskDecoderConfig(PreTrainedConfig):
+ r"""
+ num_upsampling_stages (`int`, *optional*, defaults to 3):
+ Number of upsampling stages in the pixel decoder (FPN).
+ """
+
+ model_type = "sam3_lite_text_mask_decoder"
+
+ hidden_size: int = 256
+ num_upsampling_stages: int = 3
+ layer_norm_eps: float = 1e-6
+ dropout: float | int = 0.0
+ num_attention_heads: int = 8
+ initializer_range: float = 0.02
+
+
+@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0")
+@strict
+class Sam3LiteTextTextConfig(PreTrainedConfig):
+ r"""
+ use_repmixer_blocks (`bool`, *optional*, defaults to `True`):
+ Whether to use RepMixer blocks (MobileCLIP-style) for the first and last encoder layers.
+ When `False`, all layers are standard Transformer encoder layers.
+ layer_scale_init_value (`float`, *optional*, defaults to `1e-5`):
+ Initial value for the learnable layer-scale parameters in RepMixer blocks (residual branches).
+ repmixer_kernel_size (`int`, *optional*, defaults to `11`):
+ Kernel size for depthwise convolutions in RepMixer blocks (token mixer and convolutional feed-forward path).
+ """
+
+ model_type = "sam3_lite_text_text_model"
+
+ vocab_size: int = 49408
+ hidden_size: int = 512
+ intermediate_size: int = 2048
+ projection_dim: int = 512
+ num_hidden_layers: int = 12
+ num_attention_heads: int = 8
+ max_position_embeddings: int = 77
+ hidden_act: str = "gelu"
+ layer_norm_eps: float = 1e-5
+ attention_dropout: float = 0.0
+ use_repmixer_blocks: bool = True
+ layer_scale_init_value: float = 1e-5
+ repmixer_kernel_size: int = 11
+
+
+@auto_docstring(checkpoint="facebook/sam3_lite_text")
+@strict
+class Sam3LiteTextConfig(PreTrainedConfig):
+ r"""
+ geometry_encoder_config (`dict` or `Sam3LiteTextGeometryEncoderConfig`, *optional*):
+ Configuration for the geometry encoder.
+ detr_encoder_config (`dict` or `Sam3LiteTextDETREncoderConfig`, *optional*):
+ Configuration for the DETR encoder.
+ detr_decoder_config (`dict` or `Sam3LiteTextDETRDecoderConfig`, *optional*):
+ Configuration for the DETR decoder.
+ mask_decoder_config (`dict` or `Sam3LiteTextMaskDecoderConfig`, *optional*):
+ Configuration for the mask decoder.
+
+ Example:
+ ```python
+ >>> from transformers import Sam3LiteTextConfig, Sam3LiteTextModel
+
+ >>> # Initializing a SAM3_LITE_TEXT configuration
+ >>> configuration = Sam3LiteTextConfig()
+
+ >>> # Initializing a model from the configuration
+ >>> model = Sam3LiteTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "sam3_lite_text"
+ sub_configs = {
+ "vision_config": AutoConfig,
+ "text_config": Sam3LiteTextTextConfig,
+ "geometry_encoder_config": Sam3LiteTextGeometryEncoderConfig,
+ "detr_encoder_config": Sam3LiteTextDETREncoderConfig,
+ "detr_decoder_config": Sam3LiteTextDETRDecoderConfig,
+ "mask_decoder_config": Sam3LiteTextMaskDecoderConfig,
+ }
+
+ vision_config: dict | PreTrainedConfig | None = None
+ text_config: dict | PreTrainedConfig | None = None
+ geometry_encoder_config: dict | PreTrainedConfig | None = None
+ detr_encoder_config: dict | PreTrainedConfig | None = None
+ detr_decoder_config: dict | PreTrainedConfig | None = None
+ mask_decoder_config: dict | PreTrainedConfig | None = None
+ initializer_range: float = 0.02
+
+ def __post_init__(self, **kwargs):
+ if isinstance(self.vision_config, dict):
+ self.vision_config["model_type"] = self.vision_config.get("model_type", "sam3_vision_model")
+ self.vision_config = CONFIG_MAPPING[self.vision_config["model_type"]](**self.vision_config)
+ elif self.vision_config is None:
+ self.vision_config = CONFIG_MAPPING["sam3_vision_model"]()
+
+ if self.text_config is None:
+ self.text_config = Sam3LiteTextTextConfig()
+ if isinstance(self.text_config, dict):
+ self.text_config = Sam3LiteTextTextConfig(**self.text_config)
+
+ if self.geometry_encoder_config is None:
+ self.geometry_encoder_config = Sam3LiteTextGeometryEncoderConfig()
+ if isinstance(self.geometry_encoder_config, dict):
+ self.geometry_encoder_config = Sam3LiteTextGeometryEncoderConfig(**self.geometry_encoder_config)
+
+ if self.detr_encoder_config is None:
+ self.detr_encoder_config = Sam3LiteTextDETREncoderConfig()
+ if isinstance(self.detr_encoder_config, dict):
+ self.detr_encoder_config = Sam3LiteTextDETREncoderConfig(**self.detr_encoder_config)
+
+ if self.detr_decoder_config is None:
+ self.detr_decoder_config = Sam3LiteTextDETRDecoderConfig()
+ if isinstance(self.detr_decoder_config, dict):
+ self.detr_decoder_config = Sam3LiteTextDETRDecoderConfig(**self.detr_decoder_config)
+
+ if self.mask_decoder_config is None:
+ self.mask_decoder_config = Sam3LiteTextMaskDecoderConfig()
+ if isinstance(self.mask_decoder_config, dict):
+ self.mask_decoder_config = Sam3LiteTextMaskDecoderConfig(**self.mask_decoder_config)
+
+ super().__post_init__(**kwargs)
+
+ @property
+ def image_size(self):
+ """Image size for the SAM3_LITE_TEXT model."""
+ return self.vision_config.image_size
+
+ @image_size.setter
+ def image_size(self, value):
+ """Set the image size and propagate to vision config."""
+ self.vision_config.image_size = value
+
+
+__all__ = [
+ "Sam3LiteTextConfig",
+ "Sam3LiteTextTextConfig",
+ "Sam3LiteTextGeometryEncoderConfig",
+ "Sam3LiteTextDETREncoderConfig",
+ "Sam3LiteTextDETRDecoderConfig",
+ "Sam3LiteTextMaskDecoderConfig",
+]
diff --git a/src/transformers/models/sam3_lite_text/convert_sam3_lite_text_to_hf.py b/src/transformers/models/sam3_lite_text/convert_sam3_lite_text_to_hf.py
new file mode 100644
index 000000000000..8dac30f41302
--- /dev/null
+++ b/src/transformers/models/sam3_lite_text/convert_sam3_lite_text_to_hf.py
@@ -0,0 +1,496 @@
+# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert EfficientSAM3 LiteText checkpoints to Hugging Face format."""
+
+import argparse
+import gc
+import os
+
+import regex as re
+import torch
+from huggingface_hub import hf_hub_download
+
+from transformers import CLIPTokenizerFast, Sam3ImageProcessor, Sam3Processor
+from transformers.models.sam3_lite_text.configuration_sam3_lite_text import Sam3LiteTextConfig, Sam3LiteTextTextConfig
+from transformers.models.sam3_lite_text.modeling_sam3_lite_text import Sam3LiteTextModel
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+# fmt: off
+ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
+ # Strip the "detector." prefix that wraps all components
+ r"^detector\.": r"",
+
+ # ============================================================================
+ # Vision Encoder - ViT Backbone
+ # ============================================================================
+ r"^backbone\.vision_backbone\.trunk\.": r"vision_encoder.backbone.",
+ r"^vision_encoder\.backbone\.pos_embed": r"vision_encoder.backbone.embeddings.position_embeddings",
+ r"^vision_encoder\.backbone\.patch_embed\.proj\.": r"vision_encoder.backbone.embeddings.patch_embeddings.projection.",
+ r"^vision_encoder\.backbone\.ln_pre\.": r"vision_encoder.backbone.layer_norm.",
+ r"^vision_encoder\.backbone\.blocks\.(\d+)\.norm1\.": r"vision_encoder.backbone.layers.\1.layer_norm1.",
+ r"^vision_encoder\.backbone\.blocks\.(\d+)\.norm2\.": r"vision_encoder.backbone.layers.\1.layer_norm2.",
+ r"^vision_encoder\.backbone\.blocks\.(\d+)\.attn\.qkv\.": r"vision_encoder.backbone.layers.\1.attention.qkv.",
+ r"^vision_encoder\.backbone\.blocks\.(\d+)\.attn\.proj\.": r"vision_encoder.backbone.layers.\1.attention.o_proj.",
+ r"^vision_encoder\.backbone\.blocks\.(\d+)\.attn\.freqs_cis": r"vision_encoder.backbone.layers.\1.rotary_emb.rope_embeddings",
+ r"^vision_encoder\.backbone\.blocks\.(\d+)\.mlp\.fc1\.": r"vision_encoder.backbone.layers.\1.mlp.fc1.",
+ r"^vision_encoder\.backbone\.blocks\.(\d+)\.mlp\.fc2\.": r"vision_encoder.backbone.layers.\1.mlp.fc2.",
+
+ # Vision Encoder - FPN Neck
+ r"^backbone\.vision_backbone\.neck\.fpn\.(\d+)\.": r"vision_encoder.neck.fpn_layers.\1.",
+ r"^backbone\.vision_backbone\.convs\.(\d+)\.dconv_2x2_0\.": r"vision_encoder.neck.fpn_layers.\1.scale_layers.0.",
+ r"^backbone\.vision_backbone\.convs\.(\d+)\.dconv_2x2_1\.": r"vision_encoder.neck.fpn_layers.\1.scale_layers.2.",
+ r"^backbone\.vision_backbone\.convs\.(\d+)\.dconv_2x2\.": r"vision_encoder.neck.fpn_layers.\1.scale_layers.0.",
+ r"^backbone\.vision_backbone\.convs\.(\d+)\.maxpool_2x2\.": r"vision_encoder.neck.fpn_layers.\1.scale_layers.0.",
+ r"^backbone\.vision_backbone\.convs\.(\d+)\.conv_1x1\.": r"vision_encoder.neck.fpn_layers.\1.proj1.",
+ r"^backbone\.vision_backbone\.convs\.(\d+)\.conv_3x3\.": r"vision_encoder.neck.fpn_layers.\1.proj2.",
+
+ # ============================================================================
+ # Text Encoder - LiteText (MobileCLIP student)
+ # ============================================================================
+ # Embeddings
+ r"^backbone\.language_backbone\.encoder\.embedding_layer\.": r"text_encoder.embeddings.token_embedding.",
+ r"^backbone\.language_backbone\.encoder\.positional_embedding\.pos_embed\.pos_embed$": r"text_encoder.embeddings.position_embedding.position_embedding",
+ r"^backbone\.language_backbone\.encoder\.final_layer_norm\.": r"text_encoder.final_layer_norm.",
+ r"^backbone\.language_backbone\.encoder\.projection_layer$": r"text_encoder.projection.weight",
+ # text_projection: projects from text hidden-dim to DETR hidden-dim
+ r"^backbone\.language_backbone\.projector\.": r"text_projection.",
+ # RepMixer blocks (layer 0 and the last layer in the mct variant)
+ r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.layer_scale$": r"text_encoder.layers.\1.layer_scale",
+ r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.token_mixer\.layer_scale$": r"text_encoder.layers.\1.token_mixer.layer_scale",
+ r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.token_mixer\.norm\.rbr_skip\.": r"text_encoder.layers.\1.token_mixer.reference_batchnorm.",
+ r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.token_mixer\.mixer\.rbr_skip\.": r"text_encoder.layers.\1.token_mixer.mixer.batchnorm_skip.",
+ r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.token_mixer\.mixer\.rbr_conv\.0\.conv\.": r"text_encoder.layers.\1.token_mixer.mixer.conv.",
+ r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.token_mixer\.mixer\.rbr_conv\.0\.bn\.": r"text_encoder.layers.\1.token_mixer.mixer.batchnorm_conv.",
+ r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.convffn\.conv\.conv\.": r"text_encoder.layers.\1.conv_feed_forward.depthwise_conv.",
+ r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.convffn\.conv\.bn\.": r"text_encoder.layers.\1.conv_feed_forward.depthwise_batchnorm.",
+ r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.convffn\.fc1\.": r"text_encoder.layers.\1.conv_feed_forward.mlp.fc1.",
+ r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.convffn\.fc2\.": r"text_encoder.layers.\1.conv_feed_forward.mlp.fc2.",
+ # Standard transformer layers (pre-norm MHA + FFN)
+ r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.pre_norm_mha\.0\.": r"text_encoder.layers.\1.layer_norm1.",
+ r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.pre_norm_mha\.1\.qkv_proj\.": r"text_encoder.layers.\1.self_attn.in_proj_",
+ r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.pre_norm_mha\.1\.out_proj\.": r"text_encoder.layers.\1.self_attn.out_proj.",
+ r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.pre_norm_ffn\.0\.": r"text_encoder.layers.\1.layer_norm2.",
+ r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.pre_norm_ffn\.1\.": r"text_encoder.layers.\1.mlp.fc1.",
+ r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.pre_norm_ffn\.4\.": r"text_encoder.layers.\1.mlp.fc2.",
+
+ # ============================================================================
+ # Geometry Encoder
+ # ============================================================================
+ r"^geometry_encoder\.points_direct_project\.": r"geometry_encoder.boxes_direct_project.",
+ r"^geometry_encoder\.points_pool_project\.": r"geometry_encoder.boxes_pool_project.",
+ r"^geometry_encoder\.points_pos_enc_project\.": r"geometry_encoder.boxes_pos_enc_project.",
+ r"^geometry_encoder\.encode\.(\d+)\.cross_attn_image\.out_proj\.": r"geometry_encoder.layers.\1.cross_attn.o_proj.",
+ r"^geometry_encoder\.encode\.(\d+)\.cross_attn_image\.": r"geometry_encoder.layers.\1.cross_attn.",
+ r"^geometry_encoder\.encode\.(\d+)\.self_attn\.out_proj\.": r"geometry_encoder.layers.\1.self_attn.o_proj.",
+ r"^geometry_encoder\.encode\.(\d+)\.self_attn\.": r"geometry_encoder.layers.\1.self_attn.",
+ r"^geometry_encoder\.encode\.(\d+)\.linear1\.": r"geometry_encoder.layers.\1.mlp.fc1.",
+ r"^geometry_encoder\.encode\.(\d+)\.linear2\.": r"geometry_encoder.layers.\1.mlp.fc2.",
+ r"^geometry_encoder\.encode\.(\d+)\.norm1\.": r"geometry_encoder.layers.\1.layer_norm1.",
+ r"^geometry_encoder\.encode\.(\d+)\.norm2\.": r"geometry_encoder.layers.\1.layer_norm2.",
+ r"^geometry_encoder\.encode\.(\d+)\.norm3\.": r"geometry_encoder.layers.\1.layer_norm3.",
+ r"^geometry_encoder\.img_pre_norm\.": r"geometry_encoder.vision_layer_norm.",
+ r"^geometry_encoder\.norm\.": r"geometry_encoder.prompt_layer_norm.",
+ r"^geometry_encoder\.encode_norm\.": r"geometry_encoder.output_layer_norm.",
+
+ # ============================================================================
+ # DETR Encoder
+ # ============================================================================
+ r"^transformer\.encoder\.layers\.(\d+)\.cross_attn_image\.out_proj\.": r"detr_encoder.layers.\1.cross_attn.o_proj.",
+ r"^transformer\.encoder\.layers\.(\d+)\.cross_attn_image\.": r"detr_encoder.layers.\1.cross_attn.",
+ r"^transformer\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.": r"detr_encoder.layers.\1.self_attn.o_proj.",
+ r"^transformer\.encoder\.layers\.(\d+)\.self_attn\.": r"detr_encoder.layers.\1.self_attn.",
+ r"^transformer\.encoder\.layers\.(\d+)\.cross_attn\.out_proj\.": r"detr_encoder.layers.\1.cross_attn.o_proj.",
+ r"^transformer\.encoder\.layers\.(\d+)\.cross_attn\.": r"detr_encoder.layers.\1.cross_attn.",
+ r"^transformer\.encoder\.layers\.(\d+)\.linear1\.": r"detr_encoder.layers.\1.mlp.fc1.",
+ r"^transformer\.encoder\.layers\.(\d+)\.linear2\.": r"detr_encoder.layers.\1.mlp.fc2.",
+ r"^transformer\.encoder\.layers\.(\d+)\.norm1\.": r"detr_encoder.layers.\1.layer_norm1.",
+ r"^transformer\.encoder\.layers\.(\d+)\.norm2\.": r"detr_encoder.layers.\1.layer_norm2.",
+ r"^transformer\.encoder\.layers\.(\d+)\.norm3\.": r"detr_encoder.layers.\1.layer_norm3.",
+
+ # ============================================================================
+ # DETR Decoder
+ # ============================================================================
+ r"^transformer\.decoder\.query_embed\.": r"detr_decoder.query_embed.",
+ r"^transformer\.decoder\.reference_points\.": r"detr_decoder.reference_points.",
+ r"^transformer\.decoder\.instance_query_embed\.": r"detr_decoder.instance_query_embed.",
+ r"^transformer\.decoder\.instance_reference_points\.": r"detr_decoder.instance_reference_points.",
+ r"^transformer\.decoder\.presence_token\.": r"detr_decoder.presence_token.",
+ r"^transformer\.decoder\.presence_token_head\.layers\.0\.": r"detr_decoder.presence_head.layer1.",
+ r"^transformer\.decoder\.presence_token_head\.layers\.1\.": r"detr_decoder.presence_head.layer2.",
+ r"^transformer\.decoder\.presence_token_head\.layers\.2\.": r"detr_decoder.presence_head.layer3.",
+ r"^transformer\.decoder\.presence_token_out_norm\.": r"detr_decoder.presence_layer_norm.",
+ r"^transformer\.decoder\.norm\.": r"detr_decoder.output_layer_norm.",
+ r"^transformer\.decoder\.bbox_embed\.layers\.0\.": r"detr_decoder.box_head.layer1.",
+ r"^transformer\.decoder\.bbox_embed\.layers\.1\.": r"detr_decoder.box_head.layer2.",
+ r"^transformer\.decoder\.bbox_embed\.layers\.2\.": r"detr_decoder.box_head.layer3.",
+ r"^transformer\.decoder\.instance_bbox_embed\.layers\.0\.": r"detr_decoder.instance_box_head.layer1.",
+ r"^transformer\.decoder\.instance_bbox_embed\.layers\.1\.": r"detr_decoder.instance_box_head.layer2.",
+ r"^transformer\.decoder\.instance_bbox_embed\.layers\.2\.": r"detr_decoder.instance_box_head.layer3.",
+ r"^transformer\.decoder\.ref_point_head\.layers\.0\.": r"detr_decoder.ref_point_head.layer1.",
+ r"^transformer\.decoder\.ref_point_head\.layers\.1\.": r"detr_decoder.ref_point_head.layer2.",
+ r"^transformer\.decoder\.boxRPB_embed_x\.layers\.0\.": r"detr_decoder.box_rpb_embed_x.layer1.",
+ r"^transformer\.decoder\.boxRPB_embed_x\.layers\.1\.": r"detr_decoder.box_rpb_embed_x.layer2.",
+ r"^transformer\.decoder\.boxRPB_embed_y\.layers\.0\.": r"detr_decoder.box_rpb_embed_y.layer1.",
+ r"^transformer\.decoder\.boxRPB_embed_y\.layers\.1\.": r"detr_decoder.box_rpb_embed_y.layer2.",
+ r"^transformer\.decoder\.layers\.(\d+)\.self_attn\.out_proj\.": r"detr_decoder.layers.\1.self_attn.o_proj.",
+ r"^transformer\.decoder\.layers\.(\d+)\.self_attn\.": r"detr_decoder.layers.\1.self_attn.",
+ r"^transformer\.decoder\.layers\.(\d+)\.ca_text\.out_proj\.": r"detr_decoder.layers.\1.text_cross_attn.o_proj.",
+ r"^transformer\.decoder\.layers\.(\d+)\.ca_text\.": r"detr_decoder.layers.\1.text_cross_attn.",
+ r"^transformer\.decoder\.layers\.(\d+)\.cross_attn\.out_proj\.": r"detr_decoder.layers.\1.vision_cross_attn.o_proj.",
+ r"^transformer\.decoder\.layers\.(\d+)\.cross_attn\.": r"detr_decoder.layers.\1.vision_cross_attn.",
+ r"^transformer\.decoder\.layers\.(\d+)\.linear1\.": r"detr_decoder.layers.\1.mlp.fc1.",
+ r"^transformer\.decoder\.layers\.(\d+)\.linear2\.": r"detr_decoder.layers.\1.mlp.fc2.",
+ r"^transformer\.decoder\.layers\.(\d+)\.norm1\.": r"detr_decoder.layers.\1.vision_cross_attn_layer_norm.",
+ r"^transformer\.decoder\.layers\.(\d+)\.catext_norm\.": r"detr_decoder.layers.\1.text_cross_attn_layer_norm.",
+ r"^transformer\.decoder\.layers\.(\d+)\.norm2\.": r"detr_decoder.layers.\1.self_attn_layer_norm.",
+ r"^transformer\.decoder\.layers\.(\d+)\.norm3\.": r"detr_decoder.layers.\1.mlp_layer_norm.",
+
+ # ============================================================================
+ # Dot Product Scoring
+ # ============================================================================
+ r"^dot_prod_scoring\.prompt_mlp\.layers\.0\.": r"dot_product_scoring.text_mlp.layer1.",
+ r"^dot_prod_scoring\.prompt_mlp\.layers\.1\.": r"dot_product_scoring.text_mlp.layer2.",
+ r"^dot_prod_scoring\.prompt_mlp\.out_norm\.": r"dot_product_scoring.text_mlp_out_norm.",
+ r"^dot_prod_scoring\.prompt_proj\.": r"dot_product_scoring.text_proj.",
+ r"^dot_prod_scoring\.hs_proj\.": r"dot_product_scoring.query_proj.",
+
+ # ============================================================================
+ # Mask Decoder
+ # ============================================================================
+ r"^segmentation_head\.pixel_decoder\.conv_layers\.(\d+)\.": r"mask_decoder.pixel_decoder.conv_layers.\1.",
+ r"^segmentation_head\.pixel_decoder\.norms\.(\d+)\.": r"mask_decoder.pixel_decoder.norms.\1.",
+ r"^segmentation_head\.mask_embed\.layers\.(\d+)\.": r"mask_decoder.mask_embedder.layers.\1.",
+ r"^segmentation_head\.mask_predictor\.mask_embed\.layers\.(\d+)\.": r"mask_decoder.mask_embedder.layers.\1.",
+ r"^segmentation_head\.instance_seg_head\.": r"mask_decoder.instance_projection.",
+ r"^segmentation_head\.semantic_seg_head\.": r"mask_decoder.semantic_projection.",
+ r"^segmentation_head\.cross_attend_prompt\.out_proj\.": r"mask_decoder.prompt_cross_attn.o_proj.",
+ r"^segmentation_head\.cross_attend_prompt\.": r"mask_decoder.prompt_cross_attn.",
+ r"^segmentation_head\.cross_attn_norm\.": r"mask_decoder.prompt_cross_attn_norm.",
+}
+# fmt: on
+
+
+def convert_old_keys_to_new_keys(state_dict_keys: list[str]) -> dict[str, str]:
+ """
+ Convert original SAM3 LiteText checkpoint keys to HuggingFace format.
+
+ Applies all regex patterns in `ORIGINAL_TO_CONVERTED_KEY_MAPPING` at once
+ using a multiline bulk-substitution.
+ """
+ output_dict = {}
+ if state_dict_keys is not None:
+ old_text = "\n".join(state_dict_keys)
+ new_text = old_text
+ for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
+ new_text = re.sub(pattern, replacement, new_text, flags=re.MULTILINE)
+ output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
+ return output_dict
+
+
+def split_qkv(state_dict: dict) -> dict:
+ """Split combined QKV projections into separate Q, K, V projections."""
+ # Vision backbone: .attention.qkv.* → .attention.{q,k,v}_proj.*
+ for key in [k for k in state_dict if ".attention.qkv." in k]:
+ qkv = state_dict.pop(key)
+ q, k, v = torch.chunk(qkv, 3, dim=0)
+ state_dict[key.replace(".qkv.", ".q_proj.")] = q
+ state_dict[key.replace(".qkv.", ".k_proj.")] = k
+ state_dict[key.replace(".qkv.", ".v_proj.")] = v
+
+ # Text encoder & attention layers: .in_proj_weight/bias → .{q,k,v}_proj.*
+ for key in [k for k in state_dict if ".in_proj_" in k]:
+ in_proj = state_dict.pop(key)
+ q, k, v = torch.chunk(in_proj, 3, dim=0)
+ if key.endswith("in_proj_weight"):
+ base = key.replace("in_proj_weight", "")
+ state_dict[base + "q_proj.weight"] = q
+ state_dict[base + "k_proj.weight"] = k
+ state_dict[base + "v_proj.weight"] = v
+ elif key.endswith("in_proj_bias"):
+ base = key.replace("in_proj_bias", "")
+ state_dict[base + "q_proj.bias"] = q
+ state_dict[base + "k_proj.bias"] = k
+ state_dict[base + "v_proj.bias"] = v
+
+ return state_dict
+
+
+def load_original_state_dict(checkpoint_path: str) -> dict[str, torch.Tensor]:
+ """Load the original EfficientSAM3 LiteText checkpoint."""
+ print(f"Loading original checkpoint from {checkpoint_path}")
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
+ if "model" in checkpoint:
+ state_dict = checkpoint["model"]
+ elif "state_dict" in checkpoint:
+ state_dict = checkpoint["state_dict"]
+ else:
+ state_dict = checkpoint
+ print(f"Loaded {len(state_dict)} keys from checkpoint")
+ return state_dict
+
+
+def _infer_text_config(state_dict: dict[str, torch.Tensor]) -> Sam3LiteTextTextConfig:
+ """Infer LiteText encoder hyper-parameters from the raw checkpoint."""
+ prefix = "detector.backbone.language_backbone.encoder."
+ hidden_size = state_dict[f"{prefix}embedding_layer.weight"].shape[1]
+ context_length = state_dict[f"{prefix}positional_embedding.pos_embed.pos_embed"].shape[2]
+ use_repmixer_blocks = any(f"{prefix}transformer.0.token_mixer" in k for k in state_dict)
+ if use_repmixer_blocks:
+ num_hidden_layers = 6
+ else:
+ layer_ids = {
+ int(k.split("transformer.")[1].split(".")[0])
+ for k in state_dict
+ if f"{prefix}transformer." in k and ".pre_norm_mha." in k
+ }
+ num_hidden_layers = max(layer_ids) + 1
+
+ return Sam3LiteTextTextConfig(
+ vocab_size=49408,
+ hidden_size=hidden_size,
+ intermediate_size=hidden_size * 4,
+ num_hidden_layers=num_hidden_layers,
+ num_attention_heads=hidden_size // 64,
+ max_position_embeddings=context_length,
+ projection_dim=hidden_size,
+ use_repmixer_blocks=use_repmixer_blocks,
+ )
+
+
+def get_sam3_lite_text_config(state_dict: dict[str, torch.Tensor]) -> Sam3LiteTextConfig:
+ """Build a Sam3LiteTextConfig inferred from the raw checkpoint."""
+ text_config = _infer_text_config(state_dict)
+ config = Sam3LiteTextConfig()
+ config.text_config = text_config
+ return config
+
+
+def convert_sam3_lite_text_checkpoint(
+ checkpoint_path: str,
+ output_path: str,
+ config: Sam3LiteTextConfig | None = None,
+ push_to_hub: bool = False,
+ repo_id: str | None = None,
+):
+ """
+ Convert an EfficientSAM3 LiteText checkpoint to HuggingFace format.
+
+ Args:
+ checkpoint_path: Path to the original `.pt` checkpoint file.
+ output_path: Directory where the converted model will be saved.
+ config: Optional pre-built `Sam3LiteTextConfig` (defaults to auto-inferred).
+ push_to_hub: Whether to push the model to the Hugging Face Hub.
+ repo_id: Hub repository ID (required when ``push_to_hub=True``).
+ """
+ os.makedirs(output_path, exist_ok=True)
+
+ # Load original checkpoint
+ state_dict_old = load_original_state_dict(checkpoint_path)
+
+ # Build config from checkpoint
+ if config is None:
+ config = get_sam3_lite_text_config(state_dict_old)
+
+ config.architectures = ["Sam3LiteTextModel"]
+ config.save_pretrained(output_path)
+ print("Model config saved successfully")
+
+ # Convert keys
+ print("Converting checkpoint keys...")
+ all_keys = list(state_dict_old.keys())
+ key_mapping = convert_old_keys_to_new_keys(all_keys)
+
+ state_dict_new = {}
+ for old_key in all_keys:
+ new_key = key_mapping.get(old_key, old_key)
+ # num_batches_tracked from BatchNorm is not needed
+ if "num_batches_tracked" in new_key:
+ continue
+ # Parallel SAM2 neck branch in the original checkpoint; HF vision uses `convs` only.
+ if "vision_backbone.sam2_convs" in new_key:
+ continue
+ # Drop keys whose names were not transformed (unrecognised / legacy keys)
+ if new_key == old_key:
+ continue
+ # Strip the first position (cls token) from ViT position embeddings
+ if new_key == "vision_encoder.backbone.embeddings.position_embeddings":
+ state_dict_new[new_key] = state_dict_old[old_key][:, 1:, :]
+ else:
+ state_dict_new[new_key] = state_dict_old[old_key]
+
+ del state_dict_old
+ gc.collect()
+
+ # Split combined QKV projections
+ print("Splitting QKV projections...")
+ state_dict_new = split_qkv(state_dict_new)
+
+ # HF models compute the RoPE table on the fly
+ for k in list(state_dict_new.keys()):
+ if k.endswith("rotary_emb.rope_embeddings"):
+ state_dict_new.pop(k)
+
+ print(
+ "Converted key counts:",
+ {
+ prefix: sum(1 for k in state_dict_new if k.startswith(prefix))
+ for prefix in (
+ "vision_encoder.",
+ "text_encoder.",
+ "geometry_encoder.",
+ "detr_encoder.",
+ "detr_decoder.",
+ "mask_decoder.",
+ )
+ },
+ )
+
+ # Load weights into HF model
+ print("Loading weights into Sam3LiteTextModel...")
+ model = Sam3LiteTextModel(config)
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict_new, strict=False)
+
+ if missing_keys:
+ logger.warning(f"Missing keys ({len(missing_keys)}):")
+ for key in missing_keys:
+ logger.warning(f" - {key}")
+
+ if unexpected_keys:
+ logger.warning(f"Unexpected keys ({len(unexpected_keys)}):")
+ for key in unexpected_keys:
+ logger.warning(f" - {key}")
+
+ # Save model
+ print(f"Saving converted model to {output_path}")
+ model.save_pretrained(output_path)
+
+ # Save processor
+ print("Creating and saving processor...")
+ image_processor = Sam3ImageProcessor()
+ tokenizer = CLIPTokenizerFast.from_pretrained(
+ "openai/clip-vit-base-patch32",
+ max_length=config.text_config.max_position_embeddings,
+ model_max_length=config.text_config.max_position_embeddings,
+ )
+ processor = Sam3Processor(image_processor=image_processor, tokenizer=tokenizer)
+ processor.save_pretrained(output_path)
+
+ if push_to_hub:
+ if repo_id is None:
+ raise ValueError("repo_id must be provided when push_to_hub=True")
+ print(f"Pushing model to Hub: {repo_id}")
+ model.push_to_hub(repo_id)
+ processor.push_to_hub(repo_id)
+
+ print("Conversion complete!")
+
+ # Cleanup and verify
+ del state_dict_new, model
+ gc.collect()
+
+ print("\nVerifying converted checkpoint can be loaded...")
+ try:
+ model = Sam3LiteTextModel.from_pretrained(output_path)
+ param_count = sum(p.numel() for p in model.parameters())
+ print(f"Successfully loaded model with {param_count:,} parameters")
+ del model
+ gc.collect()
+ except Exception as e:
+ print(f"Failed to reload model: {e}")
+
+ print("\n" + "=" * 80)
+ print("Conversion finished!")
+ print("=" * 80)
+ print(f"Output directory: {output_path}")
+ print("\nTo use the model:")
+ print(">>> from transformers import Sam3LiteTextModel, Sam3Processor")
+ print(f">>> model = Sam3LiteTextModel.from_pretrained('{output_path}')")
+ print("=" * 80)
+
+
+MODEL_VARIANTS = {
+ "s0": "sam3_litetext/efficient_sam3_image_encoder_mobileclip_s0_ctx16.pt",
+ "s1": "sam3_litetext/efficient_sam3_image_encoder_mobileclip_s1_ctx16.pt",
+ "l": "sam3_litetext/efficient_sam3_image_encoder_mobileclip2_l_ctx16.pt",
+}
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Convert EfficientSAM3 LiteText checkpoint to HuggingFace format")
+ parser.add_argument(
+ "--model_variant",
+ type=str,
+ choices=list(MODEL_VARIANTS),
+ default=None,
+ help="Model variant to download and convert: 's0' (MobileCLIP-S0, 42M), 's1' (MobileCLIP-S1, 63M), "
+ "or 'l' (MobileCLIP2-L, 124M). Takes precedence over --filename when set.",
+ )
+ parser.add_argument(
+ "--checkpoint_path",
+ type=str,
+ default=None,
+ help="Path to the original .pt checkpoint file. If omitted, the checkpoint is downloaded from the Hub.",
+ )
+ parser.add_argument(
+ "--output_path",
+ type=str,
+ required=True,
+ help="Directory where the converted checkpoint will be saved.",
+ )
+ parser.add_argument(
+ "--repo_id",
+ type=str,
+ default="Simon7108528/EfficientSAM3",
+ help="Hub repository ID to download the checkpoint from (used when --checkpoint_path is not provided).",
+ )
+ parser.add_argument(
+ "--filename",
+ type=str,
+ default="sam3_litetext/efficient_sam3_image_encoder_mobileclip_s0_ctx16.pt",
+ help="Filename within the Hub repository to download (ignored when --model_variant is set).",
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ help="Whether to push the converted model to the Hugging Face Hub.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="Hub repository ID to push to (e.g. 'my-org/sam3-litetext-s0').",
+ )
+ args = parser.parse_args()
+
+ filename = MODEL_VARIANTS[args.model_variant] if args.model_variant else args.filename
+
+ checkpoint_path = args.checkpoint_path
+ if checkpoint_path is None:
+ print(f"Downloading checkpoint {filename} from {args.repo_id}...")
+ checkpoint_path = hf_hub_download(args.repo_id, filename)
+
+ convert_sam3_lite_text_checkpoint(
+ checkpoint_path=checkpoint_path,
+ output_path=args.output_path,
+ push_to_hub=args.push_to_hub,
+ repo_id=args.hub_model_id,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/transformers/models/sam3_lite_text/modeling_sam3_lite_text.py b/src/transformers/models/sam3_lite_text/modeling_sam3_lite_text.py
new file mode 100644
index 000000000000..5a7b02880edd
--- /dev/null
+++ b/src/transformers/models/sam3_lite_text/modeling_sam3_lite_text.py
@@ -0,0 +1,2351 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/sam3_lite_text/modular_sam3_lite_text.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_sam3_lite_text.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2026 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections.abc import Callable, Iterable
+from dataclasses import dataclass
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+from ... import initialization as init
+from ...activations import ACT2FN
+from ...masking_utils import create_bidirectional_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...pytorch_utils import compile_compatible_method_lru_cache
+from ...utils import auto_docstring, can_return_tuple, is_torchvision_available, logging
+from ...utils.generic import TransformersKwargs, is_flash_attention_requested, merge_with_config_defaults
+from ...utils.output_capturing import capture_outputs
+from ..auto import AutoModel
+from .configuration_sam3_lite_text import (
+ Sam3LiteTextConfig,
+ Sam3LiteTextDETRDecoderConfig,
+ Sam3LiteTextDETREncoderConfig,
+ Sam3LiteTextGeometryEncoderConfig,
+ Sam3LiteTextMaskDecoderConfig,
+ Sam3LiteTextTextConfig,
+ Sam3LiteTextViTConfig,
+)
+
+
+if is_torchvision_available():
+ import torchvision
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class Sam3LiteTextTextEncoderOutput(BaseModelOutputWithPooling):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Full sequence of hidden states from the text encoder.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
+ EOT-pooled output projected to `projection_dim` via the internal CLIP-style projection.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Tuple of hidden states at each layer, returned when `output_hidden_states=True`.
+ attentions (`tuple(torch.FloatTensor)`, *optional*):
+ Tuple of attention weights at each transformer layer, returned when `output_attentions=True`.
+ """
+
+
+class Sam3LiteTextTextPositionEmbedding(nn.Module):
+ """Learnable positional embedding with bilinear interpolation for variable sequence lengths."""
+
+ def __init__(self, max_position_embeddings: int, hidden_size: int):
+ super().__init__()
+ self.position_embedding = nn.Parameter(torch.empty(1, 1, max_position_embeddings, hidden_size))
+
+ def forward(self, seq_len: int) -> torch.Tensor:
+ position_embedding = self.position_embedding
+ if seq_len != position_embedding.shape[2]:
+ position_embedding = F.interpolate(
+ position_embedding,
+ size=(seq_len, position_embedding.shape[-1]),
+ mode="bilinear",
+ )
+ return position_embedding.reshape(1, seq_len, -1)
+
+
+class Sam3LiteTextMobileOneBlock(nn.Module):
+ """Depthwise conv branch with batch norm on the skip path and after the conv (MobileOne-style)."""
+
+ def __init__(self, hidden_size: int, kernel_size: int = 3):
+ super().__init__()
+ self.batchnorm_skip = nn.BatchNorm2d(hidden_size)
+ self.conv = nn.Conv2d(
+ hidden_size,
+ hidden_size,
+ kernel_size=(1, kernel_size),
+ stride=1,
+ padding=(0, kernel_size // 2),
+ groups=hidden_size,
+ bias=False,
+ )
+ self.batchnorm_conv = nn.BatchNorm2d(hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.batchnorm_conv(self.conv(hidden_states))
+ hidden_states = hidden_states + self.batchnorm_skip(residual)
+ return hidden_states
+
+
+class Sam3LiteTextConvMLP(nn.Module):
+ """Pointwise MLP using 1×1 convolutions, compatible with 4-D (B, C, H, W) feature maps."""
+
+ def __init__(self, config: Sam3LiteTextTextConfig):
+ super().__init__()
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Conv2d(config.hidden_size, config.intermediate_size, kernel_size=1)
+ self.fc2 = nn.Conv2d(config.intermediate_size, config.hidden_size, kernel_size=1)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Sam3LiteTextConvolutionalFeedForward(nn.Module):
+ """Convolutional feed-forward network: depthwise conv + two pointwise projections."""
+
+ def __init__(self, config: Sam3LiteTextTextConfig):
+ super().__init__()
+ self.depthwise_conv = nn.Conv2d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=(1, config.repmixer_kernel_size),
+ padding=(0, config.repmixer_kernel_size // 2),
+ groups=config.hidden_size,
+ bias=False,
+ )
+ self.depthwise_batchnorm = nn.BatchNorm2d(config.hidden_size)
+ self.mlp = Sam3LiteTextConvMLP(config)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.depthwise_batchnorm(self.depthwise_conv(hidden_states))
+ return self.mlp(hidden_states)
+
+
+class Sam3LiteTextLayerScaledResidual(nn.Module):
+ """Common layer-scale residual pattern shared by the RepMixer and feed-forward branches."""
+
+ def __init__(self, hidden_size: int, layer_scale_init_value: float):
+ super().__init__()
+ self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((hidden_size, 1, 1)), requires_grad=True)
+
+ def layer_scale_residual(self, hidden_states: torch.Tensor, update: torch.Tensor) -> torch.Tensor:
+ return hidden_states + self.layer_scale * update
+
+
+class Sam3LiteTextRepMixer(Sam3LiteTextLayerScaledResidual):
+ """Re-parameterisable depthwise-conv token mixer operating on 1D sequence data."""
+
+ def __init__(self, config: Sam3LiteTextTextConfig):
+ super().__init__(config.hidden_size, config.layer_scale_init_value)
+ self.reference_batchnorm = nn.BatchNorm2d(config.hidden_size)
+ self.mixer = Sam3LiteTextMobileOneBlock(config.hidden_size, kernel_size=config.repmixer_kernel_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return self.layer_scale_residual(
+ hidden_states, self.mixer(hidden_states) - self.reference_batchnorm(hidden_states)
+ )
+
+
+class Sam3LiteTextRepMixerBlock(Sam3LiteTextLayerScaledResidual):
+ """Token-mixing RepMixer plus a convolutional feed-forward path, each with layer scale."""
+
+ def __init__(self, config: Sam3LiteTextTextConfig):
+ super().__init__(config.hidden_size, config.layer_scale_init_value)
+ self.token_mixer = Sam3LiteTextRepMixer(config)
+ self.conv_feed_forward = Sam3LiteTextConvolutionalFeedForward(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ hidden_states = hidden_states.transpose(1, 2).unsqueeze(2)
+ hidden_states = self.token_mixer(hidden_states)
+ hidden_states = self.layer_scale_residual(hidden_states, self.conv_feed_forward(hidden_states))
+ return hidden_states.squeeze(2).transpose(1, 2)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: torch.Tensor | None,
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class Sam3LiteTextTextAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+ self.is_causal = False
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ """Input shape: Batch x Time x Channel"""
+
+ input_shape = hidden_states.shape[:-1]
+
+ hidden_shape = (*input_shape, -1, self.head_dim)
+ queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
+ self.config._attn_implementation, eager_attention_forward
+ )
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ queries,
+ keys,
+ values,
+ attention_mask,
+ is_causal=self.is_causal,
+ scaling=self.scale,
+ dropout=0.0 if not self.training else self.dropout,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class Sam3LiteTextTextMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Sam3LiteTextTextEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Sam3LiteTextTextConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.self_attn = Sam3LiteTextTextAttention(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = Sam3LiteTextTextMLP(config)
+
+ @auto_docstring
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class Sam3LiteTextTextEmbeddings(nn.Module):
+ """Token embedding + interpolatable positional embedding for the text encoder."""
+
+ def __init__(self, config: Sam3LiteTextTextConfig):
+ super().__init__()
+ self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.position_embedding = Sam3LiteTextTextPositionEmbedding(config.max_position_embeddings, config.hidden_size)
+
+ def forward(self, input_ids: torch.LongTensor) -> torch.Tensor:
+ hidden_states = self.token_embedding(input_ids)
+ hidden_states = hidden_states + self.position_embedding(input_ids.shape[1]).to(hidden_states.dtype)
+ return hidden_states
+
+
+class Sam3LiteTextViTRotaryEmbedding(nn.Module):
+ """
+ Vision Rotary Position Embedding for SAM3_LITE_TEXT, following transformers library standards.
+ Supports 2D (axial) rotary embeddings for spatial dimensions.
+ """
+
+ def __init__(self, config: Sam3LiteTextViTConfig, end_x: int, end_y: int, scale: float = 1.0):
+ super().__init__()
+ dim = config.hidden_size // config.num_attention_heads
+ # Ensure even dimension for proper axial splitting
+ if dim % 4 != 0:
+ raise ValueError("Dimension must be divisible by 4 for axial RoPE")
+ self.end_x, self.end_y = end_x, end_y
+ self.dim = dim
+ self.rope_theta = config.rope_theta
+ self.scale = scale
+ freqs = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+
+ flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
+ x_positions = (flattened_indices % end_x) * scale
+ y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * scale
+ freqs_x = torch.outer(x_positions, freqs).float()
+ freqs_y = torch.outer(y_positions, freqs).float()
+ inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
+ inv_freq = inv_freq.repeat_interleave(2, dim=-1)
+ # directly register the cos and sin embeddings as we have a fixed feature shape
+ self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
+ self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
+
+ @torch.no_grad()
+ def forward(self) -> tuple[torch.Tensor, torch.Tensor]:
+ # As the feature map size is fixed for each stage, we can just return the pre-computed embeddings.
+ return self.rope_embeddings_cos, self.rope_embeddings_sin
+
+
+class Sam3LiteTextViTPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config: Sam3LiteTextViTConfig):
+ super().__init__()
+ image_size, patch_size = config.pretrain_image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=False)
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+class Sam3LiteTextViTEmbeddings(nn.Module):
+ """
+ Construct the patch embeddings and position embeddings for SAM3_LITE_TEXT ViT.
+
+ Position embeddings are tiled (not interpolated) when resizing to match different input sizes.
+ """
+
+ def __init__(self, config: Sam3LiteTextViTConfig):
+ super().__init__()
+
+ self.patch_embeddings = Sam3LiteTextViTPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(
+ torch.randn(1, num_patches, config.hidden_size)
+ ) # !Remove cls token in convert weights!
+
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.patch_size = config.patch_size
+
+ def _tile_position_embeddings(
+ self,
+ position_embeddings: torch.Tensor,
+ height: int,
+ width: int,
+ ) -> torch.Tensor:
+ """
+ Tile position embeddings to match target spatial dimensions.
+ Args:
+ position_embeddings: Shape [1, num_pretrain_patches, hidden_size]
+ height: Target height in patches
+ width: Target width in patches
+
+ Returns:
+ Shape [1, height * width, hidden_size]
+ """
+ pretrain_size = int(position_embeddings.shape[1] ** 0.5)
+
+ # Skip tiling if sizes match (but always tile during tracing for consistent graph)
+ if not torch.jit.is_tracing() and pretrain_size == height and pretrain_size == width:
+ return position_embeddings.reshape(1, height * width, -1)
+
+ # Tile position embeddings to match target spatial dimensions
+ hidden_size = position_embeddings.shape[-1]
+ pos_embed = position_embeddings.reshape(1, pretrain_size, pretrain_size, hidden_size).permute(0, 3, 1, 2)
+ repeat_h = height // pretrain_size + 1
+ repeat_w = width // pretrain_size + 1
+ pos_embed = pos_embed.tile([1, 1, repeat_h, repeat_w])[:, :, :height, :width]
+ return pos_embed.permute(0, 2, 3, 1).reshape(1, height * width, hidden_size)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ interpolate_pos_encoding: bool = False,
+ ) -> torch.Tensor:
+ height, width = pixel_values.shape[-2:]
+ embeddings = self.patch_embeddings(pixel_values)
+
+ # Calculate spatial dimensions in patches
+ height_patches = height // self.patch_size
+ width_patches = width // self.patch_size
+
+ position_embeddings = self._tile_position_embeddings(
+ self.position_embeddings,
+ height_patches,
+ width_patches,
+ )
+ embeddings = embeddings + position_embeddings
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+@auto_docstring
+class Sam3LiteTextPreTrainedModel(PreTrainedModel):
+ config_class = Sam3LiteTextConfig
+ base_model_prefix = "model"
+ main_input_name = "pixel_values"
+ input_modalities = ["image", "text"]
+ _supports_sdpa = True
+ _supports_flash_attn = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+ supports_gradient_checkpointing = True
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, Sam3LiteTextViTEmbeddings):
+ init.normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, Sam3LiteTextViTRotaryEmbedding):
+ end_x, end_y = module.end_x, module.end_y
+ dim = module.dim
+ freqs = 1.0 / (module.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+ flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
+ x_positions = (flattened_indices % end_x) * module.scale
+ y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * module.scale
+ freqs_x = torch.outer(x_positions, freqs).float()
+ freqs_y = torch.outer(y_positions, freqs).float()
+ inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
+ inv_freq = inv_freq.repeat_interleave(2, dim=-1)
+ init.copy_(module.rope_embeddings_cos, inv_freq.cos())
+ init.copy_(module.rope_embeddings_sin, inv_freq.sin())
+ if isinstance(module, Sam3LiteTextTextPositionEmbedding):
+ init.normal_(module.position_embedding, std=module.position_embedding.shape[-1] ** -0.5)
+ elif isinstance(module, Sam3LiteTextTextModel):
+ init.normal_(module.projection.weight, std=module.config.hidden_size**-0.5)
+
+
+@auto_docstring(
+ custom_intro="""
+ MobileCLIP MCT text encoder used in EfficientSAM3 LiteText.
+
+ When `config.use_repmixer_blocks` is `True`, the first and last layers are
+ `Sam3LiteTextRepMixerBlock` modules; the rest are standard `Sam3LiteTextTextEncoderLayer` layers.
+"""
+)
+class Sam3LiteTextTextModel(Sam3LiteTextPreTrainedModel):
+ config_class = Sam3LiteTextTextConfig
+ config: Sam3LiteTextTextConfig
+ _can_record_outputs = {
+ "hidden_states": Sam3LiteTextTextEncoderLayer,
+ "attentions": Sam3LiteTextTextAttention,
+ }
+
+ def __init__(self, config: Sam3LiteTextTextConfig):
+ super().__init__(config)
+ self.embeddings = Sam3LiteTextTextEmbeddings(config)
+ repmixer_positions = {0, config.num_hidden_layers - 1} if config.use_repmixer_blocks else set()
+ self.layers = nn.ModuleList(
+ [
+ Sam3LiteTextRepMixerBlock(config) if i in repmixer_positions else Sam3LiteTextTextEncoderLayer(config)
+ for i in range(config.num_hidden_layers)
+ ]
+ )
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
+ self.post_init()
+
+ @merge_with_config_defaults
+ @capture_outputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | Sam3LiteTextTextEncoderOutput:
+ hidden_states = self.embeddings(input_ids)
+ attention_mask = create_bidirectional_mask(self.config, hidden_states, attention_mask)
+
+ for layer in self.layers:
+ hidden_states = layer(hidden_states, attention_mask=attention_mask, **kwargs)
+
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ pooled = hidden_states[
+ torch.arange(hidden_states.shape[0], device=hidden_states.device), input_ids.argmax(dim=-1)
+ ]
+ pooled = self.projection(pooled)
+ return Sam3LiteTextTextEncoderOutput(
+ last_hidden_state=hidden_states,
+ pooler_output=pooled,
+ )
+
+
+@dataclass
+@auto_docstring
+class Sam3LiteTextVisionEncoderOutput(BaseModelOutputWithPooling):
+ r"""
+ fpn_hidden_states (`tuple[torch.FloatTensor]`):
+ Tuple of multi-level FPN feature maps.
+ fpn_position_encoding (`tuple[torch.FloatTensor]`):
+ Tuple of position encodings for each FPN level.
+ """
+
+ fpn_hidden_states: tuple[torch.FloatTensor, ...] = None
+ fpn_position_encoding: tuple[torch.FloatTensor, ...] = None
+
+
+@dataclass
+@auto_docstring
+class Sam3LiteTextGeometryEncoderOutput(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_prompts, hidden_size)`):
+ Encoded geometry prompt features (boxes).
+ attention_mask (`torch.BoolTensor` of shape `(batch_size, num_prompts)`, *optional*):
+ Attention mask for geometry prompts where True indicates valid positions and False indicates padding.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ attention_mask: torch.BoolTensor | None = None
+
+
+@dataclass
+@auto_docstring
+class Sam3LiteTextDETREncoderOutput(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Encoded vision features (flattened from multi-level features).
+ pos_embeds_flattened (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Flattened position embeddings for the vision features.
+ text_features (`torch.FloatTensor` of shape `(batch_size, text_seq_len, hidden_size)`, *optional*):
+ Text features (may be pooled after encoder processing).
+ spatial_shapes (`torch.LongTensor` of shape `(num_levels, 2)`, *optional*):
+ Spatial shapes (height, width) for each feature pyramid level.
+ hidden_states (`tuple[torch.FloatTensor]`, *optional*):
+ Tuple of hidden states from all encoder layers.
+ attentions (`tuple[torch.FloatTensor]`, *optional*):
+ Tuple of attention weights from all encoder layers.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ pos_embeds_flattened: torch.FloatTensor | None = None
+ text_features: torch.FloatTensor | None = None
+ spatial_shapes: torch.LongTensor | None = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+
+
+@dataclass
+@auto_docstring
+class Sam3LiteTextDETRDecoderOutput(ModelOutput):
+ r"""
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, hidden_size)`):
+ Decoder hidden states from all layers.
+ reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`):
+ Predicted reference boxes from all decoder layers in (cx, cy, w, h) format.
+ presence_logits (`torch.FloatTensor` of shape `(num_layers, batch_size, 1)`):
+ Presence logits from all decoder layers indicating object presence confidence.
+ hidden_states (`tuple[torch.FloatTensor]`, *optional*):
+ Tuple of hidden states from all decoder layers.
+ attentions (`tuple[torch.FloatTensor]`, *optional*):
+ Tuple of attention weights from all decoder layers (self-attention and cross-attention).
+ """
+
+ intermediate_hidden_states: torch.FloatTensor = None
+ reference_boxes: torch.FloatTensor = None
+ presence_logits: torch.FloatTensor = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+
+
+@dataclass
+@auto_docstring
+class Sam3LiteTextMaskDecoderOutput(ModelOutput):
+ r"""
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`):
+ Predicted segmentation masks for each query.
+ semantic_seg (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*):
+ Semantic segmentation output.
+ attentions (`tuple[torch.FloatTensor]`, *optional*):
+ Tuple of attention weights from mask decoder cross-attention layers.
+ """
+
+ pred_masks: torch.FloatTensor = None
+ semantic_seg: torch.FloatTensor | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+
+
+@dataclass
+@auto_docstring
+class Sam3LiteTextImageSegmentationOutput(ModelOutput):
+ r"""
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`):
+ Predicted segmentation masks for each query.
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Predicted bounding boxes in (x1, y1, x2, y2) format.
+ pred_logits (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
+ Classification confidence scores for each query, computed via dot product between
+ decoder query features and text features.
+ presence_logits (`torch.FloatTensor` of shape `(batch_size, 1)`, *optional*):
+ Presence logits from the DETR decoder presence token (last layer only). These indicate whether objects
+ are present in the scene. Can be used to compute final scores by multiplying with pred_logits:
+ `final_scores = pred_logits.sigmoid() * presence_logits.sigmoid()`.
+ semantic_seg (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*):
+ Semantic segmentation output.
+ decoder_hidden_states (`tuple[torch.FloatTensor]`, *optional*):
+ Tuple of hidden states from all DETR decoder layers. Each tensor has shape `(batch_size, num_queries, hidden_size)`.
+ decoder_reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`, *optional*):
+ Reference boxes from all DETR decoder layers.
+ encoder_hidden_states (`tuple[torch.FloatTensor]`, *optional*):
+ Tuple of hidden states from all DETR encoder layers.
+ vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*):
+ Tuple of hidden states from all vision encoder (ViT) layers.
+ vision_attentions (`tuple[torch.FloatTensor]`, *optional*):
+ Attention weights from vision encoder (ViT) layers.
+ detr_encoder_attentions (`tuple[torch.FloatTensor]`, *optional*):
+ Attention weights from DETR encoder layers.
+ detr_decoder_attentions (`tuple[torch.FloatTensor]`, *optional*):
+ Attention weights from DETR decoder layers (self-attention and cross-attention).
+ mask_decoder_attentions (`tuple[torch.FloatTensor]`, *optional*):
+ Attention weights from mask decoder layers.
+ """
+
+ pred_masks: torch.FloatTensor = None
+ pred_boxes: torch.FloatTensor = None
+ pred_logits: torch.FloatTensor | None = None
+ presence_logits: torch.FloatTensor | None = None
+ semantic_seg: torch.FloatTensor | None = None
+ decoder_hidden_states: tuple[torch.FloatTensor] | None = None
+ decoder_reference_boxes: torch.FloatTensor | None = None
+ encoder_hidden_states: tuple[torch.FloatTensor] | None = None
+ vision_hidden_states: tuple[torch.FloatTensor] | None = None
+ vision_attentions: tuple[torch.FloatTensor] | None = None
+ detr_encoder_attentions: tuple[torch.FloatTensor] | None = None
+ detr_decoder_attentions: tuple[torch.FloatTensor] | None = None
+ mask_decoder_attentions: tuple[torch.FloatTensor] | None = None
+
+
+class Sam3LiteTextMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Sam3LiteTextAttention(nn.Module):
+ """
+ Multi-head attention.
+ Handles standard [batch_size, seq_len, hidden_size] tensors.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.num_attention_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // config.num_attention_heads
+ self.scaling = self.head_dim**-0.5
+ self.is_causal = False
+
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size)
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size)
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size)
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ query: [batch_size, query_len, hidden_size]
+ key: [batch_size, key_len, hidden_size]
+ value: [batch_size, value_len, hidden_size]
+ attention_mask: [batch_size, num_heads, query_len, key_len] or broadcastable
+
+ Returns:
+ Tuple of (output, attention_weights)
+ output: [batch_size, query_len, hidden_size]
+ attention_weights: [batch_size, num_heads, query_len, key_len]
+ """
+ batch_size = query.shape[0]
+ query_len = query.shape[1]
+ key_len = key.shape[1]
+
+ query = self.q_proj(query).view(batch_size, query_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
+ key = self.k_proj(key).view(batch_size, key_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
+ value = self.v_proj(value).view(batch_size, key_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
+
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
+ self.config._attn_implementation, eager_attention_forward
+ )
+
+ if (
+ is_flash_attention_requested(self.config)
+ and attention_mask is not None
+ and attention_mask.dtype != torch.bool
+ ):
+ # Relative position bias tensors are represented as float masks and are incompatible with Flash Attention
+ # Fallback to SDPA for this call only so the rest of the model can still benefit from FA
+ attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
+ logger.warning_once(
+ "Sam3LiteTextAttention: falling back to SDPA for relative-position cross-attention because "
+ "Flash Attention does not support additive bias masks."
+ )
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=attention_mask,
+ dropout=0.0,
+ scaling=self.scaling,
+ is_causal=self.is_causal,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(batch_size, query_len, self.num_attention_heads * self.head_dim).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class Sam3LiteTextSinePositionEmbedding(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
+ need paper, generalized to work on images.
+ """
+
+ def __init__(
+ self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: float | None = None
+ ):
+ super().__init__()
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ self.scale = 2 * math.pi if scale is None else scale
+
+ def encode_1d_positions(self, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Encode 1D coordinate pairs using sine/cosine positional embeddings.
+
+ Args:
+ x: 1D tensor of x coordinates (flattened)
+ y: 1D tensor of y coordinates (flattened)
+
+ Returns:
+ Tuple of (pos_x, pos_y) positional embeddings
+ """
+ x_embed = x * self.scale
+ y_embed = y * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=x.device).to(x.dtype)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, None] / dim_t
+ pos_y = y_embed[:, None] / dim_t
+ pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
+ pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
+ return pos_x, pos_y
+
+ def encode_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+ """
+ Encode 4D box coordinates (x, y, w, h) for decoder conditioning using sine/cosine embeddings.
+
+ Args:
+ boxes: Box coordinates [batch_size, num_queries, 4] in (x, y, w, h) format
+
+ Returns:
+ Position embeddings [batch_size, num_queries, num_pos_feats*4]
+ """
+ assert boxes.size(-1) == 4, f"Expected 4D box coordinates (x, y, w, h), got shape {boxes.shape}"
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=boxes.device).to(boxes.dtype)
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
+
+ x_embed = boxes[:, :, 0] * self.scale
+ y_embed = boxes[:, :, 1] * self.scale
+ w_embed = boxes[:, :, 2] * self.scale
+ h_embed = boxes[:, :, 3] * self.scale
+
+ pos_x = x_embed[:, :, None] / dim_t
+ pos_y = y_embed[:, :, None] / dim_t
+ pos_w = w_embed[:, :, None] / dim_t
+ pos_h = h_embed[:, :, None] / dim_t
+
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
+ pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
+ pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
+
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
+
+ return pos
+
+ @compile_compatible_method_lru_cache(maxsize=4)
+ def forward(
+ self,
+ shape: torch.Size,
+ device: torch.device | str,
+ dtype: torch.dtype,
+ mask: Tensor | None = None,
+ ) -> Tensor:
+ if mask is None:
+ mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
+ not_mask = (~mask).to(dtype)
+ y_embed = not_mask.cumsum(1)
+ x_embed = not_mask.cumsum(2)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype)
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+class Sam3LiteTextGeometryEncoderLayer(nn.Module):
+ def __init__(self, config: Sam3LiteTextGeometryEncoderConfig):
+ super().__init__()
+ self.layer_norm1 = nn.LayerNorm(config.hidden_size)
+ self.self_attn = Sam3LiteTextAttention(config)
+ self.dropout = nn.Dropout(config.dropout)
+
+ self.cross_attn = Sam3LiteTextAttention(config)
+ self.layer_norm2 = nn.LayerNorm(config.hidden_size)
+
+ self.mlp = Sam3LiteTextMLP(config)
+ self.layer_norm3 = nn.LayerNorm(config.hidden_size)
+
+ def forward(
+ self,
+ prompt_feats: Tensor,
+ vision_feats: Tensor,
+ vision_pos_encoding: Tensor,
+ prompt_mask: Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ residual = prompt_feats
+ hidden_states = self.layer_norm1(prompt_feats)
+ hidden_states, _ = self.self_attn(
+ query=hidden_states, key=hidden_states, value=hidden_states, attention_mask=prompt_mask, **kwargs
+ )
+ hidden_states = self.dropout(hidden_states) + residual
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ key = vision_feats + vision_pos_encoding
+ hidden_states, _ = self.cross_attn(query=hidden_states, key=key, value=vision_feats, **kwargs)
+ hidden_states = self.dropout(hidden_states) + residual
+ residual = hidden_states
+ hidden_states = self.layer_norm3(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.dropout(hidden_states) + residual
+
+ return hidden_states
+
+
+def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False):
+ """
+ Concatenates two right-padded sequences, such that the resulting sequence
+ is contiguous and also right-padded.
+
+ Tensors are batch-first, masks are batch-first with True=valid, False=padding.
+
+ Args:
+ seq1: A tensor of shape (batch_size, seq1_length, hidden_size).
+ mask1: A tensor of shape (batch_size, seq1_length) with True=valid, False=padding.
+ seq2: A tensor of shape (batch_size, seq2_length, hidden_size).
+ mask2: A tensor of shape (batch_size, seq2_length) with True=valid, False=padding.
+ return_index: If True, also returns the index of the ids of the element of seq2
+ in the concatenated sequence. This can be used to retrieve the elements of seq2.
+
+ Returns:
+ A tuple (concatenated_sequence, concatenated_mask) if return_index is False,
+ otherwise (concatenated_sequence, concatenated_mask, index).
+ The concatenated_mask uses True=valid, False=padding convention.
+ """
+ batch_size, seq1_length, hidden_size = seq1.shape
+ batch_size2, seq2_length, hidden_size2 = seq2.shape
+
+ assert batch_size == batch_size2 == mask1.size(0) == mask2.size(0)
+ assert hidden_size == hidden_size2
+ assert seq1_length == mask1.size(1)
+ assert seq2_length == mask2.size(1)
+
+ actual_seq1_lengths = mask1.sum(dim=-1)
+ actual_seq2_lengths = mask2.sum(dim=-1)
+
+ final_lengths = actual_seq1_lengths + actual_seq2_lengths
+ max_length = seq1_length + seq2_length
+
+ concatenated_mask = (
+ torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1) < final_lengths[:, None]
+ )
+
+ concatenated_sequence = torch.zeros((batch_size, max_length, hidden_size), device=seq2.device, dtype=seq2.dtype)
+ concatenated_sequence[:, :seq1_length, :] = seq1
+
+ # Shift seq2 elements to start at the end of valid seq1
+ index = torch.arange(seq2_length, device=seq2.device)[None].repeat(batch_size, 1)
+ index = index + actual_seq1_lengths[:, None]
+
+ # Scatter seq2 into the right positions
+ concatenated_sequence = concatenated_sequence.scatter(1, index[:, :, None].expand(-1, -1, hidden_size), seq2)
+
+ if return_index:
+ return concatenated_sequence, concatenated_mask, index
+
+ return concatenated_sequence, concatenated_mask
+
+
+def box_cxcywh_to_xyxy(x):
+ """Convert boxes from (cx, cy, w, h) format to (x1, y1, x2, y2) format."""
+ x_c, y_c, w, h = x.unbind(-1)
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
+ return torch.stack(b, dim=-1)
+
+
+class Sam3LiteTextGeometryEncoder(nn.Module):
+ """
+ Encoder for geometric prompts (boxes).
+
+ Boxes are encoded using three approaches:
+ - Direct projection: linear projection from coordinate space to hidden_size
+ - Pooling: pool features from the backbone at the specified location (ROI align for boxes)
+ - Position encoding: use position encoding of the box center
+
+ These encodings are combined additively and further processed with transformer layers.
+ """
+
+ def __init__(self, config: Sam3LiteTextGeometryEncoderConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.roi_size = config.roi_size
+
+ self.position_encoding = Sam3LiteTextSinePositionEmbedding(
+ num_pos_feats=config.hidden_size // 2, normalize=True
+ )
+ self.label_embed = nn.Embedding(2, self.hidden_size)
+ self.cls_embed = nn.Embedding(1, self.hidden_size)
+
+ # Box encoding layers
+ self.boxes_direct_project = nn.Linear(4, self.hidden_size)
+ self.boxes_pool_project = nn.Conv2d(self.hidden_size, self.hidden_size, self.roi_size)
+ self.boxes_pos_enc_project = nn.Linear(self.hidden_size + 2, self.hidden_size)
+
+ # Image feature normalization
+ self.vision_layer_norm = nn.LayerNorm(self.hidden_size)
+
+ # Prompt projection and normalization
+ self.final_proj = nn.Linear(self.hidden_size, self.hidden_size)
+ self.prompt_layer_norm = nn.LayerNorm(self.hidden_size)
+
+ # Transformer layers
+ self.layers = nn.ModuleList([Sam3LiteTextGeometryEncoderLayer(config) for _ in range(config.num_layers)])
+ self.output_layer_norm = nn.LayerNorm(self.hidden_size)
+
+ def _encode_box_coordinates(
+ self, center_x: torch.Tensor, center_y: torch.Tensor, width: torch.Tensor, height: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Encode box coordinates by combining position-encoded centers with raw width/height.
+
+ Args:
+ center_x: 1D tensor of box center x coordinates
+ center_y: 1D tensor of box center y coordinates
+ width: 1D tensor of box widths
+ height: 1D tensor of box heights
+
+ Returns:
+ Encoded box coordinates [N, embedding_dim]
+ """
+ pos_x, pos_y = self.position_encoding.encode_1d_positions(center_x, center_y)
+ pos = torch.cat((pos_y, pos_x, height[:, None], width[:, None]), dim=1)
+ return pos
+
+ def _encode_boxes(self, boxes, boxes_mask, boxes_labels, vision_features):
+ """Encode box prompts. Mask convention: True=valid, False=padding."""
+ batch_size, num_boxes = boxes.shape[:2]
+ height, width = vision_features.shape[-2:]
+ boxes_embed = self.boxes_direct_project(boxes)
+
+ # Pool features using ROI align
+ # Convert boxes from CxCyWH to xyxy format and denormalize
+ boxes_xyxy = box_cxcywh_to_xyxy(boxes)
+ scale = torch.tensor([width, height, width, height], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device)
+ scale = scale.view(1, 1, 4)
+ boxes_xyxy = boxes_xyxy * scale
+ # ROI align expects list of boxes per batch element,
+ # convert from bfloat16 to float16 as roi_align only supports float16 and float32
+ dtype = torch.float16 if vision_features.dtype == torch.bfloat16 else vision_features.dtype
+ sampled_features = torchvision.ops.roi_align(
+ vision_features.to(dtype), boxes_xyxy.to(dtype).unbind(0), self.roi_size
+ ).to(vision_features.dtype)
+
+ pooled_projection = self.boxes_pool_project(sampled_features)
+ pooled_projection = pooled_projection.view(batch_size, num_boxes, self.hidden_size)
+ boxes_embed = boxes_embed + pooled_projection
+
+ # Add position encoding
+ center_x, center_y, box_width, box_height = boxes.unbind(-1)
+ pos_enc = self._encode_box_coordinates(
+ center_x.flatten(), center_y.flatten(), box_width.flatten(), box_height.flatten()
+ )
+ pos_enc = pos_enc.view(batch_size, num_boxes, pos_enc.shape[-1])
+ pos_projection = self.boxes_pos_enc_project(pos_enc)
+ boxes_embed = boxes_embed + pos_projection
+
+ # Add label embeddings (positive/negative)
+ label_embed = self.label_embed(boxes_labels.long())
+ return label_embed + boxes_embed, boxes_mask
+
+ def forward(
+ self,
+ box_embeddings: torch.Tensor,
+ box_mask: torch.Tensor,
+ box_labels: torch.Tensor,
+ img_feats: tuple[torch.Tensor, ...],
+ img_pos_embeds: tuple[torch.Tensor, ...] | None = None,
+ ):
+ """
+ Forward pass for encoding geometric prompts.
+
+ Args:
+ box_embeddings: Box coordinates in CxCyWH format [batch_size, num_boxes, 4]
+ box_mask: Attention mask for boxes [batch_size, num_boxes]
+ box_labels: Labels for boxes (positive/negative) [batch_size, num_boxes]
+ img_feats: Image features from vision encoder
+ img_pos_embeds: Optional position embeddings for image features
+
+ Returns:
+ Sam3LiteTextGeometryEncoderOutput containing encoded geometry features and attention mask.
+ """
+ batch_size = box_embeddings.shape[0]
+
+ # Prepare vision features for cross-attention: flatten spatial dimensions
+ vision_feats = img_feats[-1] # [B, C, H, W]
+ vision_pos_embeds = img_pos_embeds[-1] if img_pos_embeds is not None else torch.zeros_like(vision_feats)
+ vision_feats_flat = vision_feats.flatten(2).transpose(1, 2) # [B, H*W, C]
+ vision_pos_embeds_flat = vision_pos_embeds.flatten(2).transpose(1, 2) # [B, H*W, C]
+
+ # Normalize image features for pooling operations
+ img_feats_last = img_feats[-1] # [B, C, H, W]
+ img_feats_last = img_feats_last.permute(0, 2, 3, 1) # [B, H, W, C]
+ normalized_img_feats = self.vision_layer_norm(img_feats_last)
+ normalized_img_feats = normalized_img_feats.permute(0, 3, 1, 2) # [B, C, H, W]
+
+ prompt_embeds, prompt_mask = self._encode_boxes(box_embeddings, box_mask, box_labels, normalized_img_feats)
+
+ # Add CLS token (always valid)
+ cls_embed = self.cls_embed.weight.view(1, self.hidden_size).unsqueeze(0).expand(batch_size, -1, -1)
+ cls_mask = torch.ones(batch_size, 1, dtype=prompt_mask.dtype, device=prompt_mask.device)
+ prompt_embeds, prompt_mask = concat_padded_sequences(prompt_embeds, prompt_mask, cls_embed, cls_mask)
+
+ prompt_embeds = self.prompt_layer_norm(self.final_proj(prompt_embeds))
+
+ # Create bidirectional attention mask for transformer layers
+ prompt_attention_mask = None
+ if prompt_mask is not None:
+ prompt_attention_mask = create_bidirectional_mask(
+ config=self.config,
+ inputs_embeds=prompt_embeds,
+ attention_mask=prompt_mask,
+ )
+
+ # Apply transformer layers with cross-attention to vision features
+ for layer in self.layers:
+ prompt_embeds = layer(
+ prompt_feats=prompt_embeds,
+ vision_feats=vision_feats_flat,
+ vision_pos_encoding=vision_pos_embeds_flat,
+ prompt_mask=prompt_attention_mask,
+ )
+
+ # Final output normalization
+ prompt_embeds = self.output_layer_norm(prompt_embeds)
+
+ return Sam3LiteTextGeometryEncoderOutput(
+ last_hidden_state=prompt_embeds,
+ attention_mask=prompt_mask,
+ )
+
+
+class Sam3LiteTextDetrEncoderLayer(nn.Module):
+ """DETR encoder layer with self-attention and cross-attention."""
+
+ def __init__(self, config: Sam3LiteTextDETREncoderConfig):
+ super().__init__()
+ self.config = config
+ self.layer_norm1 = nn.LayerNorm(config.hidden_size)
+ self.self_attn = Sam3LiteTextAttention(config)
+ self.dropout = nn.Dropout(config.dropout)
+
+ self.cross_attn = Sam3LiteTextAttention(config)
+ self.layer_norm2 = nn.LayerNorm(config.hidden_size)
+
+ self.mlp = Sam3LiteTextMLP(config)
+ self.layer_norm3 = nn.LayerNorm(config.hidden_size)
+
+ def forward(
+ self,
+ vision_feats: Tensor,
+ prompt_feats: Tensor,
+ vision_pos_encoding: Tensor,
+ prompt_cross_attn_mask: Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ """
+ Forward pass for DETR encoder layer.
+
+ Args:
+ vision_feats: Vision features [batch_size, vision_len, hidden_size] (main hidden states)
+ prompt_feats: Text prompt features [batch_size, text_len, hidden_size]
+ vision_pos_encoding: Position encoding for vision [batch_size, vision_len, hidden_size]
+ prompt_cross_attn_mask: Cross-attention mask for prompt features
+
+ Returns:
+ Updated vision features [batch_size, vision_len, hidden_size]
+ """
+ # Self-attention on vision features with position encoding
+ residual = vision_feats
+ hidden_states = self.layer_norm1(vision_feats)
+ hidden_states_with_pos = hidden_states + vision_pos_encoding
+ hidden_states, _ = self.self_attn(
+ query=hidden_states_with_pos,
+ key=hidden_states_with_pos,
+ value=hidden_states,
+ **kwargs,
+ )
+ hidden_states = self.dropout(hidden_states) + residual
+
+ # Cross-attention: vision queries attend to text/prompt features
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+
+ hidden_states, _ = self.cross_attn(
+ query=hidden_states,
+ key=prompt_feats,
+ value=prompt_feats,
+ attention_mask=prompt_cross_attn_mask,
+ **kwargs,
+ )
+ hidden_states = self.dropout(hidden_states) + residual
+
+ # MLP
+ residual = hidden_states
+ hidden_states = self.layer_norm3(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.dropout(hidden_states) + residual
+
+ return hidden_states
+
+
+class Sam3LiteTextDetrEncoder(Sam3LiteTextPreTrainedModel):
+ """
+ DETR-style encoder that processes multi-level vision features with text fusion.
+
+ This encoder processes vision features from multiple levels (e.g., FPN features at different
+ resolutions) and fuses them with text prompts through a stack of transformer encoder layers.
+ """
+
+ _can_record_outputs = {
+ "hidden_states": Sam3LiteTextDetrEncoderLayer,
+ "attentions": Sam3LiteTextAttention,
+ }
+
+ def __init__(self, config: Sam3LiteTextDETREncoderConfig):
+ super().__init__(config)
+ self.config = config
+ self.hidden_size = config.hidden_size
+
+ self.layers = nn.ModuleList([Sam3LiteTextDetrEncoderLayer(config) for _ in range(config.num_layers)])
+
+ self.post_init()
+
+ def _prepare_multilevel_features(
+ self,
+ vision_features: list[torch.Tensor],
+ vision_pos_embeds: list[torch.Tensor],
+ ):
+ """
+ Prepare multi-level vision features by flattening spatial dimensions and adding level embeddings.
+
+ Args:
+ vision_features: List of vision features at different levels [batch_size, channels, height, width]
+ vision_pos_embeds: List of position embeddings for each level [batch_size, channels, height, width]
+
+ Returns:
+ Tuple containing flattened features, position embeddings, and spatial metadata
+ """
+ features_flattened = []
+ pos_embeds_flattened = []
+ spatial_shapes = []
+
+ for features, pos_embed in zip(vision_features, vision_pos_embeds):
+ height, width = features.shape[-2:]
+ spatial_shapes.append((height, width))
+
+ # Flatten spatial dimensions: [batch_size, channels, height, width] -> [batch_size, height*width, channels]
+ features = features.flatten(2).transpose(1, 2)
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
+
+ features_flattened.append(features)
+ pos_embeds_flattened.append(pos_embed)
+
+ # Concatenate all levels into single sequence
+ features_flattened = torch.cat(features_flattened, dim=1)
+ pos_embeds_flattened = torch.cat(pos_embeds_flattened, dim=1)
+
+ spatial_shapes = torch.tensor(spatial_shapes, dtype=torch.long, device=features_flattened.device)
+
+ return (
+ features_flattened,
+ pos_embeds_flattened,
+ spatial_shapes,
+ )
+
+ @merge_with_config_defaults
+ @capture_outputs
+ def forward(
+ self,
+ vision_features: list[torch.Tensor],
+ text_features: torch.Tensor,
+ vision_pos_embeds: list[torch.Tensor] | None = None,
+ text_mask: torch.Tensor | None = None,
+ spatial_sizes: list[tuple[int, int]] | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | Sam3LiteTextDETREncoderOutput:
+ """
+ Forward pass for the DETR encoder.
+
+ Args:
+ vision_features: List of vision features at different levels
+ text_features: Text prompt features [batch_size, seq_len, hidden_size]
+ vision_pos_embeds: Optional list of position embeddings for each level
+ text_mask: Optional text padding mask [batch_size, seq_len]
+ spatial_sizes: Optional list of (height, width) tuples for reshaping
+
+ Returns:
+ Sam3LiteTextDETREncoderOutput containing encoded features and metadata.
+ """
+ batch_size = vision_features[0].shape[0] if vision_features[0].dim() == 4 else vision_features[0].shape[1]
+
+ # TODO: See if we can remove that reshaping and just use the features as is.
+ if spatial_sizes is not None:
+ for i, (height, width) in enumerate(spatial_sizes):
+ # Reshape from [height*width, batch_size, channels] to [batch_size, channels, height, width]
+ vision_features[i] = vision_features[i].reshape(height, width, batch_size, -1).permute(2, 3, 0, 1)
+ vision_pos_embeds[i] = vision_pos_embeds[i].reshape(height, width, batch_size, -1).permute(2, 3, 0, 1)
+
+ # Flatten multi-level features for encoder processing
+ (
+ features_flattened,
+ pos_embeds_flattened,
+ spatial_shapes,
+ ) = self._prepare_multilevel_features(vision_features, vision_pos_embeds)
+
+ prompt_cross_attn_mask = None
+ if text_mask is not None:
+ prompt_cross_attn_mask = create_bidirectional_mask(
+ config=self.config,
+ inputs_embeds=features_flattened,
+ attention_mask=text_mask,
+ encoder_hidden_states=text_features,
+ )
+
+ hidden_states = features_flattened
+ for layer in self.layers:
+ hidden_states = layer(
+ hidden_states,
+ prompt_feats=text_features,
+ vision_pos_encoding=pos_embeds_flattened,
+ prompt_cross_attn_mask=prompt_cross_attn_mask,
+ **kwargs,
+ )
+ return Sam3LiteTextDETREncoderOutput(
+ last_hidden_state=hidden_states,
+ pos_embeds_flattened=pos_embeds_flattened,
+ text_features=text_features,
+ spatial_shapes=spatial_shapes,
+ )
+
+
+class Sam3LiteTextDecoderMLP(nn.Module):
+ """Simple 2 or 3-layer MLP for decoder components."""
+
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 2):
+ super().__init__()
+ if num_layers == 2:
+ self.layer1 = nn.Linear(input_dim, hidden_dim)
+ self.layer2 = nn.Linear(hidden_dim, output_dim)
+ self.layer3 = None
+ elif num_layers == 3:
+ self.layer1 = nn.Linear(input_dim, hidden_dim)
+ self.layer2 = nn.Linear(hidden_dim, hidden_dim)
+ self.layer3 = nn.Linear(hidden_dim, output_dim)
+ else:
+ raise ValueError(f"Only 2 or 3 layers supported, got {num_layers}")
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = F.relu(self.layer1(x))
+ if self.layer3 is not None:
+ x = F.relu(self.layer2(x))
+ x = self.layer3(x)
+ else:
+ x = self.layer2(x)
+ return x
+
+
+class Sam3LiteTextDetrDecoderLayer(nn.Module):
+ """DETR decoder layer with self-attention, text cross-attention, and vision cross-attention."""
+
+ def __init__(self, config: Sam3LiteTextDETRDecoderConfig):
+ super().__init__()
+ self.config = config
+ self.self_attn = Sam3LiteTextAttention(config)
+ self.self_attn_dropout = nn.Dropout(config.dropout)
+ self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size)
+
+ self.text_cross_attn = Sam3LiteTextAttention(config)
+ self.text_cross_attn_dropout = nn.Dropout(config.dropout)
+ self.text_cross_attn_layer_norm = nn.LayerNorm(config.hidden_size)
+
+ self.vision_cross_attn = Sam3LiteTextAttention(config)
+ self.vision_cross_attn_dropout = nn.Dropout(config.dropout)
+ self.vision_cross_attn_layer_norm = nn.LayerNorm(config.hidden_size)
+
+ self.mlp = Sam3LiteTextMLP(config)
+ self.mlp_layer_norm = nn.LayerNorm(config.hidden_size)
+ self.mlp_dropout = nn.Dropout(config.dropout)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ query_pos: torch.Tensor,
+ text_features: torch.Tensor,
+ vision_features: torch.Tensor,
+ vision_pos_encoding: torch.Tensor,
+ text_cross_attn_mask: torch.Tensor | None = None,
+ vision_cross_attn_mask: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ """
+ Forward pass for decoder layer.
+
+ Args:
+ hidden_states: Query features [batch_size, num_queries + 1, hidden_size] (includes presence token at position 0)
+ query_pos: Query position embeddings [batch_size, num_queries, hidden_size]
+ text_features: Text features [batch_size, seq_len, hidden_size]
+ vision_features: Vision features [batch_size, height*width, hidden_size]
+ vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size]
+ text_cross_attn_mask: Text cross-attention mask
+ vision_cross_attn_mask: Vision cross-attention mask, already expanded for presence token
+
+ Returns:
+ Updated hidden states (including presence token at position 0)
+ """
+ # Prepend zeros to query_pos for presence token
+ query_pos = F.pad(query_pos, (0, 0, 1, 0), mode="constant", value=0)
+
+ # Self-attention with query position encoding
+ residual = hidden_states
+ query_with_pos = hidden_states + query_pos
+ attn_output, _ = self.self_attn(
+ query=query_with_pos,
+ key=query_with_pos,
+ value=hidden_states,
+ attention_mask=None,
+ **kwargs,
+ )
+ hidden_states = residual + self.self_attn_dropout(attn_output)
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Text cross-attention: queries attend to text features
+ residual = hidden_states
+ query_with_pos = hidden_states + query_pos
+
+ attn_output, _ = self.text_cross_attn(
+ query=query_with_pos,
+ key=text_features,
+ value=text_features,
+ attention_mask=text_cross_attn_mask,
+ **kwargs,
+ )
+ hidden_states = residual + self.text_cross_attn_dropout(attn_output)
+ hidden_states = self.text_cross_attn_layer_norm(hidden_states)
+
+ # Vision cross-attention: queries attend to vision features (with RPB)
+ residual = hidden_states
+ query_with_pos = hidden_states + query_pos
+ key_with_pos = vision_features + vision_pos_encoding
+ attn_output, _ = self.vision_cross_attn(
+ query=query_with_pos,
+ key=key_with_pos,
+ value=vision_features,
+ attention_mask=vision_cross_attn_mask,
+ **kwargs,
+ )
+ hidden_states = residual + self.vision_cross_attn_dropout(attn_output)
+ hidden_states = self.vision_cross_attn_layer_norm(hidden_states)
+
+ # MLP
+ residual = hidden_states
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + self.mlp_dropout(hidden_states)
+ hidden_states = self.mlp_layer_norm(hidden_states)
+
+ return hidden_states
+
+
+def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
+ """The inverse function for sigmoid activation function."""
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1 / x2)
+
+
+class Sam3LiteTextDetrDecoder(Sam3LiteTextPreTrainedModel):
+ """
+ DETR-style decoder with box refinement and presence token.
+
+ Simplified version that assumes:
+ - Box refinement is always enabled
+ - Intermediate outputs are always returned
+ - BoxRPB (relative position bias) with log-scale encoding
+ - Presence token is used
+ """
+
+ _can_record_outputs = {
+ "hidden_states": Sam3LiteTextDetrDecoderLayer,
+ "attentions": Sam3LiteTextAttention,
+ }
+
+ def __init__(
+ self,
+ config: Sam3LiteTextDETRDecoderConfig,
+ ):
+ super().__init__(config)
+ self.config = config
+ self.hidden_size = config.hidden_size
+
+ self.layers = nn.ModuleList([Sam3LiteTextDetrDecoderLayer(config) for _ in range(config.num_layers)])
+
+ self.output_layer_norm = nn.LayerNorm(config.hidden_size)
+
+ self.box_head = Sam3LiteTextDecoderMLP(config.hidden_size, config.hidden_size, 4, 3)
+
+ self.query_embed = nn.Embedding(config.num_queries, config.hidden_size)
+ self.reference_points = nn.Embedding(config.num_queries, 4)
+
+ self.presence_token = nn.Embedding(1, config.hidden_size)
+ self.presence_head = Sam3LiteTextDecoderMLP(config.hidden_size, config.hidden_size, 1, 3)
+ self.presence_layer_norm = nn.LayerNorm(config.hidden_size)
+ self.clamp_presence_logit_max_val = 10.0
+
+ self.ref_point_head = Sam3LiteTextDecoderMLP(2 * config.hidden_size, config.hidden_size, config.hidden_size, 2)
+
+ self.box_rpb_embed_x = Sam3LiteTextDecoderMLP(2, config.hidden_size, config.num_attention_heads, 2)
+ self.box_rpb_embed_y = Sam3LiteTextDecoderMLP(2, config.hidden_size, config.num_attention_heads, 2)
+
+ self.position_encoding = Sam3LiteTextSinePositionEmbedding(
+ num_pos_feats=config.hidden_size // 2, normalize=False
+ )
+
+ self.post_init()
+
+ @compile_compatible_method_lru_cache(maxsize=1)
+ def _get_coords(
+ self, height: torch.Tensor, width: torch.Tensor, dtype: torch.dtype, device: torch.device
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Generate normalized coordinate grids."""
+ coords_h = torch.arange(0, height, device=device, dtype=dtype) / height
+ coords_w = torch.arange(0, width, device=device, dtype=dtype) / width
+ return coords_h, coords_w
+
+ def _get_rpb_matrix(
+ self, reference_boxes: torch.Tensor, spatial_shape: tuple[torch.Tensor, torch.Tensor]
+ ) -> torch.Tensor:
+ """
+ Compute box relative position bias (RPB) matrix using log-scale encoding.
+ RPB helps the decoder attend to relevant spatial locations based on predicted box positions.
+
+ Args:
+ reference_boxes: Reference boxes [batch_size, num_queries, 4] in sigmoid space
+ spatial_shape: (height, width) of the vision features as tensors
+
+ Returns:
+ RPB matrix [batch_size, num_heads, num_queries, height*width]
+ """
+ height, width = spatial_shape
+ boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes)
+ batch_size, num_queries, _ = boxes_xyxy.shape
+
+ # Generate coordinate grids
+ coords_h, coords_w = self._get_coords(
+ height, width, dtype=reference_boxes.dtype, device=reference_boxes.device
+ )
+
+ # Compute deltas between coordinates and box boundaries
+ deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
+ deltas_y = deltas_y.view(batch_size, num_queries, -1, 2)
+ deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
+ deltas_x = deltas_x.view(batch_size, num_queries, -1, 2)
+
+ # Apply log-scale encoding
+ deltas_x_log = deltas_x * 8
+ deltas_x_log = torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / math.log2(8)
+ deltas_y_log = deltas_y * 8
+ deltas_y_log = torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / math.log2(8)
+
+ # Embed deltas
+ deltas_x = self.box_rpb_embed_x(deltas_x_log) # [batch_size, num_queries, width, num_heads]
+ deltas_y = self.box_rpb_embed_y(deltas_y_log) # [batch_size, num_queries, height, num_heads]
+
+ # Combine into 2D bias matrix
+ rpb_matrix = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(
+ 2
+ ) # [batch_size, num_queries, height, width, num_heads]
+ rpb_matrix = rpb_matrix.flatten(2, 3) # [batch_size, num_queries, height*width, num_heads]
+ rpb_matrix = rpb_matrix.permute(0, 3, 1, 2).contiguous() # [batch_size, num_heads, num_queries, height*width]
+ return rpb_matrix
+
+ @merge_with_config_defaults
+ @capture_outputs
+ def forward(
+ self,
+ vision_features: torch.Tensor,
+ text_features: torch.Tensor,
+ vision_pos_encoding: torch.Tensor,
+ text_mask: torch.Tensor | None = None,
+ spatial_shapes: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | Sam3LiteTextDETRDecoderOutput:
+ """
+ Forward pass for the DETR decoder.
+
+ Args:
+ vision_features: Vision features [batch_size, height*width, hidden_size]
+ text_features: Text features [batch_size, seq_len, hidden_size]
+ vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size]
+ text_mask: Text padding mask [batch_size, seq_len] where True=valid, False=padding
+ spatial_shapes: Spatial shapes [num_levels, 2]
+
+ Returns:
+ Sam3LiteTextDETRDecoderOutput containing decoder outputs from all layers.
+ """
+ batch_size = vision_features.shape[0]
+
+ query_embeds = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
+ reference_boxes = self.reference_points.weight.unsqueeze(0).expand(batch_size, -1, -1)
+ reference_boxes = reference_boxes.sigmoid()
+ presence_token = self.presence_token.weight.unsqueeze(0).expand(batch_size, -1, -1)
+
+ # Concatenate presence token with query embeddings
+ hidden_states = torch.cat([presence_token, query_embeds], dim=1)
+
+ text_cross_attn_mask = None
+ if text_mask is not None:
+ text_cross_attn_mask = create_bidirectional_mask(
+ config=self.config,
+ inputs_embeds=hidden_states,
+ attention_mask=text_mask,
+ encoder_hidden_states=text_features,
+ )
+
+ intermediate_outputs = []
+ intermediate_boxes = [reference_boxes]
+ intermediate_presence_logits = []
+
+ for layer in self.layers:
+ # Generate sine embeddings for conditional queries
+ reference_points_input = reference_boxes.unsqueeze(2)
+ query_sine_embed = self.position_encoding.encode_boxes(reference_points_input[:, :, 0, :])
+ query_pos = self.ref_point_head(query_sine_embed)
+
+ # Compute box relative position bias (RPB) attention mask
+ vision_cross_attn_mask = None
+ if spatial_shapes is not None and spatial_shapes.shape[0] == 1:
+ spatial_shape = (spatial_shapes[0, 0], spatial_shapes[0, 1])
+ rpb_matrix = self._get_rpb_matrix(reference_boxes, spatial_shape)
+ # Prepend zeros row for presence token (it attends to all vision tokens equally)
+ vision_cross_attn_mask = F.pad(rpb_matrix, (0, 0, 1, 0), mode="constant", value=0)
+
+ hidden_states = layer(
+ hidden_states,
+ query_pos=query_pos,
+ text_features=text_features,
+ vision_features=vision_features,
+ vision_pos_encoding=vision_pos_encoding,
+ text_cross_attn_mask=text_cross_attn_mask,
+ vision_cross_attn_mask=vision_cross_attn_mask,
+ **kwargs,
+ )
+
+ # Extract query hidden states (without presence token) for box refinement
+ query_hidden_states = hidden_states[:, 1:]
+
+ # Box refinement: predict delta and update reference boxes
+ reference_boxes_before_sigmoid = inverse_sigmoid(reference_boxes)
+ delta_boxes = self.box_head(self.output_layer_norm(query_hidden_states))
+ new_reference_boxes = (delta_boxes + reference_boxes_before_sigmoid).sigmoid()
+ reference_boxes = new_reference_boxes.detach()
+
+ intermediate_outputs.append(self.output_layer_norm(query_hidden_states))
+ intermediate_boxes.append(new_reference_boxes)
+
+ # Process presence token
+ presence_hidden = hidden_states[:, :1]
+ presence_logits = self.presence_head(self.presence_layer_norm(presence_hidden)).squeeze(-1)
+ presence_logits = presence_logits.clamp(
+ min=-self.clamp_presence_logit_max_val, max=self.clamp_presence_logit_max_val
+ )
+ intermediate_presence_logits.append(presence_logits)
+
+ # Stack outputs from all layers
+ intermediate_outputs = torch.stack(intermediate_outputs)
+ intermediate_boxes = torch.stack(intermediate_boxes[:-1])
+ intermediate_presence_logits = torch.stack(intermediate_presence_logits)
+
+ return Sam3LiteTextDETRDecoderOutput(
+ intermediate_hidden_states=intermediate_outputs,
+ reference_boxes=intermediate_boxes,
+ presence_logits=intermediate_presence_logits,
+ )
+
+
+class Sam3LiteTextDotProductScoring(nn.Module):
+ """
+ Computes classification scores by computing dot product between projected decoder queries and pooled text features.
+ This is used to determine confidence/presence scores for each query.
+ """
+
+ def __init__(self, config: Sam3LiteTextConfig):
+ super().__init__()
+ self.config = config
+ hidden_size = config.detr_decoder_config.hidden_size
+ projection_dim = config.detr_decoder_config.hidden_size
+
+ self.text_mlp = Sam3LiteTextDecoderMLP(
+ input_dim=hidden_size,
+ hidden_dim=config.detr_decoder_config.intermediate_size,
+ output_dim=hidden_size,
+ num_layers=2,
+ )
+ self.text_mlp_dropout = nn.Dropout(config.detr_decoder_config.dropout)
+ self.text_mlp_out_norm = nn.LayerNorm(hidden_size)
+
+ # Projections for text and query features
+ self.text_proj = nn.Linear(hidden_size, projection_dim)
+ self.query_proj = nn.Linear(hidden_size, projection_dim)
+
+ # Scale factor for dot product
+ self.scale = float(1.0 / np.sqrt(projection_dim))
+
+ # Clamping to avoid numerical issues
+ self.clamp_logits = True
+ self.clamp_max_val = 12.0
+
+ def _pool_text_features(self, text_features: torch.Tensor, text_mask: torch.Tensor | None) -> torch.Tensor:
+ """
+ Mean pool text features, accounting for padding.
+
+ Args:
+ text_features: [batch_size, seq_len, hidden_size]
+ text_mask: [batch_size, seq_len] where True indicates valid tokens, False indicates padding
+
+ Returns:
+ pooled_text: [batch_size, hidden_size]
+ """
+ if text_mask is None:
+ # No padding, simple mean
+ return text_features.mean(dim=1)
+
+ is_valid = text_mask.to(text_features.dtype).unsqueeze(-1) # [batch_size, seq_len, 1]
+
+ # Count valid tokens per batch
+ num_valid = is_valid.sum(dim=1).clamp(min=1.0) # [batch_size, 1]
+
+ # Mean pool only over valid tokens
+ pooled_text = (text_features * is_valid).sum(dim=1) / num_valid # [batch_size, hidden_size]
+
+ return pooled_text
+
+ def forward(
+ self,
+ decoder_hidden_states: torch.Tensor,
+ text_features: torch.Tensor,
+ text_mask: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ """
+ Compute classification scores via dot product.
+
+ Args:
+ decoder_hidden_states: [num_layers, batch_size, num_queries, hidden_size]
+ text_features: [batch_size, seq_len, hidden_size]
+ text_mask: [batch_size, seq_len] where True=valid, False=padding
+
+ Returns:
+ scores: [num_layers, batch_size, num_queries, 1]
+ """
+ orig_text_features = text_features
+ text_features = self.text_mlp(text_features)
+ text_features = self.text_mlp_dropout(text_features)
+ text_features = text_features + orig_text_features
+ text_features = self.text_mlp_out_norm(text_features)
+
+ pooled_text = self._pool_text_features(text_features, text_mask)
+
+ proj_text = self.text_proj(pooled_text)
+ proj_queries = self.query_proj(decoder_hidden_states)
+
+ proj_text = proj_text.unsqueeze(-1)
+ scores = torch.matmul(proj_queries, proj_text.unsqueeze(0))
+ scores = scores * self.scale
+ if self.clamp_logits:
+ scores = scores.clamp(min=-self.clamp_max_val, max=self.clamp_max_val)
+
+ return scores
+
+
+class Sam3LiteTextMaskEmbedder(nn.Module):
+ """
+ MLP that embeds object queries for mask prediction.
+ Similar to MaskFormer's mask embedder.
+ """
+
+ def __init__(self, config: Sam3LiteTextMaskDecoderConfig):
+ super().__init__()
+ self.config = config
+ hidden_size = config.hidden_size
+
+ self.layers = nn.ModuleList(
+ [
+ nn.Linear(hidden_size, hidden_size),
+ nn.Linear(hidden_size, hidden_size),
+ nn.Linear(hidden_size, hidden_size),
+ ]
+ )
+ self.activation = nn.ReLU()
+
+ def forward(self, queries: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ queries: Query embeddings [batch_size, num_queries, hidden_size]
+
+ Returns:
+ Mask embeddings [batch_size, num_queries, hidden_size]
+ """
+ hidden_states = queries
+ for i, layer in enumerate(self.layers):
+ hidden_states = layer(hidden_states)
+ if i < len(self.layers) - 1:
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class Sam3LiteTextPixelDecoder(nn.Module):
+ """
+ Feature Pyramid Network (FPN) decoder that generates pixel-level features.
+ Inspired by MaskFormer's pixel decoder.
+ """
+
+ def __init__(self, config: Sam3LiteTextMaskDecoderConfig):
+ super().__init__()
+ self.config = config
+ hidden_size = config.hidden_size
+ num_upsampling_stages = config.num_upsampling_stages
+
+ # Create conv layers and norms for FPN
+ self.conv_layers = nn.ModuleList(
+ [
+ nn.Conv2d(hidden_size, hidden_size, kernel_size=3, stride=1, padding=1)
+ for _ in range(num_upsampling_stages)
+ ]
+ )
+ self.norms = nn.ModuleList([nn.GroupNorm(8, hidden_size) for _ in range(num_upsampling_stages)])
+
+ self.out_channels = hidden_size
+
+ def forward(self, backbone_features: list[torch.Tensor]) -> torch.Tensor:
+ """
+ Args:
+ backbone_features: List of backbone features [batch_size, hidden_size, H_i, W_i]
+ from low to high resolution (assumes already projected to hidden_size)
+
+ Returns:
+ Pixel embeddings [batch_size, hidden_size, H, W] at the finest resolution
+ """
+ # Start from the coarsest feature (last in list)
+ prev_fpn = backbone_features[-1]
+ # Iterate through features from coarse to fine (excluding the last which we started with)
+ for layer_idx, backbone_feat in enumerate(reversed(backbone_features[:-1])):
+ # Upsample previous FPN output to match current backbone feature size
+ prev_fpn = F.interpolate(prev_fpn, size=backbone_feat.shape[-2:], mode="nearest")
+
+ # Add skip connection
+ prev_fpn = prev_fpn + backbone_feat
+
+ # Apply conv and norm
+ prev_fpn = self.conv_layers[layer_idx](prev_fpn)
+ prev_fpn = self.norms[layer_idx](prev_fpn)
+ prev_fpn = F.relu(prev_fpn)
+
+ return prev_fpn
+
+
+class Sam3LiteTextMaskDecoder(Sam3LiteTextPreTrainedModel):
+ """
+ Mask decoder that combines object queries with pixel-level features to predict instance masks.
+ Also produces a semantic segmentation output and supports cross-attention to prompts.
+ """
+
+ _can_record_outputs = {
+ "attentions": Sam3LiteTextAttention,
+ }
+
+ def __init__(self, config: Sam3LiteTextMaskDecoderConfig):
+ super().__init__(config)
+ self.config = config
+ hidden_size = config.hidden_size
+
+ # Pixel decoder (FPN)
+ self.pixel_decoder = Sam3LiteTextPixelDecoder(config)
+
+ # Mask embedder (MLP to transform queries)
+ self.mask_embedder = Sam3LiteTextMaskEmbedder(config)
+
+ # Projection from pixel decoder output to mask embedding space
+ self.instance_projection = nn.Conv2d(self.pixel_decoder.out_channels, hidden_size, kernel_size=1)
+
+ # Semantic segmentation head (always present in UniversalSegmentationHead)
+ self.semantic_projection = nn.Conv2d(self.pixel_decoder.out_channels, 1, kernel_size=1)
+
+ self.prompt_cross_attn = Sam3LiteTextAttention(config)
+ self.prompt_cross_attn_norm = nn.LayerNorm(hidden_size)
+ self.prompt_cross_attn_dropout = nn.Dropout(config.dropout)
+
+ self.post_init()
+
+ @merge_with_config_defaults
+ @capture_outputs
+ def forward(
+ self,
+ decoder_queries: torch.Tensor,
+ backbone_features: list[torch.Tensor],
+ encoder_hidden_states: torch.Tensor,
+ prompt_features: torch.Tensor | None = None,
+ prompt_mask: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | Sam3LiteTextMaskDecoderOutput:
+ """
+ Args:
+ decoder_queries: Decoder output queries [batch_size, num_queries, hidden_size]
+ backbone_features: List of backbone features to process through FPN
+ encoder_hidden_states: Encoder outputs [batch_size, seq_len, hidden_size]
+ prompt_features: Prompt features (text + geometry) for cross-attention [batch_size, prompt_len, hidden_size]
+ prompt_mask: Padding mask [batch_size, prompt_len] where True=valid, False=padding
+
+ Returns:
+ Sam3LiteTextMaskDecoderOutput containing predicted masks and semantic segmentation.
+ """
+ if prompt_features is not None:
+ # Cross-attention: encoder features attend to prompt features
+ residual = encoder_hidden_states
+ normed_hidden_states = self.prompt_cross_attn_norm(encoder_hidden_states)
+
+ cross_attn_mask = None
+ if prompt_mask is not None:
+ cross_attn_mask = create_bidirectional_mask(
+ config=self.config,
+ inputs_embeds=normed_hidden_states,
+ encoder_hidden_states=prompt_features,
+ attention_mask=prompt_mask,
+ )
+
+ attn_output, _ = self.prompt_cross_attn(
+ query=normed_hidden_states,
+ key=prompt_features,
+ value=prompt_features,
+ attention_mask=cross_attn_mask,
+ **kwargs,
+ )
+ encoder_hidden_states = residual + self.prompt_cross_attn_dropout(attn_output)
+
+ # Process backbone features through FPN to get pixel embeddings
+ pixel_embed = self._embed_pixels(
+ backbone_features=backbone_features,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+
+ # Predict instance masks via dot product between query embeddings and pixel embeddings
+ instance_embeds = self.instance_projection(pixel_embed)
+ mask_embeddings = self.mask_embedder(decoder_queries)
+ pred_masks = torch.einsum("bqc,bchw->bqhw", mask_embeddings, instance_embeds)
+
+ # Generate semantic segmentation
+ semantic_seg = self.semantic_projection(pixel_embed)
+
+ return Sam3LiteTextMaskDecoderOutput(
+ pred_masks=pred_masks,
+ semantic_seg=semantic_seg,
+ )
+
+ def _embed_pixels(
+ self,
+ backbone_features: list[torch.Tensor],
+ encoder_hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Embed pixels by combining backbone FPN features with encoder vision features.
+ The encoder vision features replace the finest-resolution backbone feature.
+
+ Args:
+ backbone_features: List of backbone features [batch_size, C, H_i, W_i]
+ encoder_hidden_states: Encoder outputs [batch_size, seq_len, hidden_size]
+
+ Returns:
+ Pixel embeddings [batch_size, hidden_size, H, W]
+ """
+ backbone_visual_feats = [feat.clone() for feat in backbone_features]
+
+ # Extract vision features from encoder output and reshape to spatial format
+ spatial_dim = backbone_features[-1].shape[-2] * backbone_features[-1].shape[-1]
+ encoder_visual_embed = encoder_hidden_states[:, :spatial_dim, :]
+ batch_size, _, hidden_size = encoder_visual_embed.shape
+ height, width = backbone_features[-1].shape[-2:]
+ encoder_visual_embed = encoder_visual_embed.transpose(1, 2).reshape(batch_size, hidden_size, height, width)
+
+ # Replace finest backbone feature with encoder vision features
+ backbone_visual_feats[-1] = encoder_visual_embed
+
+ # Process through FPN decoder
+ pixel_embed = self.pixel_decoder(backbone_visual_feats)
+
+ return pixel_embed
+
+
+class Sam3LiteTextModel(Sam3LiteTextPreTrainedModel):
+ input_modalities = ["image", "text"]
+ base_model_prefix = "detector_model"
+ _keys_to_ignore_on_load_unexpected = [
+ r"^tracker_model.",
+ r"^tracker_neck.",
+ ]
+ # DETR components create float masks from features, so flash/flex attention cannot be dispatched safely.
+ _supports_flash_attn = False
+ _supports_flex_attn = False
+
+ def __init__(self, config: Sam3LiteTextConfig):
+ # loading from a sam3_lite_text_video config
+ if hasattr(config, "detector_config") and config.detector_config is not None:
+ detector_config = config.detector_config
+ if isinstance(detector_config, dict):
+ detector_config = Sam3LiteTextConfig(**detector_config)
+ config = detector_config
+ super().__init__(config)
+ self.vision_encoder = AutoModel.from_config(config.vision_config)
+ self.text_encoder = Sam3LiteTextTextModel(config.text_config)
+ self.vocab_size = config.text_config.vocab_size
+
+ # Project text features from text encoder hidden size to model hidden size
+ # CLIP text encoder outputs 1024-dim features, but we need 256-dim for DETR
+ self.text_projection = nn.Linear(config.text_config.hidden_size, config.detr_encoder_config.hidden_size)
+
+ # Pass _attn_implementation to subconfigs BEFORE creating modules
+ config.geometry_encoder_config._attn_implementation = config._attn_implementation
+ config.detr_encoder_config._attn_implementation = config._attn_implementation
+ config.detr_decoder_config._attn_implementation = config._attn_implementation
+ config.mask_decoder_config._attn_implementation = config._attn_implementation
+
+ self.geometry_encoder = Sam3LiteTextGeometryEncoder(config.geometry_encoder_config)
+ self.detr_encoder = Sam3LiteTextDetrEncoder(config.detr_encoder_config)
+ self.detr_decoder = Sam3LiteTextDetrDecoder(config.detr_decoder_config)
+ self.mask_decoder = Sam3LiteTextMaskDecoder(config.mask_decoder_config)
+
+ # Dot product scoring to compute classification scores
+ self.dot_product_scoring = Sam3LiteTextDotProductScoring(config)
+
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def get_text_features(
+ self,
+ input_ids: torch.LongTensor,
+ attention_mask: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | BaseModelOutputWithPooling:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import Sam3LiteTextModel, Sam3LiteTextProcessor
+ >>> from PIL import Image
+ >>> import httpx
+ >>> from io import BytesIO
+
+ >>> model = Sam3LiteTextModel.from_pretrained("facebook/sam3_lite_text")
+ >>> processor = Sam3LiteTextProcessor.from_pretrained("facebook/sam3_lite_text")
+
+ >>> # Pre-compute text embeddings
+ >>> text_inputs = processor(text="cat", return_tensors="pt")
+ >>> text_embeds = model.get_text_features(**text_inputs).pooler_output
+
+ >>> # Reuse text embeddings for multiple images
+ >>> url = "http://images.cocodataset.org/val2017/000000077595.jpg"
+ >>> with httpx.stream("GET", url) as response:
+ ... image = Image.open(BytesIO(response.read()))
+ >>> img_inputs = processor(images=image, return_tensors="pt")
+ >>> outputs = model(pixel_values=img_inputs.pixel_values, text_embeds=text_embeds)
+ ```
+ """
+ text_outputs = self.text_encoder(
+ input_ids=input_ids, attention_mask=attention_mask, return_dict=True, **kwargs
+ )
+ last_hidden_state = text_outputs.last_hidden_state
+ text_outputs.pooler_output = self.text_projection(last_hidden_state)
+
+ return text_outputs
+
+ @auto_docstring
+ def get_vision_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Sam3LiteTextVisionEncoderOutput:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import Sam3LiteTextModel, Sam3LiteTextProcessor
+ >>> from PIL import Image
+ >>> import httpx
+ >>> from io import BytesIO
+
+ >>> model = Sam3LiteTextModel.from_pretrained("facebook/sam3_lite_text")
+ >>> processor = Sam3LiteTextProcessor.from_pretrained("facebook/sam3_lite_text")
+
+ >>> # Pre-compute vision embeddings
+ >>> url = "http://images.cocodataset.org/val2017/000000077595.jpg"
+ >>> with httpx.stream("GET", url) as response:
+ ... image = Image.open(BytesIO(response.read()))
+ >>> img_inputs = processor(images=image, return_tensors="pt")
+ >>> vision_embeds = model.get_vision_features(pixel_values=img_inputs.pixel_values)
+
+ >>> # Reuse vision embeddings for multiple text prompts
+ >>> text_inputs = processor(text="cat", return_tensors="pt")
+ >>> outputs = model(vision_embeds=vision_embeds, input_ids=text_inputs.input_ids)
+ ```
+ """
+ vision_outputs = self.vision_encoder(pixel_values, **kwargs)
+ return vision_outputs
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor | None = None,
+ vision_embeds: Sam3LiteTextVisionEncoderOutput | None = None,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ text_embeds: torch.FloatTensor | None = None,
+ input_boxes: torch.FloatTensor | None = None,
+ input_boxes_labels: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Sam3LiteTextImageSegmentationOutput:
+ r"""
+ vision_embeds (`Sam3LiteTextVisionEncoderOutput`, *optional*):
+ Pre-computed vision embeddings. Can be used to easily reuse vision embeddings. If provided, `pixel_values`
+ should not be passed. Mutually exclusive with `pixel_values`.
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Pre-computed text embeddings. Can be used to easily reuse text embeddings. If provided, `input_ids`
+ should not be passed. Mutually exclusive with `input_ids`.
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`, *optional*):
+ Normalized box coordinates in [0, 1] range, in (cx, cy, w, h) format.
+ input_boxes_labels (`torch.LongTensor` of shape `(batch_size, num_boxes)`, *optional*):
+ Labels for boxes: 1 (positive), 0 (negative).
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import httpx
+ >>> from io import BytesIO
+ >>> from transformers import AutoModel, AutoProcessor
+
+ >>> model = AutoModel.from_pretrained("facebook/sam3_lite_text")
+ >>> processor = AutoProcessor.from_pretrained("facebook/sam3_lite_text")
+
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
+ >>> with httpx.stream("GET", url) as response:
+ ... image = Image.open(BytesIO(response.read())).convert("RGB")
+ >>> text = "car"
+ >>> inputs = processor(images=image, text=text, return_tensors="pt")
+
+ >>> # Get segmentation output
+ >>> outputs = model(**inputs)
+ >>> pred_masks = outputs.pred_masks
+ >>> pred_boxes = outputs.pred_boxes
+ ```
+ """
+ if (pixel_values is None) == (vision_embeds is None):
+ raise ValueError("You must specify exactly one of pixel_values or vision_embeds")
+
+ if (input_ids is None) == (text_embeds is None):
+ raise ValueError("You must specify exactly one of input_ids or text_embeds")
+
+ if pixel_values is not None:
+ batch_size = pixel_values.shape[0]
+ device = pixel_values.device
+ else:
+ batch_size = vision_embeds.fpn_hidden_states[0].shape[0]
+ device = vision_embeds.fpn_hidden_states[0].device
+
+ if vision_embeds is None:
+ vision_outputs = self.vision_encoder(pixel_values, **kwargs)
+ else:
+ vision_outputs = vision_embeds
+
+ fpn_hidden_states = vision_outputs.fpn_hidden_states[:-1]
+ fpn_position_encoding = vision_outputs.fpn_position_encoding[:-1]
+
+ if text_embeds is None:
+ text_features = self.get_text_features(
+ input_ids=input_ids, attention_mask=attention_mask, return_dict=True
+ ).pooler_output
+ else:
+ text_features = text_embeds
+
+ text_mask = attention_mask.bool() if attention_mask is not None else None
+ has_geometry_prompts = input_boxes is not None and input_boxes.numel() > 0
+
+ geometry_prompt_features = None
+ geometry_prompt_mask = None
+
+ if has_geometry_prompts:
+ if input_boxes is not None and input_boxes.numel() > 0:
+ box_embeddings = input_boxes # [batch_size, num_boxes, 4]
+ box_labels = (
+ input_boxes_labels
+ if input_boxes_labels is not None
+ else torch.ones_like(box_embeddings[..., 0], dtype=torch.long)
+ )
+ box_mask = (
+ (input_boxes_labels != -10)
+ if input_boxes_labels is not None
+ else torch.ones(batch_size, input_boxes.shape[1], dtype=torch.bool, device=device)
+ )
+ box_labels = torch.where(box_labels == -10, 0, box_labels)
+ else:
+ box_embeddings = torch.zeros(batch_size, 0, 4, dtype=text_features.dtype, device=device)
+ box_labels = torch.zeros(batch_size, 0, dtype=torch.long, device=device)
+ box_mask = torch.zeros(batch_size, 0, dtype=torch.bool, device=device)
+
+ geometry_outputs = self.geometry_encoder(
+ box_embeddings=box_embeddings,
+ box_mask=box_mask,
+ box_labels=box_labels,
+ img_feats=fpn_hidden_states,
+ img_pos_embeds=fpn_position_encoding,
+ )
+
+ geometry_prompt_features = geometry_outputs.last_hidden_state
+ geometry_prompt_mask = geometry_outputs.attention_mask
+
+ if geometry_prompt_features is not None:
+ # Repeat text_features for all geometry prompts
+ if text_features.shape[0] == 1 and geometry_prompt_features.shape[0] > 1:
+ text_features = text_features.repeat(geometry_prompt_features.shape[0], 1, 1)
+ combined_prompt_features = torch.cat([text_features, geometry_prompt_features], dim=1)
+ if text_mask is not None and text_mask.shape[0] == 1 and geometry_prompt_mask.shape[0] > 1:
+ text_mask = text_mask.repeat(geometry_prompt_mask.shape[0], 1)
+
+ if text_mask is not None and geometry_prompt_mask is not None:
+ combined_prompt_mask = torch.cat([text_mask, geometry_prompt_mask], dim=1)
+ elif text_mask is not None:
+ geo_valid_mask = torch.ones(
+ batch_size, geometry_prompt_features.shape[1], dtype=torch.bool, device=device
+ )
+ combined_prompt_mask = torch.cat([text_mask, geo_valid_mask], dim=1)
+ elif geometry_prompt_mask is not None:
+ text_valid_mask = torch.ones(batch_size, text_features.shape[1], dtype=torch.bool, device=device)
+ combined_prompt_mask = torch.cat([text_valid_mask, geometry_prompt_mask], dim=1)
+ else:
+ combined_prompt_mask = None
+ else:
+ combined_prompt_features = text_features
+ combined_prompt_mask = text_mask
+
+ encoder_outputs = self.detr_encoder(
+ vision_features=[fpn_hidden_states[-1]],
+ text_features=combined_prompt_features,
+ vision_pos_embeds=[fpn_position_encoding[-1]],
+ text_mask=combined_prompt_mask,
+ **kwargs,
+ )
+
+ decoder_outputs = self.detr_decoder(
+ vision_features=encoder_outputs.last_hidden_state,
+ text_features=encoder_outputs.text_features,
+ vision_pos_encoding=encoder_outputs.pos_embeds_flattened,
+ text_mask=combined_prompt_mask,
+ spatial_shapes=encoder_outputs.spatial_shapes,
+ **kwargs,
+ )
+
+ # Refine boxes from decoder
+ all_box_offsets = self.detr_decoder.box_head(decoder_outputs.intermediate_hidden_states)
+ reference_boxes_inv_sig = inverse_sigmoid(decoder_outputs.reference_boxes)
+ all_pred_boxes_cxcywh = (reference_boxes_inv_sig + all_box_offsets).sigmoid()
+ all_pred_boxes = box_cxcywh_to_xyxy(all_pred_boxes_cxcywh)
+
+ all_pred_logits = self.dot_product_scoring(
+ decoder_hidden_states=decoder_outputs.intermediate_hidden_states,
+ text_features=encoder_outputs.text_features,
+ text_mask=combined_prompt_mask,
+ ).squeeze(-1)
+
+ pred_logits = all_pred_logits[-1]
+ pred_boxes = all_pred_boxes[-1]
+ decoder_hidden_states = decoder_outputs.intermediate_hidden_states[-1]
+ presence_logits = decoder_outputs.presence_logits[-1]
+
+ mask_outputs = self.mask_decoder(
+ decoder_queries=decoder_hidden_states,
+ backbone_features=list(fpn_hidden_states),
+ encoder_hidden_states=encoder_outputs.last_hidden_state,
+ prompt_features=combined_prompt_features,
+ prompt_mask=combined_prompt_mask,
+ **kwargs,
+ )
+
+ return Sam3LiteTextImageSegmentationOutput(
+ pred_masks=mask_outputs.pred_masks,
+ pred_boxes=pred_boxes,
+ pred_logits=pred_logits,
+ presence_logits=presence_logits,
+ semantic_seg=mask_outputs.semantic_seg,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_reference_boxes=decoder_outputs.reference_boxes,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ vision_hidden_states=vision_outputs.hidden_states,
+ vision_attentions=vision_outputs.attentions,
+ detr_encoder_attentions=encoder_outputs.attentions,
+ detr_decoder_attentions=decoder_outputs.attentions,
+ mask_decoder_attentions=mask_outputs.attentions,
+ )
+
+
+__all__ = ["Sam3LiteTextModel", "Sam3LiteTextPreTrainedModel", "Sam3LiteTextTextModel"]
diff --git a/src/transformers/models/sam3_lite_text/modular_sam3_lite_text.py b/src/transformers/models/sam3_lite_text/modular_sam3_lite_text.py
new file mode 100644
index 000000000000..46408464a333
--- /dev/null
+++ b/src/transformers/models/sam3_lite_text/modular_sam3_lite_text.py
@@ -0,0 +1,532 @@
+# Copyright 2026 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from huggingface_hub.dataclasses import strict
+
+from ... import initialization as init
+from ...activations import ACT2FN
+from ...configuration_utils import PreTrainedConfig
+from ...masking_utils import create_bidirectional_mask
+from ...modeling_outputs import BaseModelOutputWithPooling
+from ...processing_utils import Unpack
+from ...utils import auto_docstring
+from ...utils.generic import TransformersKwargs, merge_with_config_defaults
+from ...utils.output_capturing import capture_outputs
+from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel
+from ..sam3.configuration_sam3 import (
+ Sam3DETRDecoderConfig,
+ Sam3DETREncoderConfig,
+ Sam3GeometryEncoderConfig,
+ Sam3MaskDecoderConfig,
+)
+from ..sam3.modeling_sam3 import Sam3Model, Sam3PreTrainedModel
+from ..siglip.modeling_siglip import SiglipAttention, SiglipEncoderLayer, SiglipMLP
+
+
+@auto_docstring(checkpoint="facebook/sam3_lite_text")
+@strict
+class Sam3LiteTextViTConfig(PreTrainedConfig):
+ r"""
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ Base frequency for RoPE.
+ window_size (`int`, *optional*, defaults to 24):
+ Window size for windowed attention.
+ global_attn_indexes (`list[int]`, *optional*, defaults to `[7, 15, 23, 31]`):
+ Indexes of layers with global attention.
+ pretrain_image_size (`int`, *optional*, defaults to 336):
+ Pretrained model image size for position embedding initialization.
+ hidden_dropout (`float`, *optional*, defaults to 0.0):
+ Dropout probability for hidden states.
+ """
+
+ base_config_key = "backbone_config"
+ model_type = "sam3_vit_model"
+
+ hidden_size: int = 1024
+ intermediate_size: int = 4736
+ num_hidden_layers: int = 32
+ num_attention_heads: int = 16
+ num_channels: int = 3
+ image_size: int | list[int] | tuple[int, int] = 1008
+ patch_size: int | list[int] | tuple[int, int] = 14
+ hidden_act: str = "gelu"
+ layer_norm_eps: float = 1e-6
+ attention_dropout: float | int = 0.0
+ rope_theta: float = 10000.0
+ window_size: int = 24
+ global_attn_indexes: list[int] | None = None
+ layer_scale_init_value: float | None = None
+ pretrain_image_size: int | list[int] | tuple[int, int] = 336
+ hidden_dropout: float | int = 0.0
+ initializer_range: float = 0.02
+
+ def __post_init__(self, **kwargs):
+ super().__post_init__(**kwargs)
+ if self.global_attn_indexes is None:
+ self.global_attn_indexes = [7, 15, 23, 31]
+
+
+@auto_docstring(checkpoint="facebook/sam3_lite_text")
+@strict
+class Sam3LiteTextVisionConfig(PreTrainedConfig):
+ r"""
+ fpn_hidden_size (`int`, *optional*, defaults to 256):
+ The hidden dimension of the FPN.
+ backbone_feature_sizes (`List[List[int]]`, *optional*, defaults to `[[288, 288], [144, 144], [72, 72]]`):
+ The spatial sizes (height, width) of the feature maps from the backbone at different scales.
+ scale_factors (`list[float]`, *optional*, defaults to `[4.0, 2.0, 1.0, 0.5]`):
+ Scale factors for FPN multi-scale features. List of scaling factors for each FPN level.
+ """
+
+ base_config_key = "vision_config"
+ model_type = "sam3_vision_model"
+ sub_configs = {"backbone_config": AutoConfig}
+
+ backbone_config: dict | PreTrainedConfig | None = None
+ fpn_hidden_size: int = 256
+ backbone_feature_sizes: list | None = None
+ scale_factors: list[float] | None = None
+ hidden_act: str = "gelu"
+ layer_norm_eps: float = 1e-6
+ initializer_range: float = 0.02
+
+ def __post_init__(self, **kwargs):
+ self.scale_factors = [4.0, 2.0, 1.0, 0.5] if self.scale_factors is None else self.scale_factors
+ if self.backbone_feature_sizes is None:
+ self.backbone_feature_sizes = [[288, 288], [144, 144], [72, 72]]
+
+ if isinstance(self.backbone_config, dict):
+ self.backbone_config["model_type"] = self.backbone_config.get("model_type", "sam3_vit_model")
+ self.backbone_config = CONFIG_MAPPING[self.backbone_config["model_type"]](**self.backbone_config)
+ elif self.backbone_config is None:
+ self.backbone_config = CONFIG_MAPPING["sam3_vit_model"]()
+
+ super().__post_init__(**kwargs)
+
+ @property
+ def image_size(self):
+ """Image size for the vision encoder."""
+ return self.backbone_config.image_size
+
+ @image_size.setter
+ def image_size(self, value):
+ """Set the image size and propagate to backbone."""
+ self.backbone_config.image_size = value
+
+
+@auto_docstring(checkpoint="facebook/sam3_lite_text")
+@strict
+class Sam3LiteTextGeometryEncoderConfig(Sam3GeometryEncoderConfig):
+ pass
+
+
+@auto_docstring(checkpoint="facebook/sam3_lite_text")
+@strict
+class Sam3LiteTextDETREncoderConfig(Sam3DETREncoderConfig):
+ pass
+
+
+@auto_docstring(checkpoint="facebook/sam3_lite_text")
+@strict
+class Sam3LiteTextDETRDecoderConfig(Sam3DETRDecoderConfig):
+ pass
+
+
+@auto_docstring(checkpoint="facebook/sam3_lite_text")
+@strict
+class Sam3LiteTextMaskDecoderConfig(Sam3MaskDecoderConfig):
+ pass
+
+
+@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0")
+@strict
+class Sam3LiteTextTextConfig(PreTrainedConfig):
+ r"""
+ use_repmixer_blocks (`bool`, *optional*, defaults to `True`):
+ Whether to use RepMixer blocks (MobileCLIP-style) for the first and last encoder layers.
+ When `False`, all layers are standard Transformer encoder layers.
+ layer_scale_init_value (`float`, *optional*, defaults to `1e-5`):
+ Initial value for the learnable layer-scale parameters in RepMixer blocks (residual branches).
+ repmixer_kernel_size (`int`, *optional*, defaults to `11`):
+ Kernel size for depthwise convolutions in RepMixer blocks (token mixer and convolutional feed-forward path).
+ """
+
+ model_type = "sam3_lite_text_text_model"
+
+ vocab_size: int = 49408
+ hidden_size: int = 512
+ intermediate_size: int = 2048
+ projection_dim: int = 512
+ num_hidden_layers: int = 12
+ num_attention_heads: int = 8
+ max_position_embeddings: int = 77
+ hidden_act: str = "gelu"
+ layer_norm_eps: float = 1e-5
+ attention_dropout: float = 0.0
+ use_repmixer_blocks: bool = True
+ layer_scale_init_value: float = 1e-5
+ repmixer_kernel_size: int = 11
+
+
+@auto_docstring(checkpoint="facebook/sam3_lite_text")
+@strict
+class Sam3LiteTextConfig(PreTrainedConfig):
+ r"""
+ geometry_encoder_config (`dict` or `Sam3LiteTextGeometryEncoderConfig`, *optional*):
+ Configuration for the geometry encoder.
+ detr_encoder_config (`dict` or `Sam3LiteTextDETREncoderConfig`, *optional*):
+ Configuration for the DETR encoder.
+ detr_decoder_config (`dict` or `Sam3LiteTextDETRDecoderConfig`, *optional*):
+ Configuration for the DETR decoder.
+ mask_decoder_config (`dict` or `Sam3LiteTextMaskDecoderConfig`, *optional*):
+ Configuration for the mask decoder.
+
+ Example:
+ ```python
+ >>> from transformers import Sam3LiteTextConfig, Sam3LiteTextModel
+
+ >>> # Initializing a SAM3_LITE_TEXT configuration
+ >>> configuration = Sam3LiteTextConfig()
+
+ >>> # Initializing a model from the configuration
+ >>> model = Sam3LiteTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "sam3_lite_text"
+ sub_configs = {
+ "vision_config": AutoConfig,
+ "text_config": Sam3LiteTextTextConfig,
+ "geometry_encoder_config": Sam3LiteTextGeometryEncoderConfig,
+ "detr_encoder_config": Sam3LiteTextDETREncoderConfig,
+ "detr_decoder_config": Sam3LiteTextDETRDecoderConfig,
+ "mask_decoder_config": Sam3LiteTextMaskDecoderConfig,
+ }
+
+ vision_config: dict | PreTrainedConfig | None = None
+ text_config: dict | PreTrainedConfig | None = None
+ geometry_encoder_config: dict | PreTrainedConfig | None = None
+ detr_encoder_config: dict | PreTrainedConfig | None = None
+ detr_decoder_config: dict | PreTrainedConfig | None = None
+ mask_decoder_config: dict | PreTrainedConfig | None = None
+ initializer_range: float = 0.02
+
+ def __post_init__(self, **kwargs):
+ if isinstance(self.vision_config, dict):
+ self.vision_config["model_type"] = self.vision_config.get("model_type", "sam3_vision_model")
+ self.vision_config = CONFIG_MAPPING[self.vision_config["model_type"]](**self.vision_config)
+ elif self.vision_config is None:
+ self.vision_config = CONFIG_MAPPING["sam3_vision_model"]()
+
+ if self.text_config is None:
+ self.text_config = Sam3LiteTextTextConfig()
+ if isinstance(self.text_config, dict):
+ self.text_config = Sam3LiteTextTextConfig(**self.text_config)
+
+ if self.geometry_encoder_config is None:
+ self.geometry_encoder_config = Sam3LiteTextGeometryEncoderConfig()
+ if isinstance(self.geometry_encoder_config, dict):
+ self.geometry_encoder_config = Sam3LiteTextGeometryEncoderConfig(**self.geometry_encoder_config)
+
+ if self.detr_encoder_config is None:
+ self.detr_encoder_config = Sam3LiteTextDETREncoderConfig()
+ if isinstance(self.detr_encoder_config, dict):
+ self.detr_encoder_config = Sam3LiteTextDETREncoderConfig(**self.detr_encoder_config)
+
+ if self.detr_decoder_config is None:
+ self.detr_decoder_config = Sam3LiteTextDETRDecoderConfig()
+ if isinstance(self.detr_decoder_config, dict):
+ self.detr_decoder_config = Sam3LiteTextDETRDecoderConfig(**self.detr_decoder_config)
+
+ if self.mask_decoder_config is None:
+ self.mask_decoder_config = Sam3LiteTextMaskDecoderConfig()
+ if isinstance(self.mask_decoder_config, dict):
+ self.mask_decoder_config = Sam3LiteTextMaskDecoderConfig(**self.mask_decoder_config)
+
+ super().__post_init__(**kwargs)
+
+ @property
+ def image_size(self):
+ """Image size for the SAM3_LITE_TEXT model."""
+ return self.vision_config.image_size
+
+ @image_size.setter
+ def image_size(self, value):
+ """Set the image size and propagate to vision config."""
+ self.vision_config.image_size = value
+
+
+@dataclass
+class Sam3LiteTextTextEncoderOutput(BaseModelOutputWithPooling):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Full sequence of hidden states from the text encoder.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
+ EOT-pooled output projected to `projection_dim` via the internal CLIP-style projection.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Tuple of hidden states at each layer, returned when `output_hidden_states=True`.
+ attentions (`tuple(torch.FloatTensor)`, *optional*):
+ Tuple of attention weights at each transformer layer, returned when `output_attentions=True`.
+ """
+
+
+class Sam3LiteTextTextPositionEmbedding(nn.Module):
+ """Learnable positional embedding with bilinear interpolation for variable sequence lengths."""
+
+ def __init__(self, max_position_embeddings: int, hidden_size: int):
+ super().__init__()
+ self.position_embedding = nn.Parameter(torch.empty(1, 1, max_position_embeddings, hidden_size))
+
+ def forward(self, seq_len: int) -> torch.Tensor:
+ position_embedding = self.position_embedding
+ if seq_len != position_embedding.shape[2]:
+ position_embedding = F.interpolate(
+ position_embedding,
+ size=(seq_len, position_embedding.shape[-1]),
+ mode="bilinear",
+ )
+ return position_embedding.reshape(1, seq_len, -1)
+
+
+class Sam3LiteTextMobileOneBlock(nn.Module):
+ """Depthwise conv branch with batch norm on the skip path and after the conv (MobileOne-style)."""
+
+ def __init__(self, hidden_size: int, kernel_size: int = 3):
+ super().__init__()
+ self.batchnorm_skip = nn.BatchNorm2d(hidden_size)
+ self.conv = nn.Conv2d(
+ hidden_size,
+ hidden_size,
+ kernel_size=(1, kernel_size),
+ stride=1,
+ padding=(0, kernel_size // 2),
+ groups=hidden_size,
+ bias=False,
+ )
+ self.batchnorm_conv = nn.BatchNorm2d(hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.batchnorm_conv(self.conv(hidden_states))
+ hidden_states = hidden_states + self.batchnorm_skip(residual)
+ return hidden_states
+
+
+class Sam3LiteTextConvMLP(SiglipMLP):
+ """Pointwise MLP using 1×1 convolutions, compatible with 4-D (B, C, H, W) feature maps."""
+
+ def __init__(self, config: Sam3LiteTextTextConfig):
+ nn.Module.__init__(self)
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Conv2d(config.hidden_size, config.intermediate_size, kernel_size=1)
+ self.fc2 = nn.Conv2d(config.intermediate_size, config.hidden_size, kernel_size=1)
+
+
+class Sam3LiteTextConvolutionalFeedForward(nn.Module):
+ """Convolutional feed-forward network: depthwise conv + two pointwise projections."""
+
+ def __init__(self, config: Sam3LiteTextTextConfig):
+ super().__init__()
+ self.depthwise_conv = nn.Conv2d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=(1, config.repmixer_kernel_size),
+ padding=(0, config.repmixer_kernel_size // 2),
+ groups=config.hidden_size,
+ bias=False,
+ )
+ self.depthwise_batchnorm = nn.BatchNorm2d(config.hidden_size)
+ self.mlp = Sam3LiteTextConvMLP(config)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.depthwise_batchnorm(self.depthwise_conv(hidden_states))
+ return self.mlp(hidden_states)
+
+
+class Sam3LiteTextLayerScaledResidual(nn.Module):
+ """Common layer-scale residual pattern shared by the RepMixer and feed-forward branches."""
+
+ def __init__(self, hidden_size: int, layer_scale_init_value: float):
+ super().__init__()
+ self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((hidden_size, 1, 1)), requires_grad=True)
+
+ def layer_scale_residual(self, hidden_states: torch.Tensor, update: torch.Tensor) -> torch.Tensor:
+ return hidden_states + self.layer_scale * update
+
+
+class Sam3LiteTextRepMixer(Sam3LiteTextLayerScaledResidual):
+ """Re-parameterisable depthwise-conv token mixer operating on 1D sequence data."""
+
+ def __init__(self, config: Sam3LiteTextTextConfig):
+ super().__init__(config.hidden_size, config.layer_scale_init_value)
+ self.reference_batchnorm = nn.BatchNorm2d(config.hidden_size)
+ self.mixer = Sam3LiteTextMobileOneBlock(config.hidden_size, kernel_size=config.repmixer_kernel_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return self.layer_scale_residual(
+ hidden_states, self.mixer(hidden_states) - self.reference_batchnorm(hidden_states)
+ )
+
+
+class Sam3LiteTextRepMixerBlock(Sam3LiteTextLayerScaledResidual):
+ """Token-mixing RepMixer plus a convolutional feed-forward path, each with layer scale."""
+
+ def __init__(self, config: Sam3LiteTextTextConfig):
+ super().__init__(config.hidden_size, config.layer_scale_init_value)
+ self.token_mixer = Sam3LiteTextRepMixer(config)
+ self.conv_feed_forward = Sam3LiteTextConvolutionalFeedForward(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ hidden_states = hidden_states.transpose(1, 2).unsqueeze(2)
+ hidden_states = self.token_mixer(hidden_states)
+ hidden_states = self.layer_scale_residual(hidden_states, self.conv_feed_forward(hidden_states))
+ return hidden_states.squeeze(2).transpose(1, 2)
+
+
+class Sam3LiteTextTextAttention(SiglipAttention):
+ pass
+
+
+class Sam3LiteTextTextMLP(SiglipMLP):
+ pass
+
+
+class Sam3LiteTextTextEncoderLayer(SiglipEncoderLayer):
+ def __init__(self, config: Sam3LiteTextTextConfig):
+ super().__init__(config)
+ self.self_attn = Sam3LiteTextTextAttention(config)
+ self.mlp = Sam3LiteTextTextMLP(config)
+
+
+class Sam3LiteTextTextEmbeddings(nn.Module):
+ """Token embedding + interpolatable positional embedding for the text encoder."""
+
+ def __init__(self, config: Sam3LiteTextTextConfig):
+ super().__init__()
+ self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.position_embedding = Sam3LiteTextTextPositionEmbedding(config.max_position_embeddings, config.hidden_size)
+
+ def forward(self, input_ids: torch.LongTensor) -> torch.Tensor:
+ hidden_states = self.token_embedding(input_ids)
+ hidden_states = hidden_states + self.position_embedding(input_ids.shape[1]).to(hidden_states.dtype)
+ return hidden_states
+
+
+@auto_docstring
+class Sam3LiteTextPreTrainedModel(Sam3PreTrainedModel):
+ config_class = Sam3LiteTextConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, Sam3LiteTextTextPositionEmbedding):
+ init.normal_(module.position_embedding, std=module.position_embedding.shape[-1] ** -0.5)
+ elif isinstance(module, Sam3LiteTextTextModel):
+ init.normal_(module.projection.weight, std=module.config.hidden_size**-0.5)
+
+
+@auto_docstring(
+ custom_intro="""
+ MobileCLIP MCT text encoder used in EfficientSAM3 LiteText.
+
+ When `config.use_repmixer_blocks` is `True`, the first and last layers are
+ `Sam3LiteTextRepMixerBlock` modules; the rest are standard `Sam3LiteTextTextEncoderLayer` layers.
+"""
+)
+class Sam3LiteTextTextModel(Sam3LiteTextPreTrainedModel):
+ config_class = Sam3LiteTextTextConfig
+ config: Sam3LiteTextTextConfig
+ _can_record_outputs = {
+ "hidden_states": Sam3LiteTextTextEncoderLayer,
+ "attentions": Sam3LiteTextTextAttention,
+ }
+
+ def __init__(self, config: Sam3LiteTextTextConfig):
+ super().__init__(config)
+ self.embeddings = Sam3LiteTextTextEmbeddings(config)
+ repmixer_positions = {0, config.num_hidden_layers - 1} if config.use_repmixer_blocks else set()
+ self.layers = nn.ModuleList(
+ [
+ Sam3LiteTextRepMixerBlock(config) if i in repmixer_positions else Sam3LiteTextTextEncoderLayer(config)
+ for i in range(config.num_hidden_layers)
+ ]
+ )
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
+ self.post_init()
+
+ @merge_with_config_defaults
+ @capture_outputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | Sam3LiteTextTextEncoderOutput:
+ hidden_states = self.embeddings(input_ids)
+ attention_mask = create_bidirectional_mask(self.config, hidden_states, attention_mask)
+
+ for layer in self.layers:
+ hidden_states = layer(hidden_states, attention_mask=attention_mask, **kwargs)
+
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ pooled = hidden_states[
+ torch.arange(hidden_states.shape[0], device=hidden_states.device), input_ids.argmax(dim=-1)
+ ]
+ pooled = self.projection(pooled)
+ return Sam3LiteTextTextEncoderOutput(
+ last_hidden_state=hidden_states,
+ pooler_output=pooled,
+ )
+
+
+class Sam3LiteTextModel(Sam3Model):
+ # DETR components create float masks from features, so flash/flex attention cannot be dispatched safely.
+ _supports_flash_attn = False
+ _supports_flex_attn = False
+
+ def __init__(self, config: Sam3LiteTextConfig):
+ super().__init__(config)
+ self.text_encoder = Sam3LiteTextTextModel(config.text_config)
+ self.vision_encoder = AutoModel.from_config(config.vision_config)
+
+
+__all__ = [
+ "Sam3LiteTextConfig",
+ "Sam3LiteTextTextConfig",
+ "Sam3LiteTextGeometryEncoderConfig",
+ "Sam3LiteTextDETREncoderConfig",
+ "Sam3LiteTextDETRDecoderConfig",
+ "Sam3LiteTextMaskDecoderConfig",
+ "Sam3LiteTextModel",
+ "Sam3LiteTextPreTrainedModel",
+ "Sam3LiteTextTextModel",
+]
diff --git a/src/transformers/models/sam3_video/convert_sam3_video_to_hf.py b/src/transformers/models/sam3_video/convert_sam3_video_to_hf.py
index f021aea27329..6cc00e80dfe9 100644
--- a/src/transformers/models/sam3_video/convert_sam3_video_to_hf.py
+++ b/src/transformers/models/sam3_video/convert_sam3_video_to_hf.py
@@ -26,7 +26,7 @@
from transformers import CLIPTokenizerFast
from transformers.models.sam2_video.video_processing_sam2_video import Sam2VideoVideoProcessor
-from transformers.models.sam3.image_processing_sam3_fast import Sam3ImageProcessorFast
+from transformers.models.sam3.image_processing_sam3 import Sam3ImageProcessor
from transformers.models.sam3.modeling_sam3 import Sam3Model
from transformers.models.sam3_tracker.modeling_sam3_tracker import Sam3TrackerModel
from transformers.models.sam3_tracker_video.modeling_sam3_tracker_video import Sam3TrackerVideoModel
@@ -664,7 +664,7 @@ def convert_sam3_checkpoint(
# Save processor
print("Creating and saving processor...")
- image_processor = Sam3ImageProcessorFast()
+ image_processor = Sam3ImageProcessor()
video_processor = Sam2VideoVideoProcessor(
image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5], size={"height": 1008, "width": 1008}
)
diff --git a/tests/models/sam3/test_modeling_sam3.py b/tests/models/sam3/test_modeling_sam3.py
index 6bbd5947c263..df94063c0a7f 100644
--- a/tests/models/sam3/test_modeling_sam3.py
+++ b/tests/models/sam3/test_modeling_sam3.py
@@ -17,8 +17,7 @@
import tempfile
import unittest
-import requests
-
+from transformers.image_utils import load_image
from transformers.testing_utils import (
backend_empty_cache,
require_deterministic_for_xpu,
@@ -26,11 +25,12 @@
slow,
torch_device,
)
-from transformers.utils import is_torch_available, is_vision_available
+from transformers.utils import is_torch_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
+from ...test_processing_common import url_to_local_path
if is_torch_available():
@@ -50,10 +50,6 @@
from transformers.models.sam3.processing_sam3 import Sam3Processor
-if is_vision_available():
- from PIL import Image
-
-
class Sam3VisionModelTester:
def __init__(
self,
@@ -973,16 +969,14 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
def prepare_coco_cat_image():
"""Prepare COCO cat and laptop image (from batched inference notebook)."""
- img_url = "http://images.cocodataset.org/val2017/000000077595.jpg"
- raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
- return raw_image
+ img_url = url_to_local_path("http://images.cocodataset.org/val2017/000000077595.jpg")
+ return load_image(img_url).convert("RGB")
def prepare_coco_kitchen_image():
"""Prepare COCO kitchen scene image (from batched inference notebook)."""
- img_url = "http://images.cocodataset.org/val2017/000000136466.jpg"
- raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
- return raw_image
+ img_url = url_to_local_path("http://images.cocodataset.org/val2017/000000136466.jpg")
+ return load_image(img_url).convert("RGB")
@slow
diff --git a/tests/models/sam3_lite_text/__init__.py b/tests/models/sam3_lite_text/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/sam3_lite_text/test_modeling_sam3_lite_text.py b/tests/models/sam3_lite_text/test_modeling_sam3_lite_text.py
new file mode 100644
index 000000000000..05a9307bfa87
--- /dev/null
+++ b/tests/models/sam3_lite_text/test_modeling_sam3_lite_text.py
@@ -0,0 +1,1346 @@
+# Copyright 2026 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch SAM3 LiteText model."""
+
+import gc
+import tempfile
+import unittest
+
+from transformers.image_utils import load_image
+from transformers.testing_utils import (
+ backend_empty_cache,
+ require_deterministic_for_xpu,
+ require_torch,
+ slow,
+ torch_device,
+)
+from transformers.utils import is_torch_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor
+from ...test_pipeline_mixin import PipelineTesterMixin
+from ...test_processing_common import url_to_local_path
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+ from transformers.models.sam3.configuration_sam3 import Sam3VisionConfig, Sam3ViTConfig
+ from transformers.models.sam3.processing_sam3 import Sam3Processor as Sam3LiteTextProcessor
+ from transformers.models.sam3_lite_text.configuration_sam3_lite_text import (
+ Sam3LiteTextConfig,
+ Sam3LiteTextDETRDecoderConfig,
+ Sam3LiteTextDETREncoderConfig,
+ Sam3LiteTextGeometryEncoderConfig,
+ Sam3LiteTextMaskDecoderConfig,
+ )
+ from transformers.models.sam3_lite_text.modeling_sam3_lite_text import Sam3LiteTextModel
+
+
+class Sam3LiteTextModelTester:
+ def __init__(
+ self,
+ parent,
+ num_channels=3,
+ image_size=224, # Keep reasonable size: 224 = 16 * 14
+ hidden_size=32,
+ patch_size=14,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ intermediate_size=64,
+ window_size=8, # 224/14 = 16 patches, 16/2 = 8 per window
+ global_attn_indexes=None,
+ fpn_hidden_size=32,
+ scale_factors=None,
+ geometry_encoder_hidden_size=32,
+ geometry_encoder_num_layers=1, # Reduced from 2 to 1
+ detr_encoder_hidden_size=32,
+ detr_encoder_num_layers=1, # Reduced from 2 to 1
+ detr_decoder_hidden_size=32,
+ detr_decoder_num_layers=1, # Reduced from 2 to 1
+ detr_decoder_num_queries=5, # Reduced from 10 to 5
+ mask_decoder_hidden_size=32,
+ batch_size=2,
+ text_seq_length=16,
+ vocab_size=100,
+ is_training=True,
+ ):
+ if global_attn_indexes is None:
+ global_attn_indexes = [0, 1]
+ if scale_factors is None:
+ scale_factors = [2.0, 1.0] # Just 2 scales to reduce params
+
+ self.parent = parent
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.hidden_size = hidden_size
+ self.patch_size = patch_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.window_size = window_size
+ self.global_attn_indexes = global_attn_indexes
+ self.fpn_hidden_size = fpn_hidden_size
+ self.scale_factors = scale_factors
+ self.batch_size = batch_size
+ self.text_seq_length = text_seq_length
+ self.vocab_size = vocab_size
+ self.is_training = is_training
+
+ # Geometry encoder
+ self.geometry_encoder_hidden_size = geometry_encoder_hidden_size
+ self.geometry_encoder_num_layers = geometry_encoder_num_layers
+
+ # DETR encoder/decoder
+ self.detr_encoder_hidden_size = detr_encoder_hidden_size
+ self.detr_encoder_num_layers = detr_encoder_num_layers
+ self.detr_decoder_hidden_size = detr_decoder_hidden_size
+ self.detr_decoder_num_layers = detr_decoder_num_layers
+ self.detr_decoder_num_queries = detr_decoder_num_queries
+
+ # Mask decoder
+ self.mask_decoder_hidden_size = mask_decoder_hidden_size
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+ # Simple text input (will be processed by text encoder)
+ input_ids = torch.randint(0, self.vocab_size, (self.batch_size, self.text_seq_length), device=torch_device)
+ attention_mask = torch.ones_like(input_ids)
+
+ config = self.get_config()
+
+ return config, pixel_values, input_ids, attention_mask
+
+ def get_config(self):
+ backbone_config = Sam3ViTConfig(
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ num_channels=self.num_channels,
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ window_size=self.window_size,
+ global_attn_indexes=self.global_attn_indexes,
+ )
+
+ vision_config = Sam3VisionConfig(
+ backbone_config=backbone_config,
+ fpn_hidden_size=self.fpn_hidden_size,
+ scale_factors=self.scale_factors,
+ )
+
+ # Small text config for testing (instead of default full MobileCLIP model)
+ # use_repmixer_blocks=False ensures all layers are standard TransformerLayers so
+ # attention output, hidden state, and SDPA dispatch tests pass correctly.
+ text_config = {
+ "vocab_size": self.vocab_size,
+ "hidden_size": 32,
+ "intermediate_size": 64,
+ "projection_dim": 32,
+ "num_hidden_layers": self.num_hidden_layers,
+ "num_attention_heads": 4,
+ "max_position_embeddings": self.text_seq_length,
+ "hidden_act": "gelu",
+ "use_repmixer_blocks": False,
+ }
+
+ geometry_encoder_config = Sam3LiteTextGeometryEncoderConfig(
+ hidden_size=self.geometry_encoder_hidden_size,
+ num_layers=self.geometry_encoder_num_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ mask_fuser_hidden_size=self.geometry_encoder_hidden_size, # Match hidden_size to reduce params
+ mask_fuser_num_layers=1, # Reduce from default 2 to 1
+ )
+
+ detr_encoder_config = Sam3LiteTextDETREncoderConfig(
+ hidden_size=self.detr_encoder_hidden_size,
+ num_layers=self.detr_encoder_num_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ )
+
+ detr_decoder_config = Sam3LiteTextDETRDecoderConfig(
+ hidden_size=self.detr_decoder_hidden_size,
+ num_layers=self.detr_decoder_num_layers,
+ num_queries=self.detr_decoder_num_queries,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ )
+
+ mask_decoder_config = Sam3LiteTextMaskDecoderConfig(
+ hidden_size=self.mask_decoder_hidden_size,
+ num_upsampling_stages=2, # Reduced from 3 to 2
+ )
+
+ return Sam3LiteTextConfig(
+ vision_config=vision_config,
+ text_config=text_config,
+ geometry_encoder_config=geometry_encoder_config,
+ detr_encoder_config=detr_encoder_config,
+ detr_decoder_config=detr_decoder_config,
+ mask_decoder_config=mask_decoder_config,
+ )
+
+ def create_and_check_model(self, config, pixel_values, input_ids, attention_mask):
+ model = Sam3LiteTextModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
+
+ # Check output shapes
+ self.parent.assertIsNotNone(result.pred_masks)
+ self.parent.assertIsNotNone(result.pred_boxes)
+ self.parent.assertIsNotNone(result.pred_logits)
+
+ # Masks should be [batch_size, num_queries, H, W]
+ self.parent.assertEqual(result.pred_masks.shape[0], self.batch_size)
+ self.parent.assertEqual(result.pred_masks.shape[1], self.detr_decoder_num_queries)
+
+ # Boxes should be [batch_size, num_queries, 4]
+ self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.detr_decoder_num_queries, 4))
+
+ # Logits should be [batch_size, num_queries]
+ self.parent.assertEqual(result.pred_logits.shape, (self.batch_size, self.detr_decoder_num_queries))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, input_ids, attention_mask = config_and_inputs
+ inputs_dict = {
+ "pixel_values": pixel_values,
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class Sam3LiteTextModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ """
+ Tests for SAM3 full model.
+ """
+
+ all_model_classes = (Sam3LiteTextModel,) if is_torch_available() else ()
+ pipeline_model_mapping = {"mask-generation": Sam3LiteTextModel} if is_torch_available() else {}
+
+ test_resize_embeddings = False
+ _is_composite = True
+
+ def setUp(self):
+ self.model_tester = Sam3LiteTextModelTester(self)
+ common_properties = ["initializer_range"]
+ self.config_tester = ConfigTester(
+ self, config_class=Sam3LiteTextConfig, has_text_modality=False, common_properties=common_properties
+ )
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ @unittest.skip(reason="SAM3 does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ def test_model_get_set_embeddings(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ # Vision encoder has input embeddings
+ self.assertIsInstance(model.vision_encoder.get_input_embeddings(), (nn.Module))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, nn.Linear))
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_batching_equivalence(self, atol=5e-4, rtol=5e-4):
+ super().test_batching_equivalence(atol=atol, rtol=rtol)
+
+ # Override as SAM3Model has component-specific attention outputs
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class._from_config(config, attn_implementation="eager")
+ config = model.config
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ # Check that we have the component-specific attention outputs
+ # Note: Some may be empty tuples if attentions aren't collected for that component
+ self.assertIsNotNone(outputs.vision_attentions)
+ self.assertIsNotNone(outputs.detr_encoder_attentions)
+ self.assertIsNotNone(outputs.detr_decoder_attentions)
+ self.assertIsNotNone(outputs.mask_decoder_attentions)
+
+ # Check vision attentions (from ViT backbone) - should be properly collected
+ if outputs.vision_attentions:
+ vision_attentions = outputs.vision_attentions
+ self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers)
+
+ # Check that at least vision attentions are present (others may require different collection mechanism)
+ self.assertTrue(
+ len(outputs.vision_attentions) > 0,
+ "At least vision attentions should be collected when output_attentions=True",
+ )
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ for k in config.sub_configs:
+ if (subconfig := getattr(config, k)) is not None:
+ subconfig.output_attentions = True
+ # Sam3LiteText has a vision subconfig with itself a sub config....
+ for k in subconfig.sub_configs:
+ if (subsubconfig := getattr(subconfig, k)) is not None:
+ subsubconfig.output_attentions = True
+
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ # Verify again with config-based setting
+ self.assertIsNotNone(outputs.vision_attentions)
+ self.assertIsNotNone(outputs.detr_encoder_attentions)
+ self.assertIsNotNone(outputs.detr_decoder_attentions)
+ self.assertIsNotNone(outputs.mask_decoder_attentions)
+
+ # Override as SAM3Model has component-specific attention/hidden state outputs
+ def test_retain_grad_hidden_states_attentions(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for k in config.sub_configs:
+ if getattr(config, k) is not None:
+ getattr(config, k).output_hidden_states = True
+ getattr(config, k).output_attentions = True
+
+ config.output_hidden_states = True
+ config.output_attentions = True
+ config._attn_implementation = "eager"
+
+ # Use first model class
+ model_class = self.all_model_classes[0]
+ model = model_class._from_config(config, attn_implementation="eager")
+ model.to(torch_device)
+
+ inputs = self._prepare_for_class(inputs_dict, model_class)
+ outputs = model(**inputs)
+
+ output = outputs[0]
+
+ # SAM3 has component-specific hidden states and attentions
+ # Check vision hidden states and attentions
+ if outputs.vision_hidden_states is not None and len(outputs.vision_hidden_states) > 0:
+ vision_hidden_states = outputs.vision_hidden_states[0]
+ vision_hidden_states.retain_grad()
+
+ if outputs.vision_attentions is not None and len(outputs.vision_attentions) > 0:
+ vision_attentions = outputs.vision_attentions[0]
+ vision_attentions.retain_grad()
+
+ # Check DETR encoder hidden states and attentions
+ if outputs.encoder_hidden_states is not None and len(outputs.encoder_hidden_states) > 0:
+ encoder_hidden_states = outputs.encoder_hidden_states[0]
+ encoder_hidden_states.retain_grad()
+
+ if outputs.detr_encoder_attentions is not None and len(outputs.detr_encoder_attentions) > 0:
+ detr_encoder_attentions = outputs.detr_encoder_attentions[0]
+ detr_encoder_attentions.retain_grad()
+
+ # Check DETR decoder hidden states and attentions
+ if outputs.decoder_hidden_states is not None and len(outputs.decoder_hidden_states) > 0:
+ decoder_hidden_states = outputs.decoder_hidden_states[0]
+ decoder_hidden_states.retain_grad()
+
+ if outputs.detr_decoder_attentions is not None and len(outputs.detr_decoder_attentions) > 0:
+ detr_decoder_attentions = outputs.detr_decoder_attentions[0]
+ detr_decoder_attentions.retain_grad()
+
+ # Check mask decoder attentions
+ if outputs.mask_decoder_attentions is not None and len(outputs.mask_decoder_attentions) > 0:
+ mask_decoder_attentions = outputs.mask_decoder_attentions[0]
+ mask_decoder_attentions.retain_grad()
+
+ output.flatten()[0].backward(retain_graph=True)
+
+ # Check gradients are not None
+ if outputs.vision_hidden_states is not None and len(outputs.vision_hidden_states) > 0:
+ self.assertIsNotNone(vision_hidden_states.grad)
+
+ if outputs.vision_attentions is not None and len(outputs.vision_attentions) > 0:
+ self.assertIsNotNone(vision_attentions.grad)
+
+ if outputs.encoder_hidden_states is not None and len(outputs.encoder_hidden_states) > 0:
+ self.assertIsNotNone(encoder_hidden_states.grad)
+
+ if outputs.detr_encoder_attentions is not None and len(outputs.detr_encoder_attentions) > 0:
+ self.assertIsNotNone(detr_encoder_attentions.grad)
+
+ if outputs.decoder_hidden_states is not None and len(outputs.decoder_hidden_states) > 0:
+ self.assertIsNotNone(decoder_hidden_states.grad)
+
+ if outputs.detr_decoder_attentions is not None and len(outputs.detr_decoder_attentions) > 0:
+ self.assertIsNotNone(detr_decoder_attentions.grad)
+
+ if outputs.mask_decoder_attentions is not None and len(outputs.mask_decoder_attentions) > 0:
+ self.assertIsNotNone(mask_decoder_attentions.grad)
+
+ def test_hidden_states_output(self):
+ """Test that SAM3 properly outputs component-specific hidden states."""
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ # Enable hidden states output
+ config.output_hidden_states = True
+ for k in config.sub_configs:
+ if getattr(config, k) is not None:
+ getattr(config, k).output_hidden_states = True
+
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ # SAM3 has component-specific hidden states
+ # Check vision hidden states
+ if outputs.vision_hidden_states is not None:
+ vision_hidden_states = outputs.vision_hidden_states
+ self.assertIsInstance(vision_hidden_states, (list, tuple))
+ # Vision encoder outputs hidden states from each layer
+ expected_num_vision_layers = self.model_tester.num_hidden_layers + 1 # +1 for embeddings
+ self.assertEqual(len(vision_hidden_states), expected_num_vision_layers)
+
+ # Check DETR encoder hidden states (stored as encoder_hidden_states)
+ if outputs.encoder_hidden_states is not None:
+ encoder_hidden_states = outputs.encoder_hidden_states
+ self.assertIsInstance(encoder_hidden_states, (list, tuple))
+
+ # Check DETR decoder hidden states (stored as decoder_hidden_states)
+ if outputs.decoder_hidden_states is not None:
+ decoder_hidden_states = outputs.decoder_hidden_states
+ self.assertIsInstance(decoder_hidden_states, (list, tuple))
+
+ def test_sdpa_can_dispatch_composite_models(self):
+ """
+ Tests if composite models dispatch correctly on SDPA/eager when requested.
+ SAM3 has multiple sub-models: vision_encoder, text_encoder, geometry_encoder,
+ detr_encoder, detr_decoder, mask_decoder.
+ """
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ if not self._is_composite:
+ self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
+
+ for model_class in self.all_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
+ model_sdpa = model_sdpa.eval().to(torch_device)
+
+ # Get all sub-models that support attention
+ vision_encoder_sdpa = getattr(model_sdpa, "vision_encoder")
+ text_encoder_sdpa = getattr(model_sdpa, "text_encoder", None)
+ detr_encoder_sdpa = getattr(model_sdpa, "detr_encoder", None)
+ detr_decoder_sdpa = getattr(model_sdpa, "detr_decoder", None)
+ mask_decoder_sdpa = getattr(model_sdpa, "mask_decoder", None)
+
+ # Check that sub-models dispatch to SDPA if they support it
+ self.assertTrue(vision_encoder_sdpa.config._attn_implementation == "sdpa")
+ if text_encoder_sdpa is not None and hasattr(text_encoder_sdpa, "_supports_sdpa"):
+ # Sam3LiteTextTextModel supports SDPA
+ self.assertTrue(text_encoder_sdpa.config._attn_implementation == "sdpa")
+ if detr_encoder_sdpa is not None:
+ self.assertTrue(detr_encoder_sdpa.config._attn_implementation == "sdpa")
+ if detr_decoder_sdpa is not None:
+ self.assertTrue(detr_decoder_sdpa.config._attn_implementation == "sdpa")
+ if mask_decoder_sdpa is not None:
+ self.assertTrue(mask_decoder_sdpa.config._attn_implementation == "sdpa")
+
+ # Now test with eager
+ model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
+ model_eager = model_eager.eval().to(torch_device)
+
+ self.assertTrue(getattr(model_eager, "vision_encoder").config._attn_implementation == "eager")
+ if hasattr(model_eager, "text_encoder") and hasattr(model_eager.text_encoder, "config"):
+ self.assertTrue(model_eager.text_encoder.config._attn_implementation == "eager")
+ if hasattr(model_eager, "detr_encoder"):
+ self.assertTrue(model_eager.detr_encoder.config._attn_implementation == "eager")
+ if hasattr(model_eager, "detr_decoder"):
+ self.assertTrue(model_eager.detr_decoder.config._attn_implementation == "eager")
+ if hasattr(model_eager, "mask_decoder"):
+ self.assertTrue(model_eager.mask_decoder.config._attn_implementation == "eager")
+
+ # Verify no SDPA layers in eager model
+ for name, submodule in model_eager.named_modules():
+ class_name = submodule.__class__.__name__
+ if (
+ class_name.endswith("Attention")
+ and getattr(submodule, "config", None)
+ and submodule.config._attn_implementation == "sdpa"
+ ):
+ raise ValueError("The eager model should not have SDPA attention layers")
+
+ def test_forward_with_text_embeds(self):
+ """Test that text_embeds parameter works correctly."""
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ # First get text embeddings
+ with torch.no_grad():
+ text_embeds = model.get_text_features(
+ input_ids=inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"], return_dict=True
+ ).pooler_output
+
+ # Forward with text_embeds (remove input_ids)
+ inputs_with_embeds = {
+ "pixel_values": inputs_dict["pixel_values"],
+ "text_embeds": text_embeds,
+ }
+
+ with torch.no_grad():
+ outputs_with_embeds = model(**inputs_with_embeds)
+
+ # Forward with input_ids
+ with torch.no_grad():
+ outputs_with_ids = model(**inputs_dict)
+
+ # Outputs should be very close
+ self.assertTrue(torch.allclose(outputs_with_embeds.pred_logits, outputs_with_ids.pred_logits, atol=1e-5))
+ self.assertTrue(torch.allclose(outputs_with_embeds.pred_boxes, outputs_with_ids.pred_boxes, atol=1e-5))
+
+ def test_forward_with_both_input_ids_and_text_embeds_raises_error(self):
+ """Test that passing both input_ids and text_embeds raises an error."""
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ # Get text embeddings
+ with torch.no_grad():
+ text_embeds = model.get_text_features(
+ input_ids=inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"]
+ )
+
+ # Try to pass both (should raise error)
+ inputs_with_both = {
+ "pixel_values": inputs_dict["pixel_values"],
+ "input_ids": inputs_dict["input_ids"],
+ "text_embeds": text_embeds,
+ }
+
+ with self.assertRaises(ValueError):
+ model(**inputs_with_both)
+
+ def test_forward_with_vision_embeds(self):
+ """Test that vision_embeds parameter works correctly."""
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ # First get vision embeddings
+ with torch.no_grad():
+ vision_embeds = model.get_vision_features(pixel_values=inputs_dict["pixel_values"])
+
+ # Forward with vision_embeds (remove pixel_values)
+ inputs_with_embeds = {
+ "vision_embeds": vision_embeds,
+ "input_ids": inputs_dict["input_ids"],
+ "attention_mask": inputs_dict["attention_mask"],
+ }
+
+ with torch.no_grad():
+ outputs_with_embeds = model(**inputs_with_embeds)
+
+ # Forward with pixel_values
+ with torch.no_grad():
+ outputs_with_pixels = model(**inputs_dict)
+
+ # Outputs should be very close
+ self.assertTrue(
+ torch.allclose(outputs_with_embeds.pred_logits, outputs_with_pixels.pred_logits, atol=1e-5)
+ )
+ self.assertTrue(torch.allclose(outputs_with_embeds.pred_boxes, outputs_with_pixels.pred_boxes, atol=1e-5))
+
+ def test_forward_with_both_pixel_values_and_vision_embeds_raises_error(self):
+ """Test that passing both pixel_values and vision_embeds raises an error."""
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ # Get vision embeddings
+ with torch.no_grad():
+ vision_embeds = model.get_vision_features(pixel_values=inputs_dict["pixel_values"])
+
+ # Try to pass both (should raise error)
+ inputs_with_both = {
+ "pixel_values": inputs_dict["pixel_values"],
+ "vision_embeds": vision_embeds,
+ "input_ids": inputs_dict["input_ids"],
+ "attention_mask": inputs_dict["attention_mask"],
+ }
+
+ with self.assertRaises(ValueError):
+ model(**inputs_with_both)
+
+ def test_custom_image_size(self):
+ """Test that custom image size can be set and propagates correctly through nested configs."""
+ config = self.model_tester.get_config()
+ config.image_size = 560
+
+ self.assertEqual(config.image_size, 560)
+ self.assertEqual(config.vision_config.image_size, 560)
+ self.assertEqual(config.vision_config.backbone_config.image_size, 560)
+
+ # Verify model works with custom size
+ model = Sam3LiteTextModel(config=config).to(torch_device).eval()
+ pixel_values = floats_tensor([self.model_tester.batch_size, self.model_tester.num_channels, 560, 560]).to(
+ torch_device
+ )
+ input_ids = torch.randint(
+ 0,
+ config.text_config.vocab_size,
+ (self.model_tester.batch_size, self.model_tester.text_seq_length),
+ device=torch_device,
+ )
+
+ with torch.no_grad():
+ outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=torch.ones_like(input_ids))
+
+ self.assertIsNotNone(outputs.pred_masks)
+ self.assertIsNotNone(outputs.pred_boxes)
+ self.assertIsNotNone(outputs.pred_logits)
+
+ @unittest.skip(reason="SAM3 LiteText model can't be compiled dynamic yet")
+ def test_sdpa_can_compile_dynamic(self):
+ pass
+
+ @unittest.skip(
+ reason="Sam3LiteTextModel creates float attention masks from features (with gradients) in the DETR "
+ "encoder/decoder, which Flash Attention requires to be None."
+ )
+ def test_sdpa_can_dispatch_on_flash(self):
+ pass
+
+ def test_model_outputs_equivalence(self):
+ """
+ Test that tuple and dict outputs are equivalent.
+ SAM3 returns complex outputs with component-specific fields, so we need to ensure proper conversion.
+ """
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ def set_nan_tensor_to_zero(t):
+ t[t != t] = 0
+ return t
+
+ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
+ with torch.no_grad():
+ tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
+ dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
+
+ def recursive_check(tuple_object, dict_object):
+ if isinstance(tuple_object, (list, tuple)):
+ for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
+ recursive_check(tuple_iterable_value, dict_iterable_value)
+ elif isinstance(tuple_object, dict):
+ for tuple_iterable_value, dict_iterable_value in zip(
+ tuple_object.values(), dict_object.values()
+ ):
+ recursive_check(tuple_iterable_value, dict_iterable_value)
+ elif tuple_object is None:
+ return
+ # model might return non-tensors objects (e.g. Cache class)
+ elif isinstance(tuple_object, torch.Tensor):
+ self.assertTrue(
+ torch.allclose(
+ set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
+ ),
+ msg=(
+ "Tuple and dict output are not equal. Difference:"
+ f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
+ f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
+ f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
+ ),
+ )
+
+ recursive_check(tuple_output, dict_output)
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class)
+ check_equivalence(model, tuple_inputs, dict_inputs)
+
+ # Test with output_hidden_states
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class)
+ check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
+
+ # Test with output_attentions if supported
+ if self.has_attentions:
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class)
+ check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
+
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class)
+ check_equivalence(
+ model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
+ )
+
+ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+ """Override to ensure input_ids and attention_mask are always present for Sam3LiteTextModel."""
+ inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
+
+ # Sam3LiteTextModel always requires input_ids and attention_mask for text encoding
+ if model_class == Sam3LiteTextModel:
+ if "input_ids" not in inputs_dict or inputs_dict.get("input_ids") is None:
+ # Create dummy input_ids if not present
+ # Get batch_size from pixel_values or vision_embeds
+ if "pixel_values" in inputs_dict and inputs_dict.get("pixel_values") is not None:
+ batch_size = inputs_dict["pixel_values"].shape[0]
+ elif "vision_embeds" in inputs_dict and inputs_dict.get("vision_embeds") is not None:
+ vision_embeds = inputs_dict["vision_embeds"]
+ if vision_embeds.fpn_hidden_states is not None and len(vision_embeds.fpn_hidden_states) > 0:
+ batch_size = vision_embeds.fpn_hidden_states[0].shape[0]
+ elif vision_embeds.last_hidden_state is not None:
+ batch_size = vision_embeds.last_hidden_state.shape[0]
+ else:
+ batch_size = 2
+ else:
+ batch_size = 2
+ config = self.model_tester.get_config()
+ # text_config might be a dict or a config object
+ if isinstance(config.text_config, dict):
+ vocab_size = config.text_config.get("vocab_size", 1000)
+ else:
+ vocab_size = getattr(config.text_config, "vocab_size", 1000)
+ inputs_dict["input_ids"] = torch.randint(0, vocab_size, (batch_size, 16), device=torch_device)
+ if "attention_mask" not in inputs_dict or inputs_dict.get("attention_mask") is None:
+ inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["input_ids"])
+
+ return inputs_dict
+
+
+def prepare_coco_cat_image():
+ """Prepare COCO cat and laptop image (from batched inference notebook)."""
+ img_url = url_to_local_path("http://images.cocodataset.org/val2017/000000077595.jpg")
+ return load_image(img_url).convert("RGB")
+
+
+def prepare_coco_kitchen_image():
+ """Prepare COCO kitchen scene image (from batched inference notebook)."""
+ img_url = url_to_local_path("http://images.cocodataset.org/val2017/000000136466.jpg")
+ return load_image(img_url).convert("RGB")
+
+
+@slow
+@require_torch
+class Sam3LiteTextModelIntegrationTest(unittest.TestCase):
+ """Integration tests for SAM3 model with real pretrained weights."""
+
+ def setUp(self):
+ super().setUp()
+ model_name = "yonigozlan/sam3-litetext-s0"
+ self.model = Sam3LiteTextModel.from_pretrained(model_name, dtype=torch.float32)
+ self.processor = Sam3LiteTextProcessor.from_pretrained(model_name)
+ self.model.to(torch_device)
+ self.model.eval()
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_inference_text_prompt_only(self):
+ """Test inference with text prompt only (from multiway_prompting notebook)."""
+ # Example from notebook: "short hair" text prompt
+ raw_image = prepare_coco_cat_image()
+ text = "ear"
+
+ inputs = self.processor(images=raw_image, text=text, return_tensors="pt").to(torch_device)
+
+ with torch.no_grad():
+ outputs = self.model(**inputs)
+
+ # Check exact output shapes
+ self.assertEqual(outputs.pred_masks.shape, (1, 200, 288, 288))
+ self.assertEqual(outputs.pred_boxes.shape, (1, 200, 4))
+ self.assertEqual(outputs.pred_logits.shape, (1, 200))
+
+ # Check that predictions have reasonable scores (after sigmoid)
+ scores = torch.sigmoid(outputs.pred_logits)
+ self.assertTrue((scores >= 0).all() and (scores <= 1).all())
+
+ # Check exact values
+ sorted_indices = torch.argsort(scores.squeeze(), descending=True)
+ top_scores = scores.squeeze()[sorted_indices[:3]]
+ top_logits = outputs.pred_logits.squeeze()[sorted_indices[:3]]
+ top_idx = sorted_indices[0].item()
+
+ torch.testing.assert_close(
+ top_scores, torch.tensor([0.9326, 0.9149, 0.1009]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+ torch.testing.assert_close(
+ top_logits, torch.tensor([2.6268, 2.3755, -2.1877]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+ torch.testing.assert_close(
+ outputs.pred_boxes[0, top_idx],
+ torch.tensor([0.4704, 0.2015, 0.5615, 0.3770]).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+ torch.testing.assert_close(
+ outputs.pred_masks[0, top_idx, :3, :3],
+ torch.tensor(
+ [[-2.1856, -6.2395, -7.0870], [-6.0620, -10.4022, -11.2534], [-8.7157, -10.7961, -9.9844]]
+ ).to(torch_device),
+ atol=1e-2,
+ rtol=1e-2,
+ )
+
+ # Test post-processing
+ results = self.processor.post_process_instance_segmentation(
+ outputs, threshold=0.5, mask_threshold=0.5, target_sizes=inputs.get("original_sizes").tolist()
+ )
+ self.assertEqual(len(results), 1)
+ result = results[0]
+
+ # Check that we have detections
+ self.assertGreater(len(result["masks"]), 0)
+ self.assertGreater(len(result["boxes"]), 0)
+ self.assertGreater(len(result["scores"]), 0)
+
+ # Check exact values for top detection
+ top_pp_score = result["scores"][0]
+ top_pp_box = result["boxes"][0]
+
+ torch.testing.assert_close(top_pp_score, torch.tensor(0.9137).to(torch_device), atol=1e-4, rtol=1e-4)
+ torch.testing.assert_close(
+ top_pp_box, torch.tensor([402.3560, 90.1643, 459.6652, 156.5201]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+
+ def test_inference_single_box_prompt(self):
+ """Test inference with a single bounding box prompt (from batched_inference notebook)."""
+ raw_image = prepare_coco_cat_image()
+ # Example from notebook: laptop region in image 1
+ # Box in xyxy format: [100, 150, 500, 450]
+ box_xyxy = [100, 150, 500, 450]
+ input_boxes = [[box_xyxy]]
+
+ inputs = self.processor(
+ images=raw_image,
+ input_boxes=input_boxes,
+ input_boxes_labels=[[1]], # Positive box
+ return_tensors="pt",
+ ).to(torch_device)
+
+ with torch.no_grad():
+ outputs = self.model(**inputs)
+
+ # Check exact output shapes
+ self.assertEqual(outputs.pred_masks.shape, (1, 200, 288, 288))
+ self.assertEqual(outputs.pred_boxes.shape, (1, 200, 4))
+ self.assertEqual(outputs.pred_logits.shape, (1, 200))
+
+ # Check exact values
+ scores = torch.sigmoid(outputs.pred_logits)
+ sorted_indices = torch.argsort(scores.squeeze(), descending=True)
+ top_scores = scores.squeeze()[sorted_indices[:3]]
+ top_logits = outputs.pred_logits.squeeze()[sorted_indices[:3]]
+ top_idx = sorted_indices[0].item()
+
+ torch.testing.assert_close(
+ top_scores, torch.tensor([0.9387, 0.1474, 0.1087]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+ torch.testing.assert_close(
+ top_logits, torch.tensor([2.7291, -1.7554, -2.1046]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+ torch.testing.assert_close(
+ outputs.pred_boxes[0, top_idx],
+ torch.tensor([0.1628, 0.4168, 0.7534, 0.9935]).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+ torch.testing.assert_close(
+ outputs.pred_masks[0, top_idx, :3, :3],
+ torch.tensor([[-1.9982, -3.7409, -3.9956], [-3.6554, -5.9248, -6.0131], [-4.5402, -6.0183, -6.2711]]).to(
+ torch_device
+ ),
+ atol=1e-2,
+ rtol=1e-2,
+ )
+
+ # Test post-processing
+ results = self.processor.post_process_instance_segmentation(
+ outputs, threshold=0.5, mask_threshold=0.5, target_sizes=inputs.get("original_sizes").tolist()
+ )
+ self.assertEqual(len(results), 1)
+ result = results[0]
+
+ # Check that we have detections
+ self.assertGreater(len(result["masks"]), 0)
+
+ # Check exact values for top detection
+ top_pp_score = result["scores"][0]
+ top_pp_box = result["boxes"][0]
+
+ torch.testing.assert_close(top_pp_score, torch.tensor(0.9387).to(torch_device), atol=1e-4, rtol=1e-4)
+ torch.testing.assert_close(
+ top_pp_box, torch.tensor([104.2026, 177.1267, 482.1449, 422.2436]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+
+ def test_inference_multi_box_prompt(self):
+ """Test inference with multiple box prompts with positive and negative labels (from batched_inference notebook)."""
+ raw_image = prepare_coco_kitchen_image()
+ # Example from notebook: multiple positive boxes (dial + button)
+ # Dial box (xyxy): [59, 144, 76, 163]
+ # Button box (xyxy): [87, 148, 104, 159]
+ box1_xyxy = [59, 144, 76, 163]
+ box2_xyxy = [87, 148, 104, 159]
+
+ input_boxes = [[box1_xyxy, box2_xyxy]]
+ input_boxes_labels = [[1, 1]] # Both positive
+
+ inputs = self.processor(
+ images=raw_image, input_boxes=input_boxes, input_boxes_labels=input_boxes_labels, return_tensors="pt"
+ ).to(torch_device)
+
+ with torch.no_grad():
+ outputs = self.model(**inputs)
+
+ # Check exact output shapes
+ self.assertEqual(outputs.pred_masks.shape, (1, 200, 288, 288))
+ self.assertEqual(outputs.pred_boxes.shape, (1, 200, 4))
+ self.assertEqual(outputs.pred_logits.shape, (1, 200))
+
+ # Check exact values
+ scores = torch.sigmoid(outputs.pred_logits)
+ sorted_indices = torch.argsort(scores.squeeze(), descending=True)
+ top_scores = scores.squeeze()[sorted_indices[:3]]
+ top_logits = outputs.pred_logits.squeeze()[sorted_indices[:3]]
+ top_idx = sorted_indices[0].item()
+
+ torch.testing.assert_close(
+ top_scores, torch.tensor([0.9615, 0.9399, 0.8262]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+ torch.testing.assert_close(
+ top_logits, torch.tensor([3.2166, 2.7504, 1.5591]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+ torch.testing.assert_close(
+ outputs.pred_boxes[0, top_idx],
+ torch.tensor([0.1758, 0.2889, 0.2296, 0.3258]).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+ torch.testing.assert_close(
+ outputs.pred_masks[0, top_idx, :3, :3],
+ torch.tensor(
+ [[-9.1072, -15.0329, -18.6687], [-14.1670, -21.3808, -26.8545], [-15.5131, -23.4600, -17.5040]]
+ ).to(torch_device),
+ atol=1e-2,
+ rtol=1e-2,
+ )
+
+ # Test post-processing
+ results = self.processor.post_process_instance_segmentation(
+ outputs, threshold=0.5, mask_threshold=0.5, target_sizes=inputs.get("original_sizes").tolist()
+ )
+ self.assertEqual(len(results), 1)
+ result = results[0]
+
+ # Check that we have detections
+ self.assertGreater(len(result["masks"]), 0)
+
+ # Check exact values for top detection
+ top_pp_score = result["scores"][0]
+ top_pp_box = result["boxes"][0]
+
+ torch.testing.assert_close(top_pp_score, torch.tensor(0.9399).to(torch_device), atol=1e-4, rtol=1e-4)
+ torch.testing.assert_close(
+ top_pp_box, torch.tensor([86.8764, 147.5362, 104.4958, 159.6079]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+
+ def test_inference_combined_prompts(self):
+ """Test inference with combined text and geometry prompts (text + negative box from batched_inference notebook)."""
+ raw_image = prepare_coco_kitchen_image()
+ # Example from notebook: text "handle" + negative box to exclude oven handle
+ text = "handle"
+ # Negative box covering the oven handle area (xyxy): [40, 183, 318, 204]
+ oven_handle_box = [40, 183, 318, 204]
+
+ input_boxes = [[oven_handle_box]]
+
+ inputs = self.processor(
+ images=raw_image,
+ text=text,
+ input_boxes=input_boxes,
+ input_boxes_labels=[[0]], # 0 = negative
+ return_tensors="pt",
+ ).to(torch_device)
+
+ with torch.no_grad():
+ outputs = self.model(**inputs)
+
+ # Check exact output shapes
+ self.assertEqual(outputs.pred_masks.shape, (1, 200, 288, 288))
+ self.assertEqual(outputs.pred_boxes.shape, (1, 200, 4))
+ self.assertEqual(outputs.pred_logits.shape, (1, 200))
+
+ def test_inference_batched_images(self):
+ """Test batched inference with multiple images (from batched_inference notebook)."""
+ # Example from notebook: batch of 2 images with different text prompts
+ raw_image1 = prepare_coco_cat_image()
+ raw_image2 = prepare_coco_kitchen_image()
+
+ # Batch of 2 images with different text prompts: "ear" for cat, "dial" for kitchen
+ inputs = self.processor(images=[raw_image1, raw_image2], text=["ear", "dial"], return_tensors="pt").to(
+ torch_device
+ )
+
+ with torch.no_grad():
+ outputs = self.model(**inputs)
+
+ # Check exact output shapes
+ self.assertEqual(outputs.pred_masks.shape, (2, 200, 288, 288))
+ self.assertEqual(outputs.pred_boxes.shape, (2, 200, 4))
+ self.assertEqual(outputs.pred_logits.shape, (2, 200))
+
+ # Check scores are reasonable
+ scores = torch.sigmoid(outputs.pred_logits)
+ self.assertTrue((scores >= 0).all() and (scores <= 1).all())
+
+ # Check exact values
+ sorted_indices_0 = torch.argsort(scores[0], descending=True)
+ sorted_indices_1 = torch.argsort(scores[1], descending=True)
+ top_scores_0 = scores[0][sorted_indices_0[:3]]
+ top_scores_1 = scores[1][sorted_indices_1[:3]]
+ top_logits_0 = outputs.pred_logits[0][sorted_indices_0[:3]]
+ top_logits_1 = outputs.pred_logits[1][sorted_indices_1[:3]]
+ top_idx_0 = sorted_indices_0[0].item()
+ top_idx_1 = sorted_indices_1[0].item()
+
+ torch.testing.assert_close(
+ top_scores_0, torch.tensor([0.9326, 0.9149, 0.1009]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+ torch.testing.assert_close(
+ top_scores_1, torch.tensor([0.8532, 0.8497, 0.8473]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+ torch.testing.assert_close(
+ top_logits_0, torch.tensor([2.6268, 2.3755, -2.1877]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+ torch.testing.assert_close(
+ top_logits_1, torch.tensor([1.7598, 1.7324, 1.7133]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+ torch.testing.assert_close(
+ outputs.pred_boxes[0, top_idx_0],
+ torch.tensor([0.4704, 0.2015, 0.5615, 0.3770]).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+ torch.testing.assert_close(
+ outputs.pred_boxes[1, top_idx_1],
+ torch.tensor([0.5088, 0.2749, 0.5775, 0.3228]).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+ torch.testing.assert_close(
+ outputs.pred_masks[0, top_idx_0, :3, :3],
+ torch.tensor(
+ [[-2.1858, -6.2385, -7.0859], [-6.0630, -10.4010, -11.2507], [-8.7169, -10.7953, -9.9857]]
+ ).to(torch_device),
+ atol=1e-2,
+ rtol=1e-2,
+ )
+ torch.testing.assert_close(
+ outputs.pred_masks[1, top_idx_1, :3, :3],
+ torch.tensor(
+ [[-5.7639, -10.4997, -11.1576], [-8.6470, -14.8703, -17.1438], [-8.9861, -14.4202, -12.6988]]
+ ).to(torch_device),
+ atol=1e-2,
+ rtol=1e-2,
+ )
+
+ # Test post-processing
+ results = self.processor.post_process_instance_segmentation(
+ outputs, threshold=0.3, mask_threshold=0.5, target_sizes=inputs.get("original_sizes").tolist()
+ )
+ self.assertEqual(len(results), 2)
+
+ # Check that both have detections
+ self.assertGreater(len(results[0]["masks"]), 0)
+ self.assertGreater(len(results[1]["masks"]), 0)
+
+ # Check exact values for top detection in each image
+ top_pp_score_0 = results[0]["scores"][0]
+ top_pp_box_0 = results[0]["boxes"][0]
+ top_pp_score_1 = results[1]["scores"][0]
+ top_pp_box_1 = results[1]["boxes"][0]
+
+ torch.testing.assert_close(top_pp_score_0, torch.tensor(0.9137).to(torch_device), atol=1e-4, rtol=1e-4)
+ torch.testing.assert_close(
+ top_pp_box_0, torch.tensor([402.3560, 90.1644, 459.6652, 156.5200]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+ torch.testing.assert_close(top_pp_score_1, torch.tensor(0.5882).to(torch_device), atol=1e-4, rtol=1e-4)
+ torch.testing.assert_close(
+ top_pp_box_1, torch.tensor([110.6558, 271.1691, 137.3885, 301.4023]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+
+ def test_inference_batched_mixed_prompts(self):
+ """Test batched inference with mixed prompt types (from batched_inference notebook)."""
+ # Example from notebook: Image 1 with text "laptop", Image 2 with visual prompt (dial)
+ raw_image1 = prepare_coco_cat_image()
+ raw_image2 = prepare_coco_kitchen_image()
+
+ # Box for dial in image 2 (xyxy): [59, 144, 76, 163]
+ box2_xyxy = [59, 144, 76, 163]
+
+ inputs = self.processor(
+ images=[raw_image1, raw_image2],
+ text=["laptop", None], # Only first image has text
+ input_boxes=[None, [box2_xyxy]], # Only second image has box
+ input_boxes_labels=[None, [1]],
+ return_tensors="pt",
+ ).to(torch_device)
+
+ with torch.no_grad():
+ outputs = self.model(**inputs)
+
+ # Check exact output shapes
+ self.assertEqual(outputs.pred_masks.shape, (2, 200, 288, 288))
+ self.assertEqual(outputs.pred_boxes.shape, (2, 200, 4))
+ self.assertEqual(outputs.pred_logits.shape, (2, 200))
+
+ # Check exact values
+ scores = torch.sigmoid(outputs.pred_logits)
+ sorted_indices_0 = torch.argsort(scores[0], descending=True)
+ sorted_indices_1 = torch.argsort(scores[1], descending=True)
+ top_scores_0 = scores[0][sorted_indices_0[:3]]
+ top_scores_1 = scores[1][sorted_indices_1[:3]]
+ top_logits_0 = outputs.pred_logits[0][sorted_indices_0[:3]]
+ top_logits_1 = outputs.pred_logits[1][sorted_indices_1[:3]]
+ top_idx_0 = sorted_indices_0[0].item()
+ top_idx_1 = sorted_indices_1[0].item()
+
+ # Top-1 score/logit is always stable; positions 2-3 can swap when scores are very close
+ torch.testing.assert_close(top_scores_0[0], torch.tensor(0.9696).to(torch_device), atol=1e-4, rtol=1e-4)
+ torch.testing.assert_close(top_scores_1[0], torch.tensor(0.9696).to(torch_device), atol=1e-4, rtol=1e-4)
+ torch.testing.assert_close(top_logits_0[0], torch.tensor(3.4640).to(torch_device), atol=1e-4, rtol=1e-4)
+ torch.testing.assert_close(top_logits_1[0], torch.tensor(3.4608).to(torch_device), atol=1e-4, rtol=1e-4)
+ # Positions 2-3 use relaxed tolerance since their scores are very close (~0.16 and ~0.07)
+ torch.testing.assert_close(
+ top_scores_0[1:], torch.tensor([0.1615, 0.0683]).to(torch_device), atol=1e-2, rtol=1e-2
+ )
+ torch.testing.assert_close(
+ top_scores_1[1:], torch.tensor([0.8302, 0.8153]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+ torch.testing.assert_close(
+ top_logits_1[1:], torch.tensor([1.5873, 1.4849]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+ torch.testing.assert_close(
+ outputs.pred_boxes[0, top_idx_0],
+ torch.tensor([-0.0012, 0.0017, 0.4518, 0.9965]).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+ torch.testing.assert_close(
+ outputs.pred_boxes[1, top_idx_1],
+ torch.tensor([0.1775, 0.2877, 0.2297, 0.3261]).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+ torch.testing.assert_close(
+ outputs.pred_masks[0, top_idx_0, :3, :3],
+ torch.tensor([[0.0345, 0.2871, 0.3752], [0.6577, 0.9573, 1.0289], [0.8247, 0.9766, 0.9590]]).to(
+ torch_device
+ ),
+ atol=1e-2,
+ rtol=1e-2,
+ )
+ torch.testing.assert_close(
+ outputs.pred_masks[1, top_idx_1, :3, :3],
+ torch.tensor(
+ [[-9.2271, -14.6975, -18.0719], [-14.1411, -21.1871, -26.6270], [-15.6623, -23.1574, -18.0739]]
+ ).to(torch_device),
+ atol=1e-2,
+ rtol=1e-2,
+ )
+
+ # Test post-processing
+ results = self.processor.post_process_instance_segmentation(
+ outputs, threshold=0.3, mask_threshold=0.5, target_sizes=inputs.get("original_sizes").tolist()
+ )
+ self.assertEqual(len(results), 2)
+
+ # Check that both have detections
+ self.assertGreater(len(results[0]["masks"]), 0)
+ self.assertGreater(len(results[1]["masks"]), 0)
+
+ # Check exact values for top detection in each image
+ top_pp_score_0 = results[0]["scores"][0]
+ top_pp_box_0 = results[0]["boxes"][0]
+ top_pp_score_1 = results[1]["scores"][0]
+ top_pp_box_1 = results[1]["boxes"][0]
+
+ torch.testing.assert_close(top_pp_score_0, torch.tensor(0.9556).to(torch_device), atol=1e-4, rtol=1e-4)
+ torch.testing.assert_close(
+ top_pp_box_0, torch.tensor([-0.7773, 0.7140, 289.1797, 423.5321]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+ torch.testing.assert_close(top_pp_score_1, torch.tensor(0.8153).to(torch_device), atol=1e-4, rtol=1e-4)
+ torch.testing.assert_close(
+ top_pp_box_1, torch.tensor([168.9672, 137.3469, 191.7236, 161.3282]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+
+ def test_semantic_segmentation_output(self):
+ """Test that semantic segmentation output is produced."""
+ raw_image = prepare_coco_cat_image()
+ inputs = self.processor(images=raw_image, text="ear", return_tensors="pt").to(torch_device)
+
+ with torch.no_grad():
+ outputs = self.model(**inputs)
+
+ # Check exact semantic segmentation output shape
+ self.assertEqual(outputs.semantic_seg.shape, (1, 1, 288, 288))
+ # Check that semantic seg has same spatial size as pred_masks
+ self.assertEqual(outputs.semantic_seg.shape[-2:], outputs.pred_masks.shape[-2:])
+
+ @require_deterministic_for_xpu
+ def test_efficient_multi_prompt_single_image(self):
+ """Test efficient inference with multiple prompts on a single image using get_vision_features."""
+ raw_image = prepare_coco_cat_image()
+
+ # Pre-compute vision embeddings once
+ img_inputs = self.processor(images=raw_image, return_tensors="pt").to(torch_device)
+ with torch.no_grad():
+ vision_embeds = self.model.get_vision_features(pixel_values=img_inputs.pixel_values)
+
+ # Run multiple text prompts efficiently
+ text_prompts = ["ear", "eye"]
+ all_results = []
+
+ for prompt in text_prompts:
+ text_inputs = self.processor(text=prompt, return_tensors="pt").to(torch_device)
+ with torch.no_grad():
+ outputs = self.model(vision_embeds=vision_embeds, **text_inputs)
+
+ results = self.processor.post_process_instance_segmentation(
+ outputs,
+ threshold=0.5,
+ mask_threshold=0.5,
+ target_sizes=img_inputs.get("original_sizes").tolist(),
+ )[0]
+ all_results.append(results)
+
+ # Check that we get results for both prompts
+ self.assertEqual(len(all_results), 2)
+
+ # Verify outputs are equivalent to running with pixel_values directly
+ text_inputs = self.processor(text="ear", return_tensors="pt").to(torch_device)
+ with torch.no_grad():
+ outputs_with_embeds = self.model(vision_embeds=vision_embeds, **text_inputs)
+
+ inputs_direct = self.processor(images=raw_image, text="ear", return_tensors="pt").to(torch_device)
+ with torch.no_grad():
+ outputs_direct = self.model(**inputs_direct)
+
+ # Outputs should be identical
+ torch.testing.assert_close(outputs_with_embeds.pred_logits, outputs_direct.pred_logits, atol=1e-5, rtol=1e-5)
+ torch.testing.assert_close(outputs_with_embeds.pred_boxes, outputs_direct.pred_boxes, atol=1e-5, rtol=1e-5)
+ torch.testing.assert_close(outputs_with_embeds.pred_masks, outputs_direct.pred_masks, atol=1e-5, rtol=1e-5)
+
+ @require_deterministic_for_xpu
+ def test_efficient_single_prompt_multi_images(self):
+ """Test efficient inference with same prompt on multiple images using get_text_features."""
+ raw_image1 = prepare_coco_cat_image()
+ raw_image2 = prepare_coco_kitchen_image()
+
+ # Pre-compute text embeddings once
+ text_prompt = "handle"
+ text_inputs = self.processor(text=text_prompt, return_tensors="pt").to(torch_device)
+ with torch.no_grad():
+ text_embeds = self.model.get_text_features(**text_inputs).pooler_output
+
+ # Run inference on multiple images reusing text embeddings
+ # Note: attention_mask must be passed along with text_embeds for proper masking
+ images = [raw_image1, raw_image2]
+ all_results = []
+
+ for image in images:
+ img_inputs = self.processor(images=image, return_tensors="pt").to(torch_device)
+ with torch.no_grad():
+ outputs = self.model(
+ text_embeds=text_embeds,
+ attention_mask=text_inputs.attention_mask,
+ **img_inputs,
+ )
+
+ results = self.processor.post_process_instance_segmentation(
+ outputs,
+ threshold=0.5,
+ mask_threshold=0.5,
+ target_sizes=img_inputs.get("original_sizes").tolist(),
+ )[0]
+ all_results.append(results)
+
+ # Check that we get results for both images
+ self.assertEqual(len(all_results), 2)
+
+ # Verify outputs are equivalent to running with input_ids directly
+ img_inputs = self.processor(images=raw_image2, return_tensors="pt").to(torch_device)
+ with torch.no_grad():
+ outputs_with_embeds = self.model(
+ text_embeds=text_embeds,
+ attention_mask=text_inputs.attention_mask,
+ **img_inputs,
+ )
+
+ inputs_direct = self.processor(images=raw_image2, text=text_prompt, return_tensors="pt").to(torch_device)
+ with torch.no_grad():
+ outputs_direct = self.model(**inputs_direct)
+
+ # Outputs should be identical
+ torch.testing.assert_close(outputs_with_embeds.pred_logits, outputs_direct.pred_logits, atol=1e-5, rtol=1e-5)
+ torch.testing.assert_close(outputs_with_embeds.pred_boxes, outputs_direct.pred_boxes, atol=1e-5, rtol=1e-5)
+ torch.testing.assert_close(outputs_with_embeds.pred_masks, outputs_direct.pred_masks, atol=1e-5, rtol=1e-5)
diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py
index c987d506f319..8d78898dc4cf 100644
--- a/utils/check_config_attributes.py
+++ b/utils/check_config_attributes.py
@@ -83,6 +83,8 @@
"AutoformerConfig": ["num_static_real_features", "num_time_features"],
"SamVisionConfig": ["mlp_ratio"],
"Sam3VisionConfig": ["backbone_feature_sizes"],
+ "Sam3LiteTextViTConfig": ["global_attn_indexes", "window_size"],
+ "Sam3LiteTextVisionConfig": ["fpn_hidden_size", "scale_factors"],
"SamHQVisionConfig": ["mlp_ratio"],
"ClapAudioConfig": ["num_classes"],
"ClvpDecoderConfig": ["add_cross_attention"],
diff --git a/utils/check_repo.py b/utils/check_repo.py
index b1a3d158c716..4a464be9db41 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -272,6 +272,7 @@
"UVDocBridge", # Building part of a bigger model, tested implicitly through UVDocModel
"Gemma4VisionModel", # Building part of a bigger model, tested implicitly
"Gemma4AudioModel", # Building part of a bigger model, tested implicitly
+ "Sam3LiteTextTextModel", # Building part of a bigger model, tested implicitly through Sam3LiteTextModel
]
)
diff --git a/utils/fetch_hub_objects_for_ci.py b/utils/fetch_hub_objects_for_ci.py
index 59cf65117913..dba74dc418d7 100644
--- a/utils/fetch_hub_objects_for_ci.py
+++ b/utils/fetch_hub_objects_for_ci.py
@@ -18,6 +18,8 @@
"http://images.cocodataset.org/val2017/000000000802.jpg",
"http://images.cocodataset.org/val2017/000000000872.jpg",
"http://images.cocodataset.org/val2017/000000039769.jpg",
+ "http://images.cocodataset.org/val2017/000000077595.jpg",
+ "http://images.cocodataset.org/val2017/000000136466.jpg",
"https://www.ilankelman.org/stopsigns/australia.jpg",
"https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg",
"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",