diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index ba69db1c5e78..496b4d95c51c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1351,6 +1351,8 @@ title: SAM3 - local: model_doc/sam3_video title: SAM3 Video + - local: model_doc/sam3_lite_text + title: SAM3-LiteText - local: model_doc/shieldgemma2 title: ShieldGemma2 - local: model_doc/siglip diff --git a/docs/source/en/model_doc/nomic_bert.md b/docs/source/en/model_doc/nomic_bert.md index 73b3adc8a35f..2017805fe42a 100644 --- a/docs/source/en/model_doc/nomic_bert.md +++ b/docs/source/en/model_doc/nomic_bert.md @@ -23,7 +23,7 @@ limitations under the License. ## Overview -NomicBERT was proposed in [Nomic Embed: Training a Reproducible Long Context Text Embedder](https://arxiv.org/abs/2402.01613) by +NomicBERT was proposed in [Nomic Embed: Training a Reproducible Long Context Text Embedder](https://huggingface.co/papers/2402.01613) by Zach Nussbaum, John X. Morris, Brandon Duderstadt, and Andriy Mulyar. It is BERT-inspired with the most notable extension applying [Rotary Position Embeddings](https://huggingface.co/papers/2104.09864.pdf) to an encoder model. diff --git a/docs/source/en/model_doc/sam3_lite_text.md b/docs/source/en/model_doc/sam3_lite_text.md new file mode 100644 index 000000000000..79ebb9a1701c --- /dev/null +++ b/docs/source/en/model_doc/sam3_lite_text.md @@ -0,0 +1,118 @@ + +*This model was released on 2026-02-12 and added to Hugging Face Transformers on 2026-04-12.* + +# SAM3-LiteText + +
+
+ PyTorch +
+
+ +## Overview + +SAM3-LiteText was proposed in [SAM3-LiteText: An Anatomical Study of the SAM3 Text Encoder for Efficient Vision-Language Segmentation](https://huggingface.co/papers/2602.12173) by Chengxi Zeng, Yuxuan Jiang, Ge Gao, Shuai Wang, Duolikun Danier, Bin Zhu, Stevan Rudinac, David Bull, and Fan Zhang. + +SAM3-LiteText is a lightweight variant of [SAM3](sam3) that replaces the heavy SAM3 text encoder (353M parameters) with a compact MobileCLIP-based text encoder optimized through knowledge distillation. The SAM3 ViT-H image encoder is kept intact. This reduces text encoder parameters by up to 88% while maintaining segmentation performance comparable to the original model. + +The abstract from the paper is the following: + +*Vision-language segmentation models such as SAM3 enable flexible, prompt-driven visual grounding, but inherit large, general-purpose text encoders originally designed for open-ended language understanding. In practice, segmentation prompts are short, structured, and semantically constrained, leading to substantial over-provisioning in text encoder capacity and persistent computational and memory overhead. In this paper, we perform a large-scale anatomical analysis of text prompting in vision-language segmentation, covering 404,796 real prompts across multiple benchmarks. Our analysis reveals severe redundancy: most context windows are underutilized, vocabulary usage is highly sparse, and text embeddings lie on low-dimensional manifold despite high-dimensional representations. Motivated by these findings, we propose SAM3-LiteText, a lightweight text encoding framework that replaces the original SAM3 text encoder with a compact MobileCLIP student that is optimized by knowledge distillation. Extensive experiments on image and video segmentation benchmarks show that SAM3-LiteText reduces text encoder parameters by up to 88%, substantially reducing static memory footprint, while maintaining segmentation performance comparable to the original model.* + +The text encoder architecture is based on [MobileCLIP](https://huggingface.co/papers/2311.17049) and comes in three variants: + +| Variant | Text Encoder | Text Params | Reduction | +|---|---|---|---| +| SAM3-LiteText-S0-16 | MobileCLIP-S0 | 42.54M | ~88% | +| SAM3-LiteText-S1-16 | MobileCLIP-S1 | 63.53M | ~82% | +| SAM3-LiteText-L-16 | MobileCLIP2-L | 123.80M | ~65% | + +This model was contributed by [nielsr](https://huggingface.co/nielsr) and [yonigozlan](https://huggingface.co/yonigozlan). +The original code can be found [here](https://github.com/SimonZeng7108/efficientsam3/tree/sam3_litetext). + +## Usage + +SAM3-LiteText is a drop-in replacement for SAM3 with a lightweight text encoder. It uses the same processor ([`Sam3Processor`]) and supports the same prompting interface. Refer to the [SAM3 documentation](sam3) for detailed usage examples including text prompts, box prompts, batched inference, and more. + +```python +from io import BytesIO + +import httpx +from transformers import AutoModel, AutoProcessor +from PIL import Image + +model = AutoModel.from_pretrained("yonigozlan/sam3-litetext-s0", device_map="auto") +processor = AutoProcessor.from_pretrained("yonigozlan/sam3-litetext-s0") + +image_url = "http://images.cocodataset.org/val2017/000000077595.jpg" +image = Image.open(BytesIO(httpx.get(image_url).content)).convert("RGB") + +inputs = processor(images=image, text="ear", return_tensors="pt").to(model.device) + +outputs = model(**inputs) + +results = processor.post_process_instance_segmentation( + outputs, + threshold=0.5, + mask_threshold=0.5, + target_sizes=inputs.get("original_sizes").tolist(), +)[0] + +print(f"Found {len(results['masks'])} objects") +``` + +## Sam3LiteTextConfig + +[[autodoc]] Sam3LiteTextConfig + +## Sam3LiteTextTextConfig + +[[autodoc]] Sam3LiteTextTextConfig + +## Sam3LiteTextGeometryEncoderConfig + +[[autodoc]] Sam3LiteTextGeometryEncoderConfig + +## Sam3LiteTextDETREncoderConfig + +[[autodoc]] Sam3LiteTextDETREncoderConfig + +## Sam3LiteTextDETRDecoderConfig + +[[autodoc]] Sam3LiteTextDETRDecoderConfig + +## Sam3LiteTextMaskDecoderConfig + +[[autodoc]] Sam3LiteTextMaskDecoderConfig + +## Sam3LiteTextTextModel + +[[autodoc]] Sam3LiteTextTextModel + - forward + +## Sam3LiteTextModel + +[[autodoc]] Sam3LiteTextModel + - forward + +## Sam3LiteTextPreTrainedModel + +[[autodoc]] Sam3LiteTextPreTrainedModel + - forward diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 989be9eb114e..acc5e2fdeac0 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -367,6 +367,7 @@ from .sam2 import * from .sam2_video import * from .sam3 import * + from .sam3_lite_text import * from .sam3_tracker import * from .sam3_tracker_video import * from .sam3_video import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 2c0fe88d0e74..07b0ec3854ef 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -421,6 +421,8 @@ ("sam2_video", "Sam2VideoConfig"), ("sam2_vision_model", "Sam2VisionConfig"), ("sam3", "Sam3Config"), + ("sam3_lite_text", "Sam3LiteTextConfig"), + ("sam3_lite_text_text_model", "Sam3LiteTextTextConfig"), ("sam3_tracker", "Sam3TrackerConfig"), ("sam3_tracker_video", "Sam3TrackerVideoConfig"), ("sam3_video", "Sam3VideoConfig"), @@ -954,6 +956,8 @@ ("sam2_video", "Sam2VideoModel"), ("sam2_vision_model", "Sam2VisionModel"), ("sam3", "SAM3"), + ("sam3_lite_text", "SAM3-LiteText"), + ("sam3_lite_text_text_model", "SAM3-LiteText Text Model"), ("sam3_tracker", "Sam3Tracker"), ("sam3_tracker_video", "Sam3TrackerVideo"), ("sam3_video", "Sam3VideoModel"), @@ -1142,6 +1146,7 @@ ("sam_vision_model", "sam"), ("sam2_vision_model", "sam2"), ("sam2_hiera_det_model", "sam2"), + ("sam3_lite_text_text_model", "sam3_lite_text"), ("sam3_vit_model", "sam3"), ("sam3_vision_model", "sam3"), ("edgetam_vision_model", "edgetam"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 4ada3ba6b8ed..e554249cac4a 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -226,6 +226,7 @@ ("sam2", {"torchvision": "Sam2ImageProcessor"}), ("sam2_video", {"torchvision": "Sam2ImageProcessor"}), ("sam3", {"torchvision": "Sam3ImageProcessor"}), + ("sam3_lite_text", {"torchvision": "Sam3ImageProcessor"}), ("sam3_tracker", {"torchvision": "Sam3ImageProcessor"}), ("sam3_tracker_video", {"torchvision": "Sam3ImageProcessor"}), ("sam3_video", {"torchvision": "Sam3ImageProcessor"}), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index d4cb17cddfa6..50bbd5721413 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -399,6 +399,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("sam2_video", "Sam2VideoModel"), ("sam2_vision_model", "Sam2VisionModel"), ("sam3", "Sam3Model"), + ("sam3_lite_text", "Sam3LiteTextModel"), + ("sam3_lite_text_text_model", "Sam3LiteTextTextModel"), ("sam3_tracker", "Sam3TrackerModel"), ("sam3_tracker", "Sam3TrackerModel"), ("sam3_tracker_video", "Sam3TrackerVideoModel"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 262480b71485..1b6384b41364 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -152,6 +152,7 @@ ("sam", "SamProcessor"), ("sam2", "Sam2Processor"), ("sam3", "Sam3Processor"), + ("sam3_lite_text", "Sam3Processor"), ("sam_hq", "SamHQProcessor"), ("seamless_m4t", "SeamlessM4TProcessor"), ("sew", "Wav2Vec2Processor"), diff --git a/src/transformers/models/sam3/convert_sam3_to_hf.py b/src/transformers/models/sam3/convert_sam3_to_hf.py index 927b1cc5f208..af1786c94c7b 100644 --- a/src/transformers/models/sam3/convert_sam3_to_hf.py +++ b/src/transformers/models/sam3/convert_sam3_to_hf.py @@ -25,7 +25,7 @@ import regex as re import torch -from transformers import CLIPTokenizerFast, Sam3Config, Sam3ImageProcessorFast, Sam3Model, Sam3Processor +from transformers import CLIPTokenizerFast, Sam3Config, Sam3ImageProcessor, Sam3Model, Sam3Processor from transformers.utils import logging @@ -383,7 +383,7 @@ def convert_sam3_checkpoint( # Save processor print("Creating and saving processor...") - image_processor = Sam3ImageProcessorFast() + image_processor = Sam3ImageProcessor() tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", max_length=32, model_max_length=32) processor = Sam3Processor(image_processor=image_processor, tokenizer=tokenizer) processor.save_pretrained(output_path) diff --git a/src/transformers/models/sam3/modeling_sam3.py b/src/transformers/models/sam3/modeling_sam3.py index 5a9aa329daa0..0dee80edf5db 100644 --- a/src/transformers/models/sam3/modeling_sam3.py +++ b/src/transformers/models/sam3/modeling_sam3.py @@ -279,7 +279,7 @@ def box_cxcywh_to_xyxy(x): class Sam3MLP(nn.Module): - def __init__(self, config: Sam3ViTConfig): + def __init__(self, config): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] diff --git a/src/transformers/models/sam3_lite_text/__init__.py b/src/transformers/models/sam3_lite_text/__init__.py new file mode 100644 index 000000000000..753f93a87c0e --- /dev/null +++ b/src/transformers/models/sam3_lite_text/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_sam3_lite_text import * + from .modeling_sam3_lite_text import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/sam3_lite_text/configuration_sam3_lite_text.py b/src/transformers/models/sam3_lite_text/configuration_sam3_lite_text.py new file mode 100644 index 000000000000..696751075611 --- /dev/null +++ b/src/transformers/models/sam3_lite_text/configuration_sam3_lite_text.py @@ -0,0 +1,330 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/sam3_lite_text/modular_sam3_lite_text.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_sam3_lite_text.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...utils import auto_docstring +from ..auto import CONFIG_MAPPING, AutoConfig + + +@auto_docstring(checkpoint="facebook/sam3_lite_text") +@strict +class Sam3LiteTextViTConfig(PreTrainedConfig): + r""" + rope_theta (`float`, *optional*, defaults to 10000.0): + Base frequency for RoPE. + window_size (`int`, *optional*, defaults to 24): + Window size for windowed attention. + global_attn_indexes (`list[int]`, *optional*, defaults to `[7, 15, 23, 31]`): + Indexes of layers with global attention. + pretrain_image_size (`int`, *optional*, defaults to 336): + Pretrained model image size for position embedding initialization. + hidden_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for hidden states. + """ + + base_config_key = "backbone_config" + model_type = "sam3_vit_model" + + hidden_size: int = 1024 + intermediate_size: int = 4736 + num_hidden_layers: int = 32 + num_attention_heads: int = 16 + num_channels: int = 3 + image_size: int | list[int] | tuple[int, int] = 1008 + patch_size: int | list[int] | tuple[int, int] = 14 + hidden_act: str = "gelu" + layer_norm_eps: float = 1e-6 + attention_dropout: float | int = 0.0 + rope_theta: float = 10000.0 + window_size: int = 24 + global_attn_indexes: list[int] | None = None + layer_scale_init_value: float | None = None + pretrain_image_size: int | list[int] | tuple[int, int] = 336 + hidden_dropout: float | int = 0.0 + initializer_range: float = 0.02 + + def __post_init__(self, **kwargs): + super().__post_init__(**kwargs) + if self.global_attn_indexes is None: + self.global_attn_indexes = [7, 15, 23, 31] + + +@auto_docstring(checkpoint="facebook/sam3_lite_text") +@strict +class Sam3LiteTextVisionConfig(PreTrainedConfig): + r""" + fpn_hidden_size (`int`, *optional*, defaults to 256): + The hidden dimension of the FPN. + backbone_feature_sizes (`List[List[int]]`, *optional*, defaults to `[[288, 288], [144, 144], [72, 72]]`): + The spatial sizes (height, width) of the feature maps from the backbone at different scales. + scale_factors (`list[float]`, *optional*, defaults to `[4.0, 2.0, 1.0, 0.5]`): + Scale factors for FPN multi-scale features. List of scaling factors for each FPN level. + """ + + base_config_key = "vision_config" + model_type = "sam3_vision_model" + sub_configs = {"backbone_config": AutoConfig} + + backbone_config: dict | PreTrainedConfig | None = None + fpn_hidden_size: int = 256 + backbone_feature_sizes: list | None = None + scale_factors: list[float] | None = None + hidden_act: str = "gelu" + layer_norm_eps: float = 1e-6 + initializer_range: float = 0.02 + + def __post_init__(self, **kwargs): + self.scale_factors = [4.0, 2.0, 1.0, 0.5] if self.scale_factors is None else self.scale_factors + if self.backbone_feature_sizes is None: + self.backbone_feature_sizes = [[288, 288], [144, 144], [72, 72]] + + if isinstance(self.backbone_config, dict): + self.backbone_config["model_type"] = self.backbone_config.get("model_type", "sam3_vit_model") + self.backbone_config = CONFIG_MAPPING[self.backbone_config["model_type"]](**self.backbone_config) + elif self.backbone_config is None: + self.backbone_config = CONFIG_MAPPING["sam3_vit_model"]() + + super().__post_init__(**kwargs) + + @property + def image_size(self): + """Image size for the vision encoder.""" + return self.backbone_config.image_size + + @image_size.setter + def image_size(self, value): + """Set the image size and propagate to backbone.""" + self.backbone_config.image_size = value + + +@auto_docstring(checkpoint="facebook/sam3_lite_text") +@strict +class Sam3LiteTextGeometryEncoderConfig(PreTrainedConfig): + r""" + roi_size (`int`, *optional*, defaults to 7): + ROI size for box pooling operations. + """ + + model_type = "sam3_lite_text_geometry_encoder" + + hidden_size: int = 256 + num_layers: int = 3 + num_attention_heads: int = 8 + intermediate_size: int = 2048 + dropout: float | int = 0.1 + hidden_act: str = "relu" + hidden_dropout: float | int = 0.0 + layer_norm_eps: float = 1e-6 + roi_size: int = 7 + initializer_range: float = 0.02 + + +@auto_docstring(checkpoint="facebook/sam3_lite_text") +@strict +class Sam3LiteTextDETREncoderConfig(PreTrainedConfig): + r""" + hidden_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for hidden states. + """ + + model_type = "sam3_lite_text_detr_encoder" + + hidden_size: int = 256 + num_layers: int = 6 + num_attention_heads: int = 8 + intermediate_size: int = 2048 + dropout: float | int = 0.1 + hidden_act: str = "relu" + hidden_dropout: float | int = 0.0 + layer_norm_eps: float = 1e-6 + initializer_range: float = 0.02 + + +@auto_docstring(checkpoint="facebook/sam3_lite_text") +@strict +class Sam3LiteTextDETRDecoderConfig(PreTrainedConfig): + r""" + num_queries (`int`, *optional*, defaults to 200): + Number of object queries. + """ + + model_type = "sam3_lite_text_detr_decoder" + + hidden_size: int = 256 + num_layers: int = 6 + num_queries: int = 200 + num_attention_heads: int = 8 + intermediate_size: int = 2048 + dropout: float | int = 0.1 + hidden_act: str = "relu" + hidden_dropout: float | int = 0.0 + layer_norm_eps: float = 1e-6 + initializer_range: float = 0.02 + + +@auto_docstring(checkpoint="facebook/sam3_lite_text") +@strict +class Sam3LiteTextMaskDecoderConfig(PreTrainedConfig): + r""" + num_upsampling_stages (`int`, *optional*, defaults to 3): + Number of upsampling stages in the pixel decoder (FPN). + """ + + model_type = "sam3_lite_text_mask_decoder" + + hidden_size: int = 256 + num_upsampling_stages: int = 3 + layer_norm_eps: float = 1e-6 + dropout: float | int = 0.0 + num_attention_heads: int = 8 + initializer_range: float = 0.02 + + +@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0") +@strict +class Sam3LiteTextTextConfig(PreTrainedConfig): + r""" + use_repmixer_blocks (`bool`, *optional*, defaults to `True`): + Whether to use RepMixer blocks (MobileCLIP-style) for the first and last encoder layers. + When `False`, all layers are standard Transformer encoder layers. + layer_scale_init_value (`float`, *optional*, defaults to `1e-5`): + Initial value for the learnable layer-scale parameters in RepMixer blocks (residual branches). + repmixer_kernel_size (`int`, *optional*, defaults to `11`): + Kernel size for depthwise convolutions in RepMixer blocks (token mixer and convolutional feed-forward path). + """ + + model_type = "sam3_lite_text_text_model" + + vocab_size: int = 49408 + hidden_size: int = 512 + intermediate_size: int = 2048 + projection_dim: int = 512 + num_hidden_layers: int = 12 + num_attention_heads: int = 8 + max_position_embeddings: int = 77 + hidden_act: str = "gelu" + layer_norm_eps: float = 1e-5 + attention_dropout: float = 0.0 + use_repmixer_blocks: bool = True + layer_scale_init_value: float = 1e-5 + repmixer_kernel_size: int = 11 + + +@auto_docstring(checkpoint="facebook/sam3_lite_text") +@strict +class Sam3LiteTextConfig(PreTrainedConfig): + r""" + geometry_encoder_config (`dict` or `Sam3LiteTextGeometryEncoderConfig`, *optional*): + Configuration for the geometry encoder. + detr_encoder_config (`dict` or `Sam3LiteTextDETREncoderConfig`, *optional*): + Configuration for the DETR encoder. + detr_decoder_config (`dict` or `Sam3LiteTextDETRDecoderConfig`, *optional*): + Configuration for the DETR decoder. + mask_decoder_config (`dict` or `Sam3LiteTextMaskDecoderConfig`, *optional*): + Configuration for the mask decoder. + + Example: + ```python + >>> from transformers import Sam3LiteTextConfig, Sam3LiteTextModel + + >>> # Initializing a SAM3_LITE_TEXT configuration + >>> configuration = Sam3LiteTextConfig() + + >>> # Initializing a model from the configuration + >>> model = Sam3LiteTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "sam3_lite_text" + sub_configs = { + "vision_config": AutoConfig, + "text_config": Sam3LiteTextTextConfig, + "geometry_encoder_config": Sam3LiteTextGeometryEncoderConfig, + "detr_encoder_config": Sam3LiteTextDETREncoderConfig, + "detr_decoder_config": Sam3LiteTextDETRDecoderConfig, + "mask_decoder_config": Sam3LiteTextMaskDecoderConfig, + } + + vision_config: dict | PreTrainedConfig | None = None + text_config: dict | PreTrainedConfig | None = None + geometry_encoder_config: dict | PreTrainedConfig | None = None + detr_encoder_config: dict | PreTrainedConfig | None = None + detr_decoder_config: dict | PreTrainedConfig | None = None + mask_decoder_config: dict | PreTrainedConfig | None = None + initializer_range: float = 0.02 + + def __post_init__(self, **kwargs): + if isinstance(self.vision_config, dict): + self.vision_config["model_type"] = self.vision_config.get("model_type", "sam3_vision_model") + self.vision_config = CONFIG_MAPPING[self.vision_config["model_type"]](**self.vision_config) + elif self.vision_config is None: + self.vision_config = CONFIG_MAPPING["sam3_vision_model"]() + + if self.text_config is None: + self.text_config = Sam3LiteTextTextConfig() + if isinstance(self.text_config, dict): + self.text_config = Sam3LiteTextTextConfig(**self.text_config) + + if self.geometry_encoder_config is None: + self.geometry_encoder_config = Sam3LiteTextGeometryEncoderConfig() + if isinstance(self.geometry_encoder_config, dict): + self.geometry_encoder_config = Sam3LiteTextGeometryEncoderConfig(**self.geometry_encoder_config) + + if self.detr_encoder_config is None: + self.detr_encoder_config = Sam3LiteTextDETREncoderConfig() + if isinstance(self.detr_encoder_config, dict): + self.detr_encoder_config = Sam3LiteTextDETREncoderConfig(**self.detr_encoder_config) + + if self.detr_decoder_config is None: + self.detr_decoder_config = Sam3LiteTextDETRDecoderConfig() + if isinstance(self.detr_decoder_config, dict): + self.detr_decoder_config = Sam3LiteTextDETRDecoderConfig(**self.detr_decoder_config) + + if self.mask_decoder_config is None: + self.mask_decoder_config = Sam3LiteTextMaskDecoderConfig() + if isinstance(self.mask_decoder_config, dict): + self.mask_decoder_config = Sam3LiteTextMaskDecoderConfig(**self.mask_decoder_config) + + super().__post_init__(**kwargs) + + @property + def image_size(self): + """Image size for the SAM3_LITE_TEXT model.""" + return self.vision_config.image_size + + @image_size.setter + def image_size(self, value): + """Set the image size and propagate to vision config.""" + self.vision_config.image_size = value + + +__all__ = [ + "Sam3LiteTextConfig", + "Sam3LiteTextTextConfig", + "Sam3LiteTextGeometryEncoderConfig", + "Sam3LiteTextDETREncoderConfig", + "Sam3LiteTextDETRDecoderConfig", + "Sam3LiteTextMaskDecoderConfig", +] diff --git a/src/transformers/models/sam3_lite_text/convert_sam3_lite_text_to_hf.py b/src/transformers/models/sam3_lite_text/convert_sam3_lite_text_to_hf.py new file mode 100644 index 000000000000..8dac30f41302 --- /dev/null +++ b/src/transformers/models/sam3_lite_text/convert_sam3_lite_text_to_hf.py @@ -0,0 +1,496 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert EfficientSAM3 LiteText checkpoints to Hugging Face format.""" + +import argparse +import gc +import os + +import regex as re +import torch +from huggingface_hub import hf_hub_download + +from transformers import CLIPTokenizerFast, Sam3ImageProcessor, Sam3Processor +from transformers.models.sam3_lite_text.configuration_sam3_lite_text import Sam3LiteTextConfig, Sam3LiteTextTextConfig +from transformers.models.sam3_lite_text.modeling_sam3_lite_text import Sam3LiteTextModel +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +# fmt: off +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + # Strip the "detector." prefix that wraps all components + r"^detector\.": r"", + + # ============================================================================ + # Vision Encoder - ViT Backbone + # ============================================================================ + r"^backbone\.vision_backbone\.trunk\.": r"vision_encoder.backbone.", + r"^vision_encoder\.backbone\.pos_embed": r"vision_encoder.backbone.embeddings.position_embeddings", + r"^vision_encoder\.backbone\.patch_embed\.proj\.": r"vision_encoder.backbone.embeddings.patch_embeddings.projection.", + r"^vision_encoder\.backbone\.ln_pre\.": r"vision_encoder.backbone.layer_norm.", + r"^vision_encoder\.backbone\.blocks\.(\d+)\.norm1\.": r"vision_encoder.backbone.layers.\1.layer_norm1.", + r"^vision_encoder\.backbone\.blocks\.(\d+)\.norm2\.": r"vision_encoder.backbone.layers.\1.layer_norm2.", + r"^vision_encoder\.backbone\.blocks\.(\d+)\.attn\.qkv\.": r"vision_encoder.backbone.layers.\1.attention.qkv.", + r"^vision_encoder\.backbone\.blocks\.(\d+)\.attn\.proj\.": r"vision_encoder.backbone.layers.\1.attention.o_proj.", + r"^vision_encoder\.backbone\.blocks\.(\d+)\.attn\.freqs_cis": r"vision_encoder.backbone.layers.\1.rotary_emb.rope_embeddings", + r"^vision_encoder\.backbone\.blocks\.(\d+)\.mlp\.fc1\.": r"vision_encoder.backbone.layers.\1.mlp.fc1.", + r"^vision_encoder\.backbone\.blocks\.(\d+)\.mlp\.fc2\.": r"vision_encoder.backbone.layers.\1.mlp.fc2.", + + # Vision Encoder - FPN Neck + r"^backbone\.vision_backbone\.neck\.fpn\.(\d+)\.": r"vision_encoder.neck.fpn_layers.\1.", + r"^backbone\.vision_backbone\.convs\.(\d+)\.dconv_2x2_0\.": r"vision_encoder.neck.fpn_layers.\1.scale_layers.0.", + r"^backbone\.vision_backbone\.convs\.(\d+)\.dconv_2x2_1\.": r"vision_encoder.neck.fpn_layers.\1.scale_layers.2.", + r"^backbone\.vision_backbone\.convs\.(\d+)\.dconv_2x2\.": r"vision_encoder.neck.fpn_layers.\1.scale_layers.0.", + r"^backbone\.vision_backbone\.convs\.(\d+)\.maxpool_2x2\.": r"vision_encoder.neck.fpn_layers.\1.scale_layers.0.", + r"^backbone\.vision_backbone\.convs\.(\d+)\.conv_1x1\.": r"vision_encoder.neck.fpn_layers.\1.proj1.", + r"^backbone\.vision_backbone\.convs\.(\d+)\.conv_3x3\.": r"vision_encoder.neck.fpn_layers.\1.proj2.", + + # ============================================================================ + # Text Encoder - LiteText (MobileCLIP student) + # ============================================================================ + # Embeddings + r"^backbone\.language_backbone\.encoder\.embedding_layer\.": r"text_encoder.embeddings.token_embedding.", + r"^backbone\.language_backbone\.encoder\.positional_embedding\.pos_embed\.pos_embed$": r"text_encoder.embeddings.position_embedding.position_embedding", + r"^backbone\.language_backbone\.encoder\.final_layer_norm\.": r"text_encoder.final_layer_norm.", + r"^backbone\.language_backbone\.encoder\.projection_layer$": r"text_encoder.projection.weight", + # text_projection: projects from text hidden-dim to DETR hidden-dim + r"^backbone\.language_backbone\.projector\.": r"text_projection.", + # RepMixer blocks (layer 0 and the last layer in the mct variant) + r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.layer_scale$": r"text_encoder.layers.\1.layer_scale", + r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.token_mixer\.layer_scale$": r"text_encoder.layers.\1.token_mixer.layer_scale", + r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.token_mixer\.norm\.rbr_skip\.": r"text_encoder.layers.\1.token_mixer.reference_batchnorm.", + r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.token_mixer\.mixer\.rbr_skip\.": r"text_encoder.layers.\1.token_mixer.mixer.batchnorm_skip.", + r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.token_mixer\.mixer\.rbr_conv\.0\.conv\.": r"text_encoder.layers.\1.token_mixer.mixer.conv.", + r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.token_mixer\.mixer\.rbr_conv\.0\.bn\.": r"text_encoder.layers.\1.token_mixer.mixer.batchnorm_conv.", + r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.convffn\.conv\.conv\.": r"text_encoder.layers.\1.conv_feed_forward.depthwise_conv.", + r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.convffn\.conv\.bn\.": r"text_encoder.layers.\1.conv_feed_forward.depthwise_batchnorm.", + r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.convffn\.fc1\.": r"text_encoder.layers.\1.conv_feed_forward.mlp.fc1.", + r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.convffn\.fc2\.": r"text_encoder.layers.\1.conv_feed_forward.mlp.fc2.", + # Standard transformer layers (pre-norm MHA + FFN) + r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.pre_norm_mha\.0\.": r"text_encoder.layers.\1.layer_norm1.", + r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.pre_norm_mha\.1\.qkv_proj\.": r"text_encoder.layers.\1.self_attn.in_proj_", + r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.pre_norm_mha\.1\.out_proj\.": r"text_encoder.layers.\1.self_attn.out_proj.", + r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.pre_norm_ffn\.0\.": r"text_encoder.layers.\1.layer_norm2.", + r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.pre_norm_ffn\.1\.": r"text_encoder.layers.\1.mlp.fc1.", + r"^backbone\.language_backbone\.encoder\.transformer\.(\d+)\.pre_norm_ffn\.4\.": r"text_encoder.layers.\1.mlp.fc2.", + + # ============================================================================ + # Geometry Encoder + # ============================================================================ + r"^geometry_encoder\.points_direct_project\.": r"geometry_encoder.boxes_direct_project.", + r"^geometry_encoder\.points_pool_project\.": r"geometry_encoder.boxes_pool_project.", + r"^geometry_encoder\.points_pos_enc_project\.": r"geometry_encoder.boxes_pos_enc_project.", + r"^geometry_encoder\.encode\.(\d+)\.cross_attn_image\.out_proj\.": r"geometry_encoder.layers.\1.cross_attn.o_proj.", + r"^geometry_encoder\.encode\.(\d+)\.cross_attn_image\.": r"geometry_encoder.layers.\1.cross_attn.", + r"^geometry_encoder\.encode\.(\d+)\.self_attn\.out_proj\.": r"geometry_encoder.layers.\1.self_attn.o_proj.", + r"^geometry_encoder\.encode\.(\d+)\.self_attn\.": r"geometry_encoder.layers.\1.self_attn.", + r"^geometry_encoder\.encode\.(\d+)\.linear1\.": r"geometry_encoder.layers.\1.mlp.fc1.", + r"^geometry_encoder\.encode\.(\d+)\.linear2\.": r"geometry_encoder.layers.\1.mlp.fc2.", + r"^geometry_encoder\.encode\.(\d+)\.norm1\.": r"geometry_encoder.layers.\1.layer_norm1.", + r"^geometry_encoder\.encode\.(\d+)\.norm2\.": r"geometry_encoder.layers.\1.layer_norm2.", + r"^geometry_encoder\.encode\.(\d+)\.norm3\.": r"geometry_encoder.layers.\1.layer_norm3.", + r"^geometry_encoder\.img_pre_norm\.": r"geometry_encoder.vision_layer_norm.", + r"^geometry_encoder\.norm\.": r"geometry_encoder.prompt_layer_norm.", + r"^geometry_encoder\.encode_norm\.": r"geometry_encoder.output_layer_norm.", + + # ============================================================================ + # DETR Encoder + # ============================================================================ + r"^transformer\.encoder\.layers\.(\d+)\.cross_attn_image\.out_proj\.": r"detr_encoder.layers.\1.cross_attn.o_proj.", + r"^transformer\.encoder\.layers\.(\d+)\.cross_attn_image\.": r"detr_encoder.layers.\1.cross_attn.", + r"^transformer\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.": r"detr_encoder.layers.\1.self_attn.o_proj.", + r"^transformer\.encoder\.layers\.(\d+)\.self_attn\.": r"detr_encoder.layers.\1.self_attn.", + r"^transformer\.encoder\.layers\.(\d+)\.cross_attn\.out_proj\.": r"detr_encoder.layers.\1.cross_attn.o_proj.", + r"^transformer\.encoder\.layers\.(\d+)\.cross_attn\.": r"detr_encoder.layers.\1.cross_attn.", + r"^transformer\.encoder\.layers\.(\d+)\.linear1\.": r"detr_encoder.layers.\1.mlp.fc1.", + r"^transformer\.encoder\.layers\.(\d+)\.linear2\.": r"detr_encoder.layers.\1.mlp.fc2.", + r"^transformer\.encoder\.layers\.(\d+)\.norm1\.": r"detr_encoder.layers.\1.layer_norm1.", + r"^transformer\.encoder\.layers\.(\d+)\.norm2\.": r"detr_encoder.layers.\1.layer_norm2.", + r"^transformer\.encoder\.layers\.(\d+)\.norm3\.": r"detr_encoder.layers.\1.layer_norm3.", + + # ============================================================================ + # DETR Decoder + # ============================================================================ + r"^transformer\.decoder\.query_embed\.": r"detr_decoder.query_embed.", + r"^transformer\.decoder\.reference_points\.": r"detr_decoder.reference_points.", + r"^transformer\.decoder\.instance_query_embed\.": r"detr_decoder.instance_query_embed.", + r"^transformer\.decoder\.instance_reference_points\.": r"detr_decoder.instance_reference_points.", + r"^transformer\.decoder\.presence_token\.": r"detr_decoder.presence_token.", + r"^transformer\.decoder\.presence_token_head\.layers\.0\.": r"detr_decoder.presence_head.layer1.", + r"^transformer\.decoder\.presence_token_head\.layers\.1\.": r"detr_decoder.presence_head.layer2.", + r"^transformer\.decoder\.presence_token_head\.layers\.2\.": r"detr_decoder.presence_head.layer3.", + r"^transformer\.decoder\.presence_token_out_norm\.": r"detr_decoder.presence_layer_norm.", + r"^transformer\.decoder\.norm\.": r"detr_decoder.output_layer_norm.", + r"^transformer\.decoder\.bbox_embed\.layers\.0\.": r"detr_decoder.box_head.layer1.", + r"^transformer\.decoder\.bbox_embed\.layers\.1\.": r"detr_decoder.box_head.layer2.", + r"^transformer\.decoder\.bbox_embed\.layers\.2\.": r"detr_decoder.box_head.layer3.", + r"^transformer\.decoder\.instance_bbox_embed\.layers\.0\.": r"detr_decoder.instance_box_head.layer1.", + r"^transformer\.decoder\.instance_bbox_embed\.layers\.1\.": r"detr_decoder.instance_box_head.layer2.", + r"^transformer\.decoder\.instance_bbox_embed\.layers\.2\.": r"detr_decoder.instance_box_head.layer3.", + r"^transformer\.decoder\.ref_point_head\.layers\.0\.": r"detr_decoder.ref_point_head.layer1.", + r"^transformer\.decoder\.ref_point_head\.layers\.1\.": r"detr_decoder.ref_point_head.layer2.", + r"^transformer\.decoder\.boxRPB_embed_x\.layers\.0\.": r"detr_decoder.box_rpb_embed_x.layer1.", + r"^transformer\.decoder\.boxRPB_embed_x\.layers\.1\.": r"detr_decoder.box_rpb_embed_x.layer2.", + r"^transformer\.decoder\.boxRPB_embed_y\.layers\.0\.": r"detr_decoder.box_rpb_embed_y.layer1.", + r"^transformer\.decoder\.boxRPB_embed_y\.layers\.1\.": r"detr_decoder.box_rpb_embed_y.layer2.", + r"^transformer\.decoder\.layers\.(\d+)\.self_attn\.out_proj\.": r"detr_decoder.layers.\1.self_attn.o_proj.", + r"^transformer\.decoder\.layers\.(\d+)\.self_attn\.": r"detr_decoder.layers.\1.self_attn.", + r"^transformer\.decoder\.layers\.(\d+)\.ca_text\.out_proj\.": r"detr_decoder.layers.\1.text_cross_attn.o_proj.", + r"^transformer\.decoder\.layers\.(\d+)\.ca_text\.": r"detr_decoder.layers.\1.text_cross_attn.", + r"^transformer\.decoder\.layers\.(\d+)\.cross_attn\.out_proj\.": r"detr_decoder.layers.\1.vision_cross_attn.o_proj.", + r"^transformer\.decoder\.layers\.(\d+)\.cross_attn\.": r"detr_decoder.layers.\1.vision_cross_attn.", + r"^transformer\.decoder\.layers\.(\d+)\.linear1\.": r"detr_decoder.layers.\1.mlp.fc1.", + r"^transformer\.decoder\.layers\.(\d+)\.linear2\.": r"detr_decoder.layers.\1.mlp.fc2.", + r"^transformer\.decoder\.layers\.(\d+)\.norm1\.": r"detr_decoder.layers.\1.vision_cross_attn_layer_norm.", + r"^transformer\.decoder\.layers\.(\d+)\.catext_norm\.": r"detr_decoder.layers.\1.text_cross_attn_layer_norm.", + r"^transformer\.decoder\.layers\.(\d+)\.norm2\.": r"detr_decoder.layers.\1.self_attn_layer_norm.", + r"^transformer\.decoder\.layers\.(\d+)\.norm3\.": r"detr_decoder.layers.\1.mlp_layer_norm.", + + # ============================================================================ + # Dot Product Scoring + # ============================================================================ + r"^dot_prod_scoring\.prompt_mlp\.layers\.0\.": r"dot_product_scoring.text_mlp.layer1.", + r"^dot_prod_scoring\.prompt_mlp\.layers\.1\.": r"dot_product_scoring.text_mlp.layer2.", + r"^dot_prod_scoring\.prompt_mlp\.out_norm\.": r"dot_product_scoring.text_mlp_out_norm.", + r"^dot_prod_scoring\.prompt_proj\.": r"dot_product_scoring.text_proj.", + r"^dot_prod_scoring\.hs_proj\.": r"dot_product_scoring.query_proj.", + + # ============================================================================ + # Mask Decoder + # ============================================================================ + r"^segmentation_head\.pixel_decoder\.conv_layers\.(\d+)\.": r"mask_decoder.pixel_decoder.conv_layers.\1.", + r"^segmentation_head\.pixel_decoder\.norms\.(\d+)\.": r"mask_decoder.pixel_decoder.norms.\1.", + r"^segmentation_head\.mask_embed\.layers\.(\d+)\.": r"mask_decoder.mask_embedder.layers.\1.", + r"^segmentation_head\.mask_predictor\.mask_embed\.layers\.(\d+)\.": r"mask_decoder.mask_embedder.layers.\1.", + r"^segmentation_head\.instance_seg_head\.": r"mask_decoder.instance_projection.", + r"^segmentation_head\.semantic_seg_head\.": r"mask_decoder.semantic_projection.", + r"^segmentation_head\.cross_attend_prompt\.out_proj\.": r"mask_decoder.prompt_cross_attn.o_proj.", + r"^segmentation_head\.cross_attend_prompt\.": r"mask_decoder.prompt_cross_attn.", + r"^segmentation_head\.cross_attn_norm\.": r"mask_decoder.prompt_cross_attn_norm.", +} +# fmt: on + + +def convert_old_keys_to_new_keys(state_dict_keys: list[str]) -> dict[str, str]: + """ + Convert original SAM3 LiteText checkpoint keys to HuggingFace format. + + Applies all regex patterns in `ORIGINAL_TO_CONVERTED_KEY_MAPPING` at once + using a multiline bulk-substitution. + """ + output_dict = {} + if state_dict_keys is not None: + old_text = "\n".join(state_dict_keys) + new_text = old_text + for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): + new_text = re.sub(pattern, replacement, new_text, flags=re.MULTILINE) + output_dict = dict(zip(old_text.split("\n"), new_text.split("\n"))) + return output_dict + + +def split_qkv(state_dict: dict) -> dict: + """Split combined QKV projections into separate Q, K, V projections.""" + # Vision backbone: .attention.qkv.* → .attention.{q,k,v}_proj.* + for key in [k for k in state_dict if ".attention.qkv." in k]: + qkv = state_dict.pop(key) + q, k, v = torch.chunk(qkv, 3, dim=0) + state_dict[key.replace(".qkv.", ".q_proj.")] = q + state_dict[key.replace(".qkv.", ".k_proj.")] = k + state_dict[key.replace(".qkv.", ".v_proj.")] = v + + # Text encoder & attention layers: .in_proj_weight/bias → .{q,k,v}_proj.* + for key in [k for k in state_dict if ".in_proj_" in k]: + in_proj = state_dict.pop(key) + q, k, v = torch.chunk(in_proj, 3, dim=0) + if key.endswith("in_proj_weight"): + base = key.replace("in_proj_weight", "") + state_dict[base + "q_proj.weight"] = q + state_dict[base + "k_proj.weight"] = k + state_dict[base + "v_proj.weight"] = v + elif key.endswith("in_proj_bias"): + base = key.replace("in_proj_bias", "") + state_dict[base + "q_proj.bias"] = q + state_dict[base + "k_proj.bias"] = k + state_dict[base + "v_proj.bias"] = v + + return state_dict + + +def load_original_state_dict(checkpoint_path: str) -> dict[str, torch.Tensor]: + """Load the original EfficientSAM3 LiteText checkpoint.""" + print(f"Loading original checkpoint from {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location="cpu") + if "model" in checkpoint: + state_dict = checkpoint["model"] + elif "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + print(f"Loaded {len(state_dict)} keys from checkpoint") + return state_dict + + +def _infer_text_config(state_dict: dict[str, torch.Tensor]) -> Sam3LiteTextTextConfig: + """Infer LiteText encoder hyper-parameters from the raw checkpoint.""" + prefix = "detector.backbone.language_backbone.encoder." + hidden_size = state_dict[f"{prefix}embedding_layer.weight"].shape[1] + context_length = state_dict[f"{prefix}positional_embedding.pos_embed.pos_embed"].shape[2] + use_repmixer_blocks = any(f"{prefix}transformer.0.token_mixer" in k for k in state_dict) + if use_repmixer_blocks: + num_hidden_layers = 6 + else: + layer_ids = { + int(k.split("transformer.")[1].split(".")[0]) + for k in state_dict + if f"{prefix}transformer." in k and ".pre_norm_mha." in k + } + num_hidden_layers = max(layer_ids) + 1 + + return Sam3LiteTextTextConfig( + vocab_size=49408, + hidden_size=hidden_size, + intermediate_size=hidden_size * 4, + num_hidden_layers=num_hidden_layers, + num_attention_heads=hidden_size // 64, + max_position_embeddings=context_length, + projection_dim=hidden_size, + use_repmixer_blocks=use_repmixer_blocks, + ) + + +def get_sam3_lite_text_config(state_dict: dict[str, torch.Tensor]) -> Sam3LiteTextConfig: + """Build a Sam3LiteTextConfig inferred from the raw checkpoint.""" + text_config = _infer_text_config(state_dict) + config = Sam3LiteTextConfig() + config.text_config = text_config + return config + + +def convert_sam3_lite_text_checkpoint( + checkpoint_path: str, + output_path: str, + config: Sam3LiteTextConfig | None = None, + push_to_hub: bool = False, + repo_id: str | None = None, +): + """ + Convert an EfficientSAM3 LiteText checkpoint to HuggingFace format. + + Args: + checkpoint_path: Path to the original `.pt` checkpoint file. + output_path: Directory where the converted model will be saved. + config: Optional pre-built `Sam3LiteTextConfig` (defaults to auto-inferred). + push_to_hub: Whether to push the model to the Hugging Face Hub. + repo_id: Hub repository ID (required when ``push_to_hub=True``). + """ + os.makedirs(output_path, exist_ok=True) + + # Load original checkpoint + state_dict_old = load_original_state_dict(checkpoint_path) + + # Build config from checkpoint + if config is None: + config = get_sam3_lite_text_config(state_dict_old) + + config.architectures = ["Sam3LiteTextModel"] + config.save_pretrained(output_path) + print("Model config saved successfully") + + # Convert keys + print("Converting checkpoint keys...") + all_keys = list(state_dict_old.keys()) + key_mapping = convert_old_keys_to_new_keys(all_keys) + + state_dict_new = {} + for old_key in all_keys: + new_key = key_mapping.get(old_key, old_key) + # num_batches_tracked from BatchNorm is not needed + if "num_batches_tracked" in new_key: + continue + # Parallel SAM2 neck branch in the original checkpoint; HF vision uses `convs` only. + if "vision_backbone.sam2_convs" in new_key: + continue + # Drop keys whose names were not transformed (unrecognised / legacy keys) + if new_key == old_key: + continue + # Strip the first position (cls token) from ViT position embeddings + if new_key == "vision_encoder.backbone.embeddings.position_embeddings": + state_dict_new[new_key] = state_dict_old[old_key][:, 1:, :] + else: + state_dict_new[new_key] = state_dict_old[old_key] + + del state_dict_old + gc.collect() + + # Split combined QKV projections + print("Splitting QKV projections...") + state_dict_new = split_qkv(state_dict_new) + + # HF models compute the RoPE table on the fly + for k in list(state_dict_new.keys()): + if k.endswith("rotary_emb.rope_embeddings"): + state_dict_new.pop(k) + + print( + "Converted key counts:", + { + prefix: sum(1 for k in state_dict_new if k.startswith(prefix)) + for prefix in ( + "vision_encoder.", + "text_encoder.", + "geometry_encoder.", + "detr_encoder.", + "detr_decoder.", + "mask_decoder.", + ) + }, + ) + + # Load weights into HF model + print("Loading weights into Sam3LiteTextModel...") + model = Sam3LiteTextModel(config) + missing_keys, unexpected_keys = model.load_state_dict(state_dict_new, strict=False) + + if missing_keys: + logger.warning(f"Missing keys ({len(missing_keys)}):") + for key in missing_keys: + logger.warning(f" - {key}") + + if unexpected_keys: + logger.warning(f"Unexpected keys ({len(unexpected_keys)}):") + for key in unexpected_keys: + logger.warning(f" - {key}") + + # Save model + print(f"Saving converted model to {output_path}") + model.save_pretrained(output_path) + + # Save processor + print("Creating and saving processor...") + image_processor = Sam3ImageProcessor() + tokenizer = CLIPTokenizerFast.from_pretrained( + "openai/clip-vit-base-patch32", + max_length=config.text_config.max_position_embeddings, + model_max_length=config.text_config.max_position_embeddings, + ) + processor = Sam3Processor(image_processor=image_processor, tokenizer=tokenizer) + processor.save_pretrained(output_path) + + if push_to_hub: + if repo_id is None: + raise ValueError("repo_id must be provided when push_to_hub=True") + print(f"Pushing model to Hub: {repo_id}") + model.push_to_hub(repo_id) + processor.push_to_hub(repo_id) + + print("Conversion complete!") + + # Cleanup and verify + del state_dict_new, model + gc.collect() + + print("\nVerifying converted checkpoint can be loaded...") + try: + model = Sam3LiteTextModel.from_pretrained(output_path) + param_count = sum(p.numel() for p in model.parameters()) + print(f"Successfully loaded model with {param_count:,} parameters") + del model + gc.collect() + except Exception as e: + print(f"Failed to reload model: {e}") + + print("\n" + "=" * 80) + print("Conversion finished!") + print("=" * 80) + print(f"Output directory: {output_path}") + print("\nTo use the model:") + print(">>> from transformers import Sam3LiteTextModel, Sam3Processor") + print(f">>> model = Sam3LiteTextModel.from_pretrained('{output_path}')") + print("=" * 80) + + +MODEL_VARIANTS = { + "s0": "sam3_litetext/efficient_sam3_image_encoder_mobileclip_s0_ctx16.pt", + "s1": "sam3_litetext/efficient_sam3_image_encoder_mobileclip_s1_ctx16.pt", + "l": "sam3_litetext/efficient_sam3_image_encoder_mobileclip2_l_ctx16.pt", +} + + +def main(): + parser = argparse.ArgumentParser(description="Convert EfficientSAM3 LiteText checkpoint to HuggingFace format") + parser.add_argument( + "--model_variant", + type=str, + choices=list(MODEL_VARIANTS), + default=None, + help="Model variant to download and convert: 's0' (MobileCLIP-S0, 42M), 's1' (MobileCLIP-S1, 63M), " + "or 'l' (MobileCLIP2-L, 124M). Takes precedence over --filename when set.", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + default=None, + help="Path to the original .pt checkpoint file. If omitted, the checkpoint is downloaded from the Hub.", + ) + parser.add_argument( + "--output_path", + type=str, + required=True, + help="Directory where the converted checkpoint will be saved.", + ) + parser.add_argument( + "--repo_id", + type=str, + default="Simon7108528/EfficientSAM3", + help="Hub repository ID to download the checkpoint from (used when --checkpoint_path is not provided).", + ) + parser.add_argument( + "--filename", + type=str, + default="sam3_litetext/efficient_sam3_image_encoder_mobileclip_s0_ctx16.pt", + help="Filename within the Hub repository to download (ignored when --model_variant is set).", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether to push the converted model to the Hugging Face Hub.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="Hub repository ID to push to (e.g. 'my-org/sam3-litetext-s0').", + ) + args = parser.parse_args() + + filename = MODEL_VARIANTS[args.model_variant] if args.model_variant else args.filename + + checkpoint_path = args.checkpoint_path + if checkpoint_path is None: + print(f"Downloading checkpoint {filename} from {args.repo_id}...") + checkpoint_path = hf_hub_download(args.repo_id, filename) + + convert_sam3_lite_text_checkpoint( + checkpoint_path=checkpoint_path, + output_path=args.output_path, + push_to_hub=args.push_to_hub, + repo_id=args.hub_model_id, + ) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/sam3_lite_text/modeling_sam3_lite_text.py b/src/transformers/models/sam3_lite_text/modeling_sam3_lite_text.py new file mode 100644 index 000000000000..5a7b02880edd --- /dev/null +++ b/src/transformers/models/sam3_lite_text/modeling_sam3_lite_text.py @@ -0,0 +1,2351 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/sam3_lite_text/modular_sam3_lite_text.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_sam3_lite_text.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Callable, Iterable +from dataclasses import dataclass + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from ... import initialization as init +from ...activations import ACT2FN +from ...masking_utils import create_bidirectional_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import compile_compatible_method_lru_cache +from ...utils import auto_docstring, can_return_tuple, is_torchvision_available, logging +from ...utils.generic import TransformersKwargs, is_flash_attention_requested, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from ..auto import AutoModel +from .configuration_sam3_lite_text import ( + Sam3LiteTextConfig, + Sam3LiteTextDETRDecoderConfig, + Sam3LiteTextDETREncoderConfig, + Sam3LiteTextGeometryEncoderConfig, + Sam3LiteTextMaskDecoderConfig, + Sam3LiteTextTextConfig, + Sam3LiteTextViTConfig, +) + + +if is_torchvision_available(): + import torchvision + + +logger = logging.get_logger(__name__) + + +@dataclass +class Sam3LiteTextTextEncoderOutput(BaseModelOutputWithPooling): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Full sequence of hidden states from the text encoder. + pooler_output (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): + EOT-pooled output projected to `projection_dim` via the internal CLIP-style projection. + hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of hidden states at each layer, returned when `output_hidden_states=True`. + attentions (`tuple(torch.FloatTensor)`, *optional*): + Tuple of attention weights at each transformer layer, returned when `output_attentions=True`. + """ + + +class Sam3LiteTextTextPositionEmbedding(nn.Module): + """Learnable positional embedding with bilinear interpolation for variable sequence lengths.""" + + def __init__(self, max_position_embeddings: int, hidden_size: int): + super().__init__() + self.position_embedding = nn.Parameter(torch.empty(1, 1, max_position_embeddings, hidden_size)) + + def forward(self, seq_len: int) -> torch.Tensor: + position_embedding = self.position_embedding + if seq_len != position_embedding.shape[2]: + position_embedding = F.interpolate( + position_embedding, + size=(seq_len, position_embedding.shape[-1]), + mode="bilinear", + ) + return position_embedding.reshape(1, seq_len, -1) + + +class Sam3LiteTextMobileOneBlock(nn.Module): + """Depthwise conv branch with batch norm on the skip path and after the conv (MobileOne-style).""" + + def __init__(self, hidden_size: int, kernel_size: int = 3): + super().__init__() + self.batchnorm_skip = nn.BatchNorm2d(hidden_size) + self.conv = nn.Conv2d( + hidden_size, + hidden_size, + kernel_size=(1, kernel_size), + stride=1, + padding=(0, kernel_size // 2), + groups=hidden_size, + bias=False, + ) + self.batchnorm_conv = nn.BatchNorm2d(hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.batchnorm_conv(self.conv(hidden_states)) + hidden_states = hidden_states + self.batchnorm_skip(residual) + return hidden_states + + +class Sam3LiteTextConvMLP(nn.Module): + """Pointwise MLP using 1×1 convolutions, compatible with 4-D (B, C, H, W) feature maps.""" + + def __init__(self, config: Sam3LiteTextTextConfig): + super().__init__() + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Conv2d(config.hidden_size, config.intermediate_size, kernel_size=1) + self.fc2 = nn.Conv2d(config.intermediate_size, config.hidden_size, kernel_size=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Sam3LiteTextConvolutionalFeedForward(nn.Module): + """Convolutional feed-forward network: depthwise conv + two pointwise projections.""" + + def __init__(self, config: Sam3LiteTextTextConfig): + super().__init__() + self.depthwise_conv = nn.Conv2d( + config.hidden_size, + config.hidden_size, + kernel_size=(1, config.repmixer_kernel_size), + padding=(0, config.repmixer_kernel_size // 2), + groups=config.hidden_size, + bias=False, + ) + self.depthwise_batchnorm = nn.BatchNorm2d(config.hidden_size) + self.mlp = Sam3LiteTextConvMLP(config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.depthwise_batchnorm(self.depthwise_conv(hidden_states)) + return self.mlp(hidden_states) + + +class Sam3LiteTextLayerScaledResidual(nn.Module): + """Common layer-scale residual pattern shared by the RepMixer and feed-forward branches.""" + + def __init__(self, hidden_size: int, layer_scale_init_value: float): + super().__init__() + self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((hidden_size, 1, 1)), requires_grad=True) + + def layer_scale_residual(self, hidden_states: torch.Tensor, update: torch.Tensor) -> torch.Tensor: + return hidden_states + self.layer_scale * update + + +class Sam3LiteTextRepMixer(Sam3LiteTextLayerScaledResidual): + """Re-parameterisable depthwise-conv token mixer operating on 1D sequence data.""" + + def __init__(self, config: Sam3LiteTextTextConfig): + super().__init__(config.hidden_size, config.layer_scale_init_value) + self.reference_batchnorm = nn.BatchNorm2d(config.hidden_size) + self.mixer = Sam3LiteTextMobileOneBlock(config.hidden_size, kernel_size=config.repmixer_kernel_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.layer_scale_residual( + hidden_states, self.mixer(hidden_states) - self.reference_batchnorm(hidden_states) + ) + + +class Sam3LiteTextRepMixerBlock(Sam3LiteTextLayerScaledResidual): + """Token-mixing RepMixer plus a convolutional feed-forward path, each with layer scale.""" + + def __init__(self, config: Sam3LiteTextTextConfig): + super().__init__(config.hidden_size, config.layer_scale_init_value) + self.token_mixer = Sam3LiteTextRepMixer(config) + self.conv_feed_forward = Sam3LiteTextConvolutionalFeedForward(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + hidden_states = hidden_states.transpose(1, 2).unsqueeze(2) + hidden_states = self.token_mixer(hidden_states) + hidden_states = self.layer_scale_residual(hidden_states, self.conv_feed_forward(hidden_states)) + return hidden_states.squeeze(2).transpose(1, 2) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Sam3LiteTextTextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Input shape: Batch x Time x Channel""" + + input_shape = hidden_states.shape[:-1] + + hidden_shape = (*input_shape, -1, self.head_dim) + queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class Sam3LiteTextTextMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Sam3LiteTextTextEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Sam3LiteTextTextConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = Sam3LiteTextTextAttention(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Sam3LiteTextTextMLP(config) + + @auto_docstring + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Sam3LiteTextTextEmbeddings(nn.Module): + """Token embedding + interpolatable positional embedding for the text encoder.""" + + def __init__(self, config: Sam3LiteTextTextConfig): + super().__init__() + self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) + self.position_embedding = Sam3LiteTextTextPositionEmbedding(config.max_position_embeddings, config.hidden_size) + + def forward(self, input_ids: torch.LongTensor) -> torch.Tensor: + hidden_states = self.token_embedding(input_ids) + hidden_states = hidden_states + self.position_embedding(input_ids.shape[1]).to(hidden_states.dtype) + return hidden_states + + +class Sam3LiteTextViTRotaryEmbedding(nn.Module): + """ + Vision Rotary Position Embedding for SAM3_LITE_TEXT, following transformers library standards. + Supports 2D (axial) rotary embeddings for spatial dimensions. + """ + + def __init__(self, config: Sam3LiteTextViTConfig, end_x: int, end_y: int, scale: float = 1.0): + super().__init__() + dim = config.hidden_size // config.num_attention_heads + # Ensure even dimension for proper axial splitting + if dim % 4 != 0: + raise ValueError("Dimension must be divisible by 4 for axial RoPE") + self.end_x, self.end_y = end_x, end_y + self.dim = dim + self.rope_theta = config.rope_theta + self.scale = scale + freqs = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + flattened_indices = torch.arange(end_x * end_y, dtype=torch.long) + x_positions = (flattened_indices % end_x) * scale + y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * scale + freqs_x = torch.outer(x_positions, freqs).float() + freqs_y = torch.outer(y_positions, freqs).float() + inv_freq = torch.cat([freqs_x, freqs_y], dim=-1) + inv_freq = inv_freq.repeat_interleave(2, dim=-1) + # directly register the cos and sin embeddings as we have a fixed feature shape + self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False) + self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False) + + @torch.no_grad() + def forward(self) -> tuple[torch.Tensor, torch.Tensor]: + # As the feature map size is fixed for each stage, we can just return the pre-computed embeddings. + return self.rope_embeddings_cos, self.rope_embeddings_sin + + +class Sam3LiteTextViTPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config: Sam3LiteTextViTConfig): + super().__init__() + image_size, patch_size = config.pretrain_image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=False) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)).flatten(2).transpose(1, 2) + return embeddings + + +class Sam3LiteTextViTEmbeddings(nn.Module): + """ + Construct the patch embeddings and position embeddings for SAM3_LITE_TEXT ViT. + + Position embeddings are tiled (not interpolated) when resizing to match different input sizes. + """ + + def __init__(self, config: Sam3LiteTextViTConfig): + super().__init__() + + self.patch_embeddings = Sam3LiteTextViTPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter( + torch.randn(1, num_patches, config.hidden_size) + ) # !Remove cls token in convert weights! + + self.dropout = nn.Dropout(config.hidden_dropout) + self.patch_size = config.patch_size + + def _tile_position_embeddings( + self, + position_embeddings: torch.Tensor, + height: int, + width: int, + ) -> torch.Tensor: + """ + Tile position embeddings to match target spatial dimensions. + Args: + position_embeddings: Shape [1, num_pretrain_patches, hidden_size] + height: Target height in patches + width: Target width in patches + + Returns: + Shape [1, height * width, hidden_size] + """ + pretrain_size = int(position_embeddings.shape[1] ** 0.5) + + # Skip tiling if sizes match (but always tile during tracing for consistent graph) + if not torch.jit.is_tracing() and pretrain_size == height and pretrain_size == width: + return position_embeddings.reshape(1, height * width, -1) + + # Tile position embeddings to match target spatial dimensions + hidden_size = position_embeddings.shape[-1] + pos_embed = position_embeddings.reshape(1, pretrain_size, pretrain_size, hidden_size).permute(0, 3, 1, 2) + repeat_h = height // pretrain_size + 1 + repeat_w = width // pretrain_size + 1 + pos_embed = pos_embed.tile([1, 1, repeat_h, repeat_w])[:, :, :height, :width] + return pos_embed.permute(0, 2, 3, 1).reshape(1, height * width, hidden_size) + + def forward( + self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + height, width = pixel_values.shape[-2:] + embeddings = self.patch_embeddings(pixel_values) + + # Calculate spatial dimensions in patches + height_patches = height // self.patch_size + width_patches = width // self.patch_size + + position_embeddings = self._tile_position_embeddings( + self.position_embeddings, + height_patches, + width_patches, + ) + embeddings = embeddings + position_embeddings + embeddings = self.dropout(embeddings) + + return embeddings + + +@auto_docstring +class Sam3LiteTextPreTrainedModel(PreTrainedModel): + config_class = Sam3LiteTextConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + input_modalities = ["image", "text"] + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + _supports_attention_backend = True + supports_gradient_checkpointing = True + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, Sam3LiteTextViTEmbeddings): + init.normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, Sam3LiteTextViTRotaryEmbedding): + end_x, end_y = module.end_x, module.end_y + dim = module.dim + freqs = 1.0 / (module.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + flattened_indices = torch.arange(end_x * end_y, dtype=torch.long) + x_positions = (flattened_indices % end_x) * module.scale + y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * module.scale + freqs_x = torch.outer(x_positions, freqs).float() + freqs_y = torch.outer(y_positions, freqs).float() + inv_freq = torch.cat([freqs_x, freqs_y], dim=-1) + inv_freq = inv_freq.repeat_interleave(2, dim=-1) + init.copy_(module.rope_embeddings_cos, inv_freq.cos()) + init.copy_(module.rope_embeddings_sin, inv_freq.sin()) + if isinstance(module, Sam3LiteTextTextPositionEmbedding): + init.normal_(module.position_embedding, std=module.position_embedding.shape[-1] ** -0.5) + elif isinstance(module, Sam3LiteTextTextModel): + init.normal_(module.projection.weight, std=module.config.hidden_size**-0.5) + + +@auto_docstring( + custom_intro=""" + MobileCLIP MCT text encoder used in EfficientSAM3 LiteText. + + When `config.use_repmixer_blocks` is `True`, the first and last layers are + `Sam3LiteTextRepMixerBlock` modules; the rest are standard `Sam3LiteTextTextEncoderLayer` layers. +""" +) +class Sam3LiteTextTextModel(Sam3LiteTextPreTrainedModel): + config_class = Sam3LiteTextTextConfig + config: Sam3LiteTextTextConfig + _can_record_outputs = { + "hidden_states": Sam3LiteTextTextEncoderLayer, + "attentions": Sam3LiteTextTextAttention, + } + + def __init__(self, config: Sam3LiteTextTextConfig): + super().__init__(config) + self.embeddings = Sam3LiteTextTextEmbeddings(config) + repmixer_positions = {0, config.num_hidden_layers - 1} if config.use_repmixer_blocks else set() + self.layers = nn.ModuleList( + [ + Sam3LiteTextRepMixerBlock(config) if i in repmixer_positions else Sam3LiteTextTextEncoderLayer(config) + for i in range(config.num_hidden_layers) + ] + ) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False) + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Sam3LiteTextTextEncoderOutput: + hidden_states = self.embeddings(input_ids) + attention_mask = create_bidirectional_mask(self.config, hidden_states, attention_mask) + + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask=attention_mask, **kwargs) + + hidden_states = self.final_layer_norm(hidden_states) + + pooled = hidden_states[ + torch.arange(hidden_states.shape[0], device=hidden_states.device), input_ids.argmax(dim=-1) + ] + pooled = self.projection(pooled) + return Sam3LiteTextTextEncoderOutput( + last_hidden_state=hidden_states, + pooler_output=pooled, + ) + + +@dataclass +@auto_docstring +class Sam3LiteTextVisionEncoderOutput(BaseModelOutputWithPooling): + r""" + fpn_hidden_states (`tuple[torch.FloatTensor]`): + Tuple of multi-level FPN feature maps. + fpn_position_encoding (`tuple[torch.FloatTensor]`): + Tuple of position encodings for each FPN level. + """ + + fpn_hidden_states: tuple[torch.FloatTensor, ...] = None + fpn_position_encoding: tuple[torch.FloatTensor, ...] = None + + +@dataclass +@auto_docstring +class Sam3LiteTextGeometryEncoderOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_prompts, hidden_size)`): + Encoded geometry prompt features (boxes). + attention_mask (`torch.BoolTensor` of shape `(batch_size, num_prompts)`, *optional*): + Attention mask for geometry prompts where True indicates valid positions and False indicates padding. + """ + + last_hidden_state: torch.FloatTensor = None + attention_mask: torch.BoolTensor | None = None + + +@dataclass +@auto_docstring +class Sam3LiteTextDETREncoderOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Encoded vision features (flattened from multi-level features). + pos_embeds_flattened (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Flattened position embeddings for the vision features. + text_features (`torch.FloatTensor` of shape `(batch_size, text_seq_len, hidden_size)`, *optional*): + Text features (may be pooled after encoder processing). + spatial_shapes (`torch.LongTensor` of shape `(num_levels, 2)`, *optional*): + Spatial shapes (height, width) for each feature pyramid level. + hidden_states (`tuple[torch.FloatTensor]`, *optional*): + Tuple of hidden states from all encoder layers. + attentions (`tuple[torch.FloatTensor]`, *optional*): + Tuple of attention weights from all encoder layers. + """ + + last_hidden_state: torch.FloatTensor = None + pos_embeds_flattened: torch.FloatTensor | None = None + text_features: torch.FloatTensor | None = None + spatial_shapes: torch.LongTensor | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + + +@dataclass +@auto_docstring +class Sam3LiteTextDETRDecoderOutput(ModelOutput): + r""" + intermediate_hidden_states (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, hidden_size)`): + Decoder hidden states from all layers. + reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`): + Predicted reference boxes from all decoder layers in (cx, cy, w, h) format. + presence_logits (`torch.FloatTensor` of shape `(num_layers, batch_size, 1)`): + Presence logits from all decoder layers indicating object presence confidence. + hidden_states (`tuple[torch.FloatTensor]`, *optional*): + Tuple of hidden states from all decoder layers. + attentions (`tuple[torch.FloatTensor]`, *optional*): + Tuple of attention weights from all decoder layers (self-attention and cross-attention). + """ + + intermediate_hidden_states: torch.FloatTensor = None + reference_boxes: torch.FloatTensor = None + presence_logits: torch.FloatTensor = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + + +@dataclass +@auto_docstring +class Sam3LiteTextMaskDecoderOutput(ModelOutput): + r""" + pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`): + Predicted segmentation masks for each query. + semantic_seg (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*): + Semantic segmentation output. + attentions (`tuple[torch.FloatTensor]`, *optional*): + Tuple of attention weights from mask decoder cross-attention layers. + """ + + pred_masks: torch.FloatTensor = None + semantic_seg: torch.FloatTensor | None = None + attentions: tuple[torch.FloatTensor] | None = None + + +@dataclass +@auto_docstring +class Sam3LiteTextImageSegmentationOutput(ModelOutput): + r""" + pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`): + Predicted segmentation masks for each query. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Predicted bounding boxes in (x1, y1, x2, y2) format. + pred_logits (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*): + Classification confidence scores for each query, computed via dot product between + decoder query features and text features. + presence_logits (`torch.FloatTensor` of shape `(batch_size, 1)`, *optional*): + Presence logits from the DETR decoder presence token (last layer only). These indicate whether objects + are present in the scene. Can be used to compute final scores by multiplying with pred_logits: + `final_scores = pred_logits.sigmoid() * presence_logits.sigmoid()`. + semantic_seg (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*): + Semantic segmentation output. + decoder_hidden_states (`tuple[torch.FloatTensor]`, *optional*): + Tuple of hidden states from all DETR decoder layers. Each tensor has shape `(batch_size, num_queries, hidden_size)`. + decoder_reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`, *optional*): + Reference boxes from all DETR decoder layers. + encoder_hidden_states (`tuple[torch.FloatTensor]`, *optional*): + Tuple of hidden states from all DETR encoder layers. + vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*): + Tuple of hidden states from all vision encoder (ViT) layers. + vision_attentions (`tuple[torch.FloatTensor]`, *optional*): + Attention weights from vision encoder (ViT) layers. + detr_encoder_attentions (`tuple[torch.FloatTensor]`, *optional*): + Attention weights from DETR encoder layers. + detr_decoder_attentions (`tuple[torch.FloatTensor]`, *optional*): + Attention weights from DETR decoder layers (self-attention and cross-attention). + mask_decoder_attentions (`tuple[torch.FloatTensor]`, *optional*): + Attention weights from mask decoder layers. + """ + + pred_masks: torch.FloatTensor = None + pred_boxes: torch.FloatTensor = None + pred_logits: torch.FloatTensor | None = None + presence_logits: torch.FloatTensor | None = None + semantic_seg: torch.FloatTensor | None = None + decoder_hidden_states: tuple[torch.FloatTensor] | None = None + decoder_reference_boxes: torch.FloatTensor | None = None + encoder_hidden_states: tuple[torch.FloatTensor] | None = None + vision_hidden_states: tuple[torch.FloatTensor] | None = None + vision_attentions: tuple[torch.FloatTensor] | None = None + detr_encoder_attentions: tuple[torch.FloatTensor] | None = None + detr_decoder_attentions: tuple[torch.FloatTensor] | None = None + mask_decoder_attentions: tuple[torch.FloatTensor] | None = None + + +class Sam3LiteTextMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Sam3LiteTextAttention(nn.Module): + """ + Multi-head attention. + Handles standard [batch_size, seq_len, hidden_size] tensors. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.hidden_size // config.num_attention_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.k_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.v_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + query: [batch_size, query_len, hidden_size] + key: [batch_size, key_len, hidden_size] + value: [batch_size, value_len, hidden_size] + attention_mask: [batch_size, num_heads, query_len, key_len] or broadcastable + + Returns: + Tuple of (output, attention_weights) + output: [batch_size, query_len, hidden_size] + attention_weights: [batch_size, num_heads, query_len, key_len] + """ + batch_size = query.shape[0] + query_len = query.shape[1] + key_len = key.shape[1] + + query = self.q_proj(query).view(batch_size, query_len, self.num_attention_heads, self.head_dim).transpose(1, 2) + key = self.k_proj(key).view(batch_size, key_len, self.num_attention_heads, self.head_dim).transpose(1, 2) + value = self.v_proj(value).view(batch_size, key_len, self.num_attention_heads, self.head_dim).transpose(1, 2) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + if ( + is_flash_attention_requested(self.config) + and attention_mask is not None + and attention_mask.dtype != torch.bool + ): + # Relative position bias tensors are represented as float masks and are incompatible with Flash Attention + # Fallback to SDPA for this call only so the rest of the model can still benefit from FA + attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"] + logger.warning_once( + "Sam3LiteTextAttention: falling back to SDPA for relative-position cross-attention because " + "Flash Attention does not support additive bias masks." + ) + + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask=attention_mask, + dropout=0.0, + scaling=self.scaling, + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = attn_output.reshape(batch_size, query_len, self.num_attention_heads * self.head_dim).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class Sam3LiteTextSinePositionEmbedding(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__( + self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: float | None = None + ): + super().__init__() + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + self.scale = 2 * math.pi if scale is None else scale + + def encode_1d_positions(self, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Encode 1D coordinate pairs using sine/cosine positional embeddings. + + Args: + x: 1D tensor of x coordinates (flattened) + y: 1D tensor of y coordinates (flattened) + + Returns: + Tuple of (pos_x, pos_y) positional embeddings + """ + x_embed = x * self.scale + y_embed = y * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=x.device).to(x.dtype) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, None] / dim_t + pos_y = y_embed[:, None] / dim_t + pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1) + pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1) + return pos_x, pos_y + + def encode_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """ + Encode 4D box coordinates (x, y, w, h) for decoder conditioning using sine/cosine embeddings. + + Args: + boxes: Box coordinates [batch_size, num_queries, 4] in (x, y, w, h) format + + Returns: + Position embeddings [batch_size, num_queries, num_pos_feats*4] + """ + assert boxes.size(-1) == 4, f"Expected 4D box coordinates (x, y, w, h), got shape {boxes.shape}" + dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=boxes.device).to(boxes.dtype) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) + + x_embed = boxes[:, :, 0] * self.scale + y_embed = boxes[:, :, 1] * self.scale + w_embed = boxes[:, :, 2] * self.scale + h_embed = boxes[:, :, 3] * self.scale + + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_w = w_embed[:, :, None] / dim_t + pos_h = h_embed[:, :, None] / dim_t + + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) + pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) + pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + + return pos + + @compile_compatible_method_lru_cache(maxsize=4) + def forward( + self, + shape: torch.Size, + device: torch.device | str, + dtype: torch.dtype, + mask: Tensor | None = None, + ) -> Tensor: + if mask is None: + mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool) + not_mask = (~mask).to(dtype) + y_embed = not_mask.cumsum(1) + x_embed = not_mask.cumsum(2) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class Sam3LiteTextGeometryEncoderLayer(nn.Module): + def __init__(self, config: Sam3LiteTextGeometryEncoderConfig): + super().__init__() + self.layer_norm1 = nn.LayerNorm(config.hidden_size) + self.self_attn = Sam3LiteTextAttention(config) + self.dropout = nn.Dropout(config.dropout) + + self.cross_attn = Sam3LiteTextAttention(config) + self.layer_norm2 = nn.LayerNorm(config.hidden_size) + + self.mlp = Sam3LiteTextMLP(config) + self.layer_norm3 = nn.LayerNorm(config.hidden_size) + + def forward( + self, + prompt_feats: Tensor, + vision_feats: Tensor, + vision_pos_encoding: Tensor, + prompt_mask: Tensor, + **kwargs: Unpack[TransformersKwargs], + ): + residual = prompt_feats + hidden_states = self.layer_norm1(prompt_feats) + hidden_states, _ = self.self_attn( + query=hidden_states, key=hidden_states, value=hidden_states, attention_mask=prompt_mask, **kwargs + ) + hidden_states = self.dropout(hidden_states) + residual + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + key = vision_feats + vision_pos_encoding + hidden_states, _ = self.cross_attn(query=hidden_states, key=key, value=vision_feats, **kwargs) + hidden_states = self.dropout(hidden_states) + residual + residual = hidden_states + hidden_states = self.layer_norm3(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.dropout(hidden_states) + residual + + return hidden_states + + +def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False): + """ + Concatenates two right-padded sequences, such that the resulting sequence + is contiguous and also right-padded. + + Tensors are batch-first, masks are batch-first with True=valid, False=padding. + + Args: + seq1: A tensor of shape (batch_size, seq1_length, hidden_size). + mask1: A tensor of shape (batch_size, seq1_length) with True=valid, False=padding. + seq2: A tensor of shape (batch_size, seq2_length, hidden_size). + mask2: A tensor of shape (batch_size, seq2_length) with True=valid, False=padding. + return_index: If True, also returns the index of the ids of the element of seq2 + in the concatenated sequence. This can be used to retrieve the elements of seq2. + + Returns: + A tuple (concatenated_sequence, concatenated_mask) if return_index is False, + otherwise (concatenated_sequence, concatenated_mask, index). + The concatenated_mask uses True=valid, False=padding convention. + """ + batch_size, seq1_length, hidden_size = seq1.shape + batch_size2, seq2_length, hidden_size2 = seq2.shape + + assert batch_size == batch_size2 == mask1.size(0) == mask2.size(0) + assert hidden_size == hidden_size2 + assert seq1_length == mask1.size(1) + assert seq2_length == mask2.size(1) + + actual_seq1_lengths = mask1.sum(dim=-1) + actual_seq2_lengths = mask2.sum(dim=-1) + + final_lengths = actual_seq1_lengths + actual_seq2_lengths + max_length = seq1_length + seq2_length + + concatenated_mask = ( + torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1) < final_lengths[:, None] + ) + + concatenated_sequence = torch.zeros((batch_size, max_length, hidden_size), device=seq2.device, dtype=seq2.dtype) + concatenated_sequence[:, :seq1_length, :] = seq1 + + # Shift seq2 elements to start at the end of valid seq1 + index = torch.arange(seq2_length, device=seq2.device)[None].repeat(batch_size, 1) + index = index + actual_seq1_lengths[:, None] + + # Scatter seq2 into the right positions + concatenated_sequence = concatenated_sequence.scatter(1, index[:, :, None].expand(-1, -1, hidden_size), seq2) + + if return_index: + return concatenated_sequence, concatenated_mask, index + + return concatenated_sequence, concatenated_mask + + +def box_cxcywh_to_xyxy(x): + """Convert boxes from (cx, cy, w, h) format to (x1, y1, x2, y2) format.""" + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +class Sam3LiteTextGeometryEncoder(nn.Module): + """ + Encoder for geometric prompts (boxes). + + Boxes are encoded using three approaches: + - Direct projection: linear projection from coordinate space to hidden_size + - Pooling: pool features from the backbone at the specified location (ROI align for boxes) + - Position encoding: use position encoding of the box center + + These encodings are combined additively and further processed with transformer layers. + """ + + def __init__(self, config: Sam3LiteTextGeometryEncoderConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.roi_size = config.roi_size + + self.position_encoding = Sam3LiteTextSinePositionEmbedding( + num_pos_feats=config.hidden_size // 2, normalize=True + ) + self.label_embed = nn.Embedding(2, self.hidden_size) + self.cls_embed = nn.Embedding(1, self.hidden_size) + + # Box encoding layers + self.boxes_direct_project = nn.Linear(4, self.hidden_size) + self.boxes_pool_project = nn.Conv2d(self.hidden_size, self.hidden_size, self.roi_size) + self.boxes_pos_enc_project = nn.Linear(self.hidden_size + 2, self.hidden_size) + + # Image feature normalization + self.vision_layer_norm = nn.LayerNorm(self.hidden_size) + + # Prompt projection and normalization + self.final_proj = nn.Linear(self.hidden_size, self.hidden_size) + self.prompt_layer_norm = nn.LayerNorm(self.hidden_size) + + # Transformer layers + self.layers = nn.ModuleList([Sam3LiteTextGeometryEncoderLayer(config) for _ in range(config.num_layers)]) + self.output_layer_norm = nn.LayerNorm(self.hidden_size) + + def _encode_box_coordinates( + self, center_x: torch.Tensor, center_y: torch.Tensor, width: torch.Tensor, height: torch.Tensor + ) -> torch.Tensor: + """ + Encode box coordinates by combining position-encoded centers with raw width/height. + + Args: + center_x: 1D tensor of box center x coordinates + center_y: 1D tensor of box center y coordinates + width: 1D tensor of box widths + height: 1D tensor of box heights + + Returns: + Encoded box coordinates [N, embedding_dim] + """ + pos_x, pos_y = self.position_encoding.encode_1d_positions(center_x, center_y) + pos = torch.cat((pos_y, pos_x, height[:, None], width[:, None]), dim=1) + return pos + + def _encode_boxes(self, boxes, boxes_mask, boxes_labels, vision_features): + """Encode box prompts. Mask convention: True=valid, False=padding.""" + batch_size, num_boxes = boxes.shape[:2] + height, width = vision_features.shape[-2:] + boxes_embed = self.boxes_direct_project(boxes) + + # Pool features using ROI align + # Convert boxes from CxCyWH to xyxy format and denormalize + boxes_xyxy = box_cxcywh_to_xyxy(boxes) + scale = torch.tensor([width, height, width, height], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device) + scale = scale.view(1, 1, 4) + boxes_xyxy = boxes_xyxy * scale + # ROI align expects list of boxes per batch element, + # convert from bfloat16 to float16 as roi_align only supports float16 and float32 + dtype = torch.float16 if vision_features.dtype == torch.bfloat16 else vision_features.dtype + sampled_features = torchvision.ops.roi_align( + vision_features.to(dtype), boxes_xyxy.to(dtype).unbind(0), self.roi_size + ).to(vision_features.dtype) + + pooled_projection = self.boxes_pool_project(sampled_features) + pooled_projection = pooled_projection.view(batch_size, num_boxes, self.hidden_size) + boxes_embed = boxes_embed + pooled_projection + + # Add position encoding + center_x, center_y, box_width, box_height = boxes.unbind(-1) + pos_enc = self._encode_box_coordinates( + center_x.flatten(), center_y.flatten(), box_width.flatten(), box_height.flatten() + ) + pos_enc = pos_enc.view(batch_size, num_boxes, pos_enc.shape[-1]) + pos_projection = self.boxes_pos_enc_project(pos_enc) + boxes_embed = boxes_embed + pos_projection + + # Add label embeddings (positive/negative) + label_embed = self.label_embed(boxes_labels.long()) + return label_embed + boxes_embed, boxes_mask + + def forward( + self, + box_embeddings: torch.Tensor, + box_mask: torch.Tensor, + box_labels: torch.Tensor, + img_feats: tuple[torch.Tensor, ...], + img_pos_embeds: tuple[torch.Tensor, ...] | None = None, + ): + """ + Forward pass for encoding geometric prompts. + + Args: + box_embeddings: Box coordinates in CxCyWH format [batch_size, num_boxes, 4] + box_mask: Attention mask for boxes [batch_size, num_boxes] + box_labels: Labels for boxes (positive/negative) [batch_size, num_boxes] + img_feats: Image features from vision encoder + img_pos_embeds: Optional position embeddings for image features + + Returns: + Sam3LiteTextGeometryEncoderOutput containing encoded geometry features and attention mask. + """ + batch_size = box_embeddings.shape[0] + + # Prepare vision features for cross-attention: flatten spatial dimensions + vision_feats = img_feats[-1] # [B, C, H, W] + vision_pos_embeds = img_pos_embeds[-1] if img_pos_embeds is not None else torch.zeros_like(vision_feats) + vision_feats_flat = vision_feats.flatten(2).transpose(1, 2) # [B, H*W, C] + vision_pos_embeds_flat = vision_pos_embeds.flatten(2).transpose(1, 2) # [B, H*W, C] + + # Normalize image features for pooling operations + img_feats_last = img_feats[-1] # [B, C, H, W] + img_feats_last = img_feats_last.permute(0, 2, 3, 1) # [B, H, W, C] + normalized_img_feats = self.vision_layer_norm(img_feats_last) + normalized_img_feats = normalized_img_feats.permute(0, 3, 1, 2) # [B, C, H, W] + + prompt_embeds, prompt_mask = self._encode_boxes(box_embeddings, box_mask, box_labels, normalized_img_feats) + + # Add CLS token (always valid) + cls_embed = self.cls_embed.weight.view(1, self.hidden_size).unsqueeze(0).expand(batch_size, -1, -1) + cls_mask = torch.ones(batch_size, 1, dtype=prompt_mask.dtype, device=prompt_mask.device) + prompt_embeds, prompt_mask = concat_padded_sequences(prompt_embeds, prompt_mask, cls_embed, cls_mask) + + prompt_embeds = self.prompt_layer_norm(self.final_proj(prompt_embeds)) + + # Create bidirectional attention mask for transformer layers + prompt_attention_mask = None + if prompt_mask is not None: + prompt_attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=prompt_embeds, + attention_mask=prompt_mask, + ) + + # Apply transformer layers with cross-attention to vision features + for layer in self.layers: + prompt_embeds = layer( + prompt_feats=prompt_embeds, + vision_feats=vision_feats_flat, + vision_pos_encoding=vision_pos_embeds_flat, + prompt_mask=prompt_attention_mask, + ) + + # Final output normalization + prompt_embeds = self.output_layer_norm(prompt_embeds) + + return Sam3LiteTextGeometryEncoderOutput( + last_hidden_state=prompt_embeds, + attention_mask=prompt_mask, + ) + + +class Sam3LiteTextDetrEncoderLayer(nn.Module): + """DETR encoder layer with self-attention and cross-attention.""" + + def __init__(self, config: Sam3LiteTextDETREncoderConfig): + super().__init__() + self.config = config + self.layer_norm1 = nn.LayerNorm(config.hidden_size) + self.self_attn = Sam3LiteTextAttention(config) + self.dropout = nn.Dropout(config.dropout) + + self.cross_attn = Sam3LiteTextAttention(config) + self.layer_norm2 = nn.LayerNorm(config.hidden_size) + + self.mlp = Sam3LiteTextMLP(config) + self.layer_norm3 = nn.LayerNorm(config.hidden_size) + + def forward( + self, + vision_feats: Tensor, + prompt_feats: Tensor, + vision_pos_encoding: Tensor, + prompt_cross_attn_mask: Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ): + """ + Forward pass for DETR encoder layer. + + Args: + vision_feats: Vision features [batch_size, vision_len, hidden_size] (main hidden states) + prompt_feats: Text prompt features [batch_size, text_len, hidden_size] + vision_pos_encoding: Position encoding for vision [batch_size, vision_len, hidden_size] + prompt_cross_attn_mask: Cross-attention mask for prompt features + + Returns: + Updated vision features [batch_size, vision_len, hidden_size] + """ + # Self-attention on vision features with position encoding + residual = vision_feats + hidden_states = self.layer_norm1(vision_feats) + hidden_states_with_pos = hidden_states + vision_pos_encoding + hidden_states, _ = self.self_attn( + query=hidden_states_with_pos, + key=hidden_states_with_pos, + value=hidden_states, + **kwargs, + ) + hidden_states = self.dropout(hidden_states) + residual + + # Cross-attention: vision queries attend to text/prompt features + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + + hidden_states, _ = self.cross_attn( + query=hidden_states, + key=prompt_feats, + value=prompt_feats, + attention_mask=prompt_cross_attn_mask, + **kwargs, + ) + hidden_states = self.dropout(hidden_states) + residual + + # MLP + residual = hidden_states + hidden_states = self.layer_norm3(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.dropout(hidden_states) + residual + + return hidden_states + + +class Sam3LiteTextDetrEncoder(Sam3LiteTextPreTrainedModel): + """ + DETR-style encoder that processes multi-level vision features with text fusion. + + This encoder processes vision features from multiple levels (e.g., FPN features at different + resolutions) and fuses them with text prompts through a stack of transformer encoder layers. + """ + + _can_record_outputs = { + "hidden_states": Sam3LiteTextDetrEncoderLayer, + "attentions": Sam3LiteTextAttention, + } + + def __init__(self, config: Sam3LiteTextDETREncoderConfig): + super().__init__(config) + self.config = config + self.hidden_size = config.hidden_size + + self.layers = nn.ModuleList([Sam3LiteTextDetrEncoderLayer(config) for _ in range(config.num_layers)]) + + self.post_init() + + def _prepare_multilevel_features( + self, + vision_features: list[torch.Tensor], + vision_pos_embeds: list[torch.Tensor], + ): + """ + Prepare multi-level vision features by flattening spatial dimensions and adding level embeddings. + + Args: + vision_features: List of vision features at different levels [batch_size, channels, height, width] + vision_pos_embeds: List of position embeddings for each level [batch_size, channels, height, width] + + Returns: + Tuple containing flattened features, position embeddings, and spatial metadata + """ + features_flattened = [] + pos_embeds_flattened = [] + spatial_shapes = [] + + for features, pos_embed in zip(vision_features, vision_pos_embeds): + height, width = features.shape[-2:] + spatial_shapes.append((height, width)) + + # Flatten spatial dimensions: [batch_size, channels, height, width] -> [batch_size, height*width, channels] + features = features.flatten(2).transpose(1, 2) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + + features_flattened.append(features) + pos_embeds_flattened.append(pos_embed) + + # Concatenate all levels into single sequence + features_flattened = torch.cat(features_flattened, dim=1) + pos_embeds_flattened = torch.cat(pos_embeds_flattened, dim=1) + + spatial_shapes = torch.tensor(spatial_shapes, dtype=torch.long, device=features_flattened.device) + + return ( + features_flattened, + pos_embeds_flattened, + spatial_shapes, + ) + + @merge_with_config_defaults + @capture_outputs + def forward( + self, + vision_features: list[torch.Tensor], + text_features: torch.Tensor, + vision_pos_embeds: list[torch.Tensor] | None = None, + text_mask: torch.Tensor | None = None, + spatial_sizes: list[tuple[int, int]] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Sam3LiteTextDETREncoderOutput: + """ + Forward pass for the DETR encoder. + + Args: + vision_features: List of vision features at different levels + text_features: Text prompt features [batch_size, seq_len, hidden_size] + vision_pos_embeds: Optional list of position embeddings for each level + text_mask: Optional text padding mask [batch_size, seq_len] + spatial_sizes: Optional list of (height, width) tuples for reshaping + + Returns: + Sam3LiteTextDETREncoderOutput containing encoded features and metadata. + """ + batch_size = vision_features[0].shape[0] if vision_features[0].dim() == 4 else vision_features[0].shape[1] + + # TODO: See if we can remove that reshaping and just use the features as is. + if spatial_sizes is not None: + for i, (height, width) in enumerate(spatial_sizes): + # Reshape from [height*width, batch_size, channels] to [batch_size, channels, height, width] + vision_features[i] = vision_features[i].reshape(height, width, batch_size, -1).permute(2, 3, 0, 1) + vision_pos_embeds[i] = vision_pos_embeds[i].reshape(height, width, batch_size, -1).permute(2, 3, 0, 1) + + # Flatten multi-level features for encoder processing + ( + features_flattened, + pos_embeds_flattened, + spatial_shapes, + ) = self._prepare_multilevel_features(vision_features, vision_pos_embeds) + + prompt_cross_attn_mask = None + if text_mask is not None: + prompt_cross_attn_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=features_flattened, + attention_mask=text_mask, + encoder_hidden_states=text_features, + ) + + hidden_states = features_flattened + for layer in self.layers: + hidden_states = layer( + hidden_states, + prompt_feats=text_features, + vision_pos_encoding=pos_embeds_flattened, + prompt_cross_attn_mask=prompt_cross_attn_mask, + **kwargs, + ) + return Sam3LiteTextDETREncoderOutput( + last_hidden_state=hidden_states, + pos_embeds_flattened=pos_embeds_flattened, + text_features=text_features, + spatial_shapes=spatial_shapes, + ) + + +class Sam3LiteTextDecoderMLP(nn.Module): + """Simple 2 or 3-layer MLP for decoder components.""" + + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 2): + super().__init__() + if num_layers == 2: + self.layer1 = nn.Linear(input_dim, hidden_dim) + self.layer2 = nn.Linear(hidden_dim, output_dim) + self.layer3 = None + elif num_layers == 3: + self.layer1 = nn.Linear(input_dim, hidden_dim) + self.layer2 = nn.Linear(hidden_dim, hidden_dim) + self.layer3 = nn.Linear(hidden_dim, output_dim) + else: + raise ValueError(f"Only 2 or 3 layers supported, got {num_layers}") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.relu(self.layer1(x)) + if self.layer3 is not None: + x = F.relu(self.layer2(x)) + x = self.layer3(x) + else: + x = self.layer2(x) + return x + + +class Sam3LiteTextDetrDecoderLayer(nn.Module): + """DETR decoder layer with self-attention, text cross-attention, and vision cross-attention.""" + + def __init__(self, config: Sam3LiteTextDETRDecoderConfig): + super().__init__() + self.config = config + self.self_attn = Sam3LiteTextAttention(config) + self.self_attn_dropout = nn.Dropout(config.dropout) + self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size) + + self.text_cross_attn = Sam3LiteTextAttention(config) + self.text_cross_attn_dropout = nn.Dropout(config.dropout) + self.text_cross_attn_layer_norm = nn.LayerNorm(config.hidden_size) + + self.vision_cross_attn = Sam3LiteTextAttention(config) + self.vision_cross_attn_dropout = nn.Dropout(config.dropout) + self.vision_cross_attn_layer_norm = nn.LayerNorm(config.hidden_size) + + self.mlp = Sam3LiteTextMLP(config) + self.mlp_layer_norm = nn.LayerNorm(config.hidden_size) + self.mlp_dropout = nn.Dropout(config.dropout) + + def forward( + self, + hidden_states: torch.Tensor, + query_pos: torch.Tensor, + text_features: torch.Tensor, + vision_features: torch.Tensor, + vision_pos_encoding: torch.Tensor, + text_cross_attn_mask: torch.Tensor | None = None, + vision_cross_attn_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + """ + Forward pass for decoder layer. + + Args: + hidden_states: Query features [batch_size, num_queries + 1, hidden_size] (includes presence token at position 0) + query_pos: Query position embeddings [batch_size, num_queries, hidden_size] + text_features: Text features [batch_size, seq_len, hidden_size] + vision_features: Vision features [batch_size, height*width, hidden_size] + vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size] + text_cross_attn_mask: Text cross-attention mask + vision_cross_attn_mask: Vision cross-attention mask, already expanded for presence token + + Returns: + Updated hidden states (including presence token at position 0) + """ + # Prepend zeros to query_pos for presence token + query_pos = F.pad(query_pos, (0, 0, 1, 0), mode="constant", value=0) + + # Self-attention with query position encoding + residual = hidden_states + query_with_pos = hidden_states + query_pos + attn_output, _ = self.self_attn( + query=query_with_pos, + key=query_with_pos, + value=hidden_states, + attention_mask=None, + **kwargs, + ) + hidden_states = residual + self.self_attn_dropout(attn_output) + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Text cross-attention: queries attend to text features + residual = hidden_states + query_with_pos = hidden_states + query_pos + + attn_output, _ = self.text_cross_attn( + query=query_with_pos, + key=text_features, + value=text_features, + attention_mask=text_cross_attn_mask, + **kwargs, + ) + hidden_states = residual + self.text_cross_attn_dropout(attn_output) + hidden_states = self.text_cross_attn_layer_norm(hidden_states) + + # Vision cross-attention: queries attend to vision features (with RPB) + residual = hidden_states + query_with_pos = hidden_states + query_pos + key_with_pos = vision_features + vision_pos_encoding + attn_output, _ = self.vision_cross_attn( + query=query_with_pos, + key=key_with_pos, + value=vision_features, + attention_mask=vision_cross_attn_mask, + **kwargs, + ) + hidden_states = residual + self.vision_cross_attn_dropout(attn_output) + hidden_states = self.vision_cross_attn_layer_norm(hidden_states) + + # MLP + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.mlp_dropout(hidden_states) + hidden_states = self.mlp_layer_norm(hidden_states) + + return hidden_states + + +def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: + """The inverse function for sigmoid activation function.""" + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +class Sam3LiteTextDetrDecoder(Sam3LiteTextPreTrainedModel): + """ + DETR-style decoder with box refinement and presence token. + + Simplified version that assumes: + - Box refinement is always enabled + - Intermediate outputs are always returned + - BoxRPB (relative position bias) with log-scale encoding + - Presence token is used + """ + + _can_record_outputs = { + "hidden_states": Sam3LiteTextDetrDecoderLayer, + "attentions": Sam3LiteTextAttention, + } + + def __init__( + self, + config: Sam3LiteTextDETRDecoderConfig, + ): + super().__init__(config) + self.config = config + self.hidden_size = config.hidden_size + + self.layers = nn.ModuleList([Sam3LiteTextDetrDecoderLayer(config) for _ in range(config.num_layers)]) + + self.output_layer_norm = nn.LayerNorm(config.hidden_size) + + self.box_head = Sam3LiteTextDecoderMLP(config.hidden_size, config.hidden_size, 4, 3) + + self.query_embed = nn.Embedding(config.num_queries, config.hidden_size) + self.reference_points = nn.Embedding(config.num_queries, 4) + + self.presence_token = nn.Embedding(1, config.hidden_size) + self.presence_head = Sam3LiteTextDecoderMLP(config.hidden_size, config.hidden_size, 1, 3) + self.presence_layer_norm = nn.LayerNorm(config.hidden_size) + self.clamp_presence_logit_max_val = 10.0 + + self.ref_point_head = Sam3LiteTextDecoderMLP(2 * config.hidden_size, config.hidden_size, config.hidden_size, 2) + + self.box_rpb_embed_x = Sam3LiteTextDecoderMLP(2, config.hidden_size, config.num_attention_heads, 2) + self.box_rpb_embed_y = Sam3LiteTextDecoderMLP(2, config.hidden_size, config.num_attention_heads, 2) + + self.position_encoding = Sam3LiteTextSinePositionEmbedding( + num_pos_feats=config.hidden_size // 2, normalize=False + ) + + self.post_init() + + @compile_compatible_method_lru_cache(maxsize=1) + def _get_coords( + self, height: torch.Tensor, width: torch.Tensor, dtype: torch.dtype, device: torch.device + ) -> tuple[torch.Tensor, torch.Tensor]: + """Generate normalized coordinate grids.""" + coords_h = torch.arange(0, height, device=device, dtype=dtype) / height + coords_w = torch.arange(0, width, device=device, dtype=dtype) / width + return coords_h, coords_w + + def _get_rpb_matrix( + self, reference_boxes: torch.Tensor, spatial_shape: tuple[torch.Tensor, torch.Tensor] + ) -> torch.Tensor: + """ + Compute box relative position bias (RPB) matrix using log-scale encoding. + RPB helps the decoder attend to relevant spatial locations based on predicted box positions. + + Args: + reference_boxes: Reference boxes [batch_size, num_queries, 4] in sigmoid space + spatial_shape: (height, width) of the vision features as tensors + + Returns: + RPB matrix [batch_size, num_heads, num_queries, height*width] + """ + height, width = spatial_shape + boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes) + batch_size, num_queries, _ = boxes_xyxy.shape + + # Generate coordinate grids + coords_h, coords_w = self._get_coords( + height, width, dtype=reference_boxes.dtype, device=reference_boxes.device + ) + + # Compute deltas between coordinates and box boundaries + deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2] + deltas_y = deltas_y.view(batch_size, num_queries, -1, 2) + deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2] + deltas_x = deltas_x.view(batch_size, num_queries, -1, 2) + + # Apply log-scale encoding + deltas_x_log = deltas_x * 8 + deltas_x_log = torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / math.log2(8) + deltas_y_log = deltas_y * 8 + deltas_y_log = torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / math.log2(8) + + # Embed deltas + deltas_x = self.box_rpb_embed_x(deltas_x_log) # [batch_size, num_queries, width, num_heads] + deltas_y = self.box_rpb_embed_y(deltas_y_log) # [batch_size, num_queries, height, num_heads] + + # Combine into 2D bias matrix + rpb_matrix = deltas_y.unsqueeze(3) + deltas_x.unsqueeze( + 2 + ) # [batch_size, num_queries, height, width, num_heads] + rpb_matrix = rpb_matrix.flatten(2, 3) # [batch_size, num_queries, height*width, num_heads] + rpb_matrix = rpb_matrix.permute(0, 3, 1, 2).contiguous() # [batch_size, num_heads, num_queries, height*width] + return rpb_matrix + + @merge_with_config_defaults + @capture_outputs + def forward( + self, + vision_features: torch.Tensor, + text_features: torch.Tensor, + vision_pos_encoding: torch.Tensor, + text_mask: torch.Tensor | None = None, + spatial_shapes: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Sam3LiteTextDETRDecoderOutput: + """ + Forward pass for the DETR decoder. + + Args: + vision_features: Vision features [batch_size, height*width, hidden_size] + text_features: Text features [batch_size, seq_len, hidden_size] + vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size] + text_mask: Text padding mask [batch_size, seq_len] where True=valid, False=padding + spatial_shapes: Spatial shapes [num_levels, 2] + + Returns: + Sam3LiteTextDETRDecoderOutput containing decoder outputs from all layers. + """ + batch_size = vision_features.shape[0] + + query_embeds = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1) + reference_boxes = self.reference_points.weight.unsqueeze(0).expand(batch_size, -1, -1) + reference_boxes = reference_boxes.sigmoid() + presence_token = self.presence_token.weight.unsqueeze(0).expand(batch_size, -1, -1) + + # Concatenate presence token with query embeddings + hidden_states = torch.cat([presence_token, query_embeds], dim=1) + + text_cross_attn_mask = None + if text_mask is not None: + text_cross_attn_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=hidden_states, + attention_mask=text_mask, + encoder_hidden_states=text_features, + ) + + intermediate_outputs = [] + intermediate_boxes = [reference_boxes] + intermediate_presence_logits = [] + + for layer in self.layers: + # Generate sine embeddings for conditional queries + reference_points_input = reference_boxes.unsqueeze(2) + query_sine_embed = self.position_encoding.encode_boxes(reference_points_input[:, :, 0, :]) + query_pos = self.ref_point_head(query_sine_embed) + + # Compute box relative position bias (RPB) attention mask + vision_cross_attn_mask = None + if spatial_shapes is not None and spatial_shapes.shape[0] == 1: + spatial_shape = (spatial_shapes[0, 0], spatial_shapes[0, 1]) + rpb_matrix = self._get_rpb_matrix(reference_boxes, spatial_shape) + # Prepend zeros row for presence token (it attends to all vision tokens equally) + vision_cross_attn_mask = F.pad(rpb_matrix, (0, 0, 1, 0), mode="constant", value=0) + + hidden_states = layer( + hidden_states, + query_pos=query_pos, + text_features=text_features, + vision_features=vision_features, + vision_pos_encoding=vision_pos_encoding, + text_cross_attn_mask=text_cross_attn_mask, + vision_cross_attn_mask=vision_cross_attn_mask, + **kwargs, + ) + + # Extract query hidden states (without presence token) for box refinement + query_hidden_states = hidden_states[:, 1:] + + # Box refinement: predict delta and update reference boxes + reference_boxes_before_sigmoid = inverse_sigmoid(reference_boxes) + delta_boxes = self.box_head(self.output_layer_norm(query_hidden_states)) + new_reference_boxes = (delta_boxes + reference_boxes_before_sigmoid).sigmoid() + reference_boxes = new_reference_boxes.detach() + + intermediate_outputs.append(self.output_layer_norm(query_hidden_states)) + intermediate_boxes.append(new_reference_boxes) + + # Process presence token + presence_hidden = hidden_states[:, :1] + presence_logits = self.presence_head(self.presence_layer_norm(presence_hidden)).squeeze(-1) + presence_logits = presence_logits.clamp( + min=-self.clamp_presence_logit_max_val, max=self.clamp_presence_logit_max_val + ) + intermediate_presence_logits.append(presence_logits) + + # Stack outputs from all layers + intermediate_outputs = torch.stack(intermediate_outputs) + intermediate_boxes = torch.stack(intermediate_boxes[:-1]) + intermediate_presence_logits = torch.stack(intermediate_presence_logits) + + return Sam3LiteTextDETRDecoderOutput( + intermediate_hidden_states=intermediate_outputs, + reference_boxes=intermediate_boxes, + presence_logits=intermediate_presence_logits, + ) + + +class Sam3LiteTextDotProductScoring(nn.Module): + """ + Computes classification scores by computing dot product between projected decoder queries and pooled text features. + This is used to determine confidence/presence scores for each query. + """ + + def __init__(self, config: Sam3LiteTextConfig): + super().__init__() + self.config = config + hidden_size = config.detr_decoder_config.hidden_size + projection_dim = config.detr_decoder_config.hidden_size + + self.text_mlp = Sam3LiteTextDecoderMLP( + input_dim=hidden_size, + hidden_dim=config.detr_decoder_config.intermediate_size, + output_dim=hidden_size, + num_layers=2, + ) + self.text_mlp_dropout = nn.Dropout(config.detr_decoder_config.dropout) + self.text_mlp_out_norm = nn.LayerNorm(hidden_size) + + # Projections for text and query features + self.text_proj = nn.Linear(hidden_size, projection_dim) + self.query_proj = nn.Linear(hidden_size, projection_dim) + + # Scale factor for dot product + self.scale = float(1.0 / np.sqrt(projection_dim)) + + # Clamping to avoid numerical issues + self.clamp_logits = True + self.clamp_max_val = 12.0 + + def _pool_text_features(self, text_features: torch.Tensor, text_mask: torch.Tensor | None) -> torch.Tensor: + """ + Mean pool text features, accounting for padding. + + Args: + text_features: [batch_size, seq_len, hidden_size] + text_mask: [batch_size, seq_len] where True indicates valid tokens, False indicates padding + + Returns: + pooled_text: [batch_size, hidden_size] + """ + if text_mask is None: + # No padding, simple mean + return text_features.mean(dim=1) + + is_valid = text_mask.to(text_features.dtype).unsqueeze(-1) # [batch_size, seq_len, 1] + + # Count valid tokens per batch + num_valid = is_valid.sum(dim=1).clamp(min=1.0) # [batch_size, 1] + + # Mean pool only over valid tokens + pooled_text = (text_features * is_valid).sum(dim=1) / num_valid # [batch_size, hidden_size] + + return pooled_text + + def forward( + self, + decoder_hidden_states: torch.Tensor, + text_features: torch.Tensor, + text_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Compute classification scores via dot product. + + Args: + decoder_hidden_states: [num_layers, batch_size, num_queries, hidden_size] + text_features: [batch_size, seq_len, hidden_size] + text_mask: [batch_size, seq_len] where True=valid, False=padding + + Returns: + scores: [num_layers, batch_size, num_queries, 1] + """ + orig_text_features = text_features + text_features = self.text_mlp(text_features) + text_features = self.text_mlp_dropout(text_features) + text_features = text_features + orig_text_features + text_features = self.text_mlp_out_norm(text_features) + + pooled_text = self._pool_text_features(text_features, text_mask) + + proj_text = self.text_proj(pooled_text) + proj_queries = self.query_proj(decoder_hidden_states) + + proj_text = proj_text.unsqueeze(-1) + scores = torch.matmul(proj_queries, proj_text.unsqueeze(0)) + scores = scores * self.scale + if self.clamp_logits: + scores = scores.clamp(min=-self.clamp_max_val, max=self.clamp_max_val) + + return scores + + +class Sam3LiteTextMaskEmbedder(nn.Module): + """ + MLP that embeds object queries for mask prediction. + Similar to MaskFormer's mask embedder. + """ + + def __init__(self, config: Sam3LiteTextMaskDecoderConfig): + super().__init__() + self.config = config + hidden_size = config.hidden_size + + self.layers = nn.ModuleList( + [ + nn.Linear(hidden_size, hidden_size), + nn.Linear(hidden_size, hidden_size), + nn.Linear(hidden_size, hidden_size), + ] + ) + self.activation = nn.ReLU() + + def forward(self, queries: torch.Tensor) -> torch.Tensor: + """ + Args: + queries: Query embeddings [batch_size, num_queries, hidden_size] + + Returns: + Mask embeddings [batch_size, num_queries, hidden_size] + """ + hidden_states = queries + for i, layer in enumerate(self.layers): + hidden_states = layer(hidden_states) + if i < len(self.layers) - 1: + hidden_states = self.activation(hidden_states) + return hidden_states + + +class Sam3LiteTextPixelDecoder(nn.Module): + """ + Feature Pyramid Network (FPN) decoder that generates pixel-level features. + Inspired by MaskFormer's pixel decoder. + """ + + def __init__(self, config: Sam3LiteTextMaskDecoderConfig): + super().__init__() + self.config = config + hidden_size = config.hidden_size + num_upsampling_stages = config.num_upsampling_stages + + # Create conv layers and norms for FPN + self.conv_layers = nn.ModuleList( + [ + nn.Conv2d(hidden_size, hidden_size, kernel_size=3, stride=1, padding=1) + for _ in range(num_upsampling_stages) + ] + ) + self.norms = nn.ModuleList([nn.GroupNorm(8, hidden_size) for _ in range(num_upsampling_stages)]) + + self.out_channels = hidden_size + + def forward(self, backbone_features: list[torch.Tensor]) -> torch.Tensor: + """ + Args: + backbone_features: List of backbone features [batch_size, hidden_size, H_i, W_i] + from low to high resolution (assumes already projected to hidden_size) + + Returns: + Pixel embeddings [batch_size, hidden_size, H, W] at the finest resolution + """ + # Start from the coarsest feature (last in list) + prev_fpn = backbone_features[-1] + # Iterate through features from coarse to fine (excluding the last which we started with) + for layer_idx, backbone_feat in enumerate(reversed(backbone_features[:-1])): + # Upsample previous FPN output to match current backbone feature size + prev_fpn = F.interpolate(prev_fpn, size=backbone_feat.shape[-2:], mode="nearest") + + # Add skip connection + prev_fpn = prev_fpn + backbone_feat + + # Apply conv and norm + prev_fpn = self.conv_layers[layer_idx](prev_fpn) + prev_fpn = self.norms[layer_idx](prev_fpn) + prev_fpn = F.relu(prev_fpn) + + return prev_fpn + + +class Sam3LiteTextMaskDecoder(Sam3LiteTextPreTrainedModel): + """ + Mask decoder that combines object queries with pixel-level features to predict instance masks. + Also produces a semantic segmentation output and supports cross-attention to prompts. + """ + + _can_record_outputs = { + "attentions": Sam3LiteTextAttention, + } + + def __init__(self, config: Sam3LiteTextMaskDecoderConfig): + super().__init__(config) + self.config = config + hidden_size = config.hidden_size + + # Pixel decoder (FPN) + self.pixel_decoder = Sam3LiteTextPixelDecoder(config) + + # Mask embedder (MLP to transform queries) + self.mask_embedder = Sam3LiteTextMaskEmbedder(config) + + # Projection from pixel decoder output to mask embedding space + self.instance_projection = nn.Conv2d(self.pixel_decoder.out_channels, hidden_size, kernel_size=1) + + # Semantic segmentation head (always present in UniversalSegmentationHead) + self.semantic_projection = nn.Conv2d(self.pixel_decoder.out_channels, 1, kernel_size=1) + + self.prompt_cross_attn = Sam3LiteTextAttention(config) + self.prompt_cross_attn_norm = nn.LayerNorm(hidden_size) + self.prompt_cross_attn_dropout = nn.Dropout(config.dropout) + + self.post_init() + + @merge_with_config_defaults + @capture_outputs + def forward( + self, + decoder_queries: torch.Tensor, + backbone_features: list[torch.Tensor], + encoder_hidden_states: torch.Tensor, + prompt_features: torch.Tensor | None = None, + prompt_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Sam3LiteTextMaskDecoderOutput: + """ + Args: + decoder_queries: Decoder output queries [batch_size, num_queries, hidden_size] + backbone_features: List of backbone features to process through FPN + encoder_hidden_states: Encoder outputs [batch_size, seq_len, hidden_size] + prompt_features: Prompt features (text + geometry) for cross-attention [batch_size, prompt_len, hidden_size] + prompt_mask: Padding mask [batch_size, prompt_len] where True=valid, False=padding + + Returns: + Sam3LiteTextMaskDecoderOutput containing predicted masks and semantic segmentation. + """ + if prompt_features is not None: + # Cross-attention: encoder features attend to prompt features + residual = encoder_hidden_states + normed_hidden_states = self.prompt_cross_attn_norm(encoder_hidden_states) + + cross_attn_mask = None + if prompt_mask is not None: + cross_attn_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=normed_hidden_states, + encoder_hidden_states=prompt_features, + attention_mask=prompt_mask, + ) + + attn_output, _ = self.prompt_cross_attn( + query=normed_hidden_states, + key=prompt_features, + value=prompt_features, + attention_mask=cross_attn_mask, + **kwargs, + ) + encoder_hidden_states = residual + self.prompt_cross_attn_dropout(attn_output) + + # Process backbone features through FPN to get pixel embeddings + pixel_embed = self._embed_pixels( + backbone_features=backbone_features, + encoder_hidden_states=encoder_hidden_states, + ) + + # Predict instance masks via dot product between query embeddings and pixel embeddings + instance_embeds = self.instance_projection(pixel_embed) + mask_embeddings = self.mask_embedder(decoder_queries) + pred_masks = torch.einsum("bqc,bchw->bqhw", mask_embeddings, instance_embeds) + + # Generate semantic segmentation + semantic_seg = self.semantic_projection(pixel_embed) + + return Sam3LiteTextMaskDecoderOutput( + pred_masks=pred_masks, + semantic_seg=semantic_seg, + ) + + def _embed_pixels( + self, + backbone_features: list[torch.Tensor], + encoder_hidden_states: torch.Tensor, + ) -> torch.Tensor: + """ + Embed pixels by combining backbone FPN features with encoder vision features. + The encoder vision features replace the finest-resolution backbone feature. + + Args: + backbone_features: List of backbone features [batch_size, C, H_i, W_i] + encoder_hidden_states: Encoder outputs [batch_size, seq_len, hidden_size] + + Returns: + Pixel embeddings [batch_size, hidden_size, H, W] + """ + backbone_visual_feats = [feat.clone() for feat in backbone_features] + + # Extract vision features from encoder output and reshape to spatial format + spatial_dim = backbone_features[-1].shape[-2] * backbone_features[-1].shape[-1] + encoder_visual_embed = encoder_hidden_states[:, :spatial_dim, :] + batch_size, _, hidden_size = encoder_visual_embed.shape + height, width = backbone_features[-1].shape[-2:] + encoder_visual_embed = encoder_visual_embed.transpose(1, 2).reshape(batch_size, hidden_size, height, width) + + # Replace finest backbone feature with encoder vision features + backbone_visual_feats[-1] = encoder_visual_embed + + # Process through FPN decoder + pixel_embed = self.pixel_decoder(backbone_visual_feats) + + return pixel_embed + + +class Sam3LiteTextModel(Sam3LiteTextPreTrainedModel): + input_modalities = ["image", "text"] + base_model_prefix = "detector_model" + _keys_to_ignore_on_load_unexpected = [ + r"^tracker_model.", + r"^tracker_neck.", + ] + # DETR components create float masks from features, so flash/flex attention cannot be dispatched safely. + _supports_flash_attn = False + _supports_flex_attn = False + + def __init__(self, config: Sam3LiteTextConfig): + # loading from a sam3_lite_text_video config + if hasattr(config, "detector_config") and config.detector_config is not None: + detector_config = config.detector_config + if isinstance(detector_config, dict): + detector_config = Sam3LiteTextConfig(**detector_config) + config = detector_config + super().__init__(config) + self.vision_encoder = AutoModel.from_config(config.vision_config) + self.text_encoder = Sam3LiteTextTextModel(config.text_config) + self.vocab_size = config.text_config.vocab_size + + # Project text features from text encoder hidden size to model hidden size + # CLIP text encoder outputs 1024-dim features, but we need 256-dim for DETR + self.text_projection = nn.Linear(config.text_config.hidden_size, config.detr_encoder_config.hidden_size) + + # Pass _attn_implementation to subconfigs BEFORE creating modules + config.geometry_encoder_config._attn_implementation = config._attn_implementation + config.detr_encoder_config._attn_implementation = config._attn_implementation + config.detr_decoder_config._attn_implementation = config._attn_implementation + config.mask_decoder_config._attn_implementation = config._attn_implementation + + self.geometry_encoder = Sam3LiteTextGeometryEncoder(config.geometry_encoder_config) + self.detr_encoder = Sam3LiteTextDetrEncoder(config.detr_encoder_config) + self.detr_decoder = Sam3LiteTextDetrDecoder(config.detr_decoder_config) + self.mask_decoder = Sam3LiteTextMaskDecoder(config.mask_decoder_config) + + # Dot product scoring to compute classification scores + self.dot_product_scoring = Sam3LiteTextDotProductScoring(config) + + self.post_init() + + @can_return_tuple + @auto_docstring + def get_text_features( + self, + input_ids: torch.LongTensor, + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + Example: + + ```python + >>> from transformers import Sam3LiteTextModel, Sam3LiteTextProcessor + >>> from PIL import Image + >>> import httpx + >>> from io import BytesIO + + >>> model = Sam3LiteTextModel.from_pretrained("facebook/sam3_lite_text") + >>> processor = Sam3LiteTextProcessor.from_pretrained("facebook/sam3_lite_text") + + >>> # Pre-compute text embeddings + >>> text_inputs = processor(text="cat", return_tensors="pt") + >>> text_embeds = model.get_text_features(**text_inputs).pooler_output + + >>> # Reuse text embeddings for multiple images + >>> url = "http://images.cocodataset.org/val2017/000000077595.jpg" + >>> with httpx.stream("GET", url) as response: + ... image = Image.open(BytesIO(response.read())) + >>> img_inputs = processor(images=image, return_tensors="pt") + >>> outputs = model(pixel_values=img_inputs.pixel_values, text_embeds=text_embeds) + ``` + """ + text_outputs = self.text_encoder( + input_ids=input_ids, attention_mask=attention_mask, return_dict=True, **kwargs + ) + last_hidden_state = text_outputs.last_hidden_state + text_outputs.pooler_output = self.text_projection(last_hidden_state) + + return text_outputs + + @auto_docstring + def get_vision_features( + self, + pixel_values: torch.FloatTensor, + **kwargs: Unpack[TransformersKwargs], + ) -> Sam3LiteTextVisionEncoderOutput: + r""" + Example: + + ```python + >>> from transformers import Sam3LiteTextModel, Sam3LiteTextProcessor + >>> from PIL import Image + >>> import httpx + >>> from io import BytesIO + + >>> model = Sam3LiteTextModel.from_pretrained("facebook/sam3_lite_text") + >>> processor = Sam3LiteTextProcessor.from_pretrained("facebook/sam3_lite_text") + + >>> # Pre-compute vision embeddings + >>> url = "http://images.cocodataset.org/val2017/000000077595.jpg" + >>> with httpx.stream("GET", url) as response: + ... image = Image.open(BytesIO(response.read())) + >>> img_inputs = processor(images=image, return_tensors="pt") + >>> vision_embeds = model.get_vision_features(pixel_values=img_inputs.pixel_values) + + >>> # Reuse vision embeddings for multiple text prompts + >>> text_inputs = processor(text="cat", return_tensors="pt") + >>> outputs = model(vision_embeds=vision_embeds, input_ids=text_inputs.input_ids) + ``` + """ + vision_outputs = self.vision_encoder(pixel_values, **kwargs) + return vision_outputs + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor | None = None, + vision_embeds: Sam3LiteTextVisionEncoderOutput | None = None, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + text_embeds: torch.FloatTensor | None = None, + input_boxes: torch.FloatTensor | None = None, + input_boxes_labels: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Sam3LiteTextImageSegmentationOutput: + r""" + vision_embeds (`Sam3LiteTextVisionEncoderOutput`, *optional*): + Pre-computed vision embeddings. Can be used to easily reuse vision embeddings. If provided, `pixel_values` + should not be passed. Mutually exclusive with `pixel_values`. + text_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Pre-computed text embeddings. Can be used to easily reuse text embeddings. If provided, `input_ids` + should not be passed. Mutually exclusive with `input_ids`. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`, *optional*): + Normalized box coordinates in [0, 1] range, in (cx, cy, w, h) format. + input_boxes_labels (`torch.LongTensor` of shape `(batch_size, num_boxes)`, *optional*): + Labels for boxes: 1 (positive), 0 (negative). + + Example: + + ```python + >>> from PIL import Image + >>> import httpx + >>> from io import BytesIO + >>> from transformers import AutoModel, AutoProcessor + + >>> model = AutoModel.from_pretrained("facebook/sam3_lite_text") + >>> processor = AutoProcessor.from_pretrained("facebook/sam3_lite_text") + + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png" + >>> with httpx.stream("GET", url) as response: + ... image = Image.open(BytesIO(response.read())).convert("RGB") + >>> text = "car" + >>> inputs = processor(images=image, text=text, return_tensors="pt") + + >>> # Get segmentation output + >>> outputs = model(**inputs) + >>> pred_masks = outputs.pred_masks + >>> pred_boxes = outputs.pred_boxes + ``` + """ + if (pixel_values is None) == (vision_embeds is None): + raise ValueError("You must specify exactly one of pixel_values or vision_embeds") + + if (input_ids is None) == (text_embeds is None): + raise ValueError("You must specify exactly one of input_ids or text_embeds") + + if pixel_values is not None: + batch_size = pixel_values.shape[0] + device = pixel_values.device + else: + batch_size = vision_embeds.fpn_hidden_states[0].shape[0] + device = vision_embeds.fpn_hidden_states[0].device + + if vision_embeds is None: + vision_outputs = self.vision_encoder(pixel_values, **kwargs) + else: + vision_outputs = vision_embeds + + fpn_hidden_states = vision_outputs.fpn_hidden_states[:-1] + fpn_position_encoding = vision_outputs.fpn_position_encoding[:-1] + + if text_embeds is None: + text_features = self.get_text_features( + input_ids=input_ids, attention_mask=attention_mask, return_dict=True + ).pooler_output + else: + text_features = text_embeds + + text_mask = attention_mask.bool() if attention_mask is not None else None + has_geometry_prompts = input_boxes is not None and input_boxes.numel() > 0 + + geometry_prompt_features = None + geometry_prompt_mask = None + + if has_geometry_prompts: + if input_boxes is not None and input_boxes.numel() > 0: + box_embeddings = input_boxes # [batch_size, num_boxes, 4] + box_labels = ( + input_boxes_labels + if input_boxes_labels is not None + else torch.ones_like(box_embeddings[..., 0], dtype=torch.long) + ) + box_mask = ( + (input_boxes_labels != -10) + if input_boxes_labels is not None + else torch.ones(batch_size, input_boxes.shape[1], dtype=torch.bool, device=device) + ) + box_labels = torch.where(box_labels == -10, 0, box_labels) + else: + box_embeddings = torch.zeros(batch_size, 0, 4, dtype=text_features.dtype, device=device) + box_labels = torch.zeros(batch_size, 0, dtype=torch.long, device=device) + box_mask = torch.zeros(batch_size, 0, dtype=torch.bool, device=device) + + geometry_outputs = self.geometry_encoder( + box_embeddings=box_embeddings, + box_mask=box_mask, + box_labels=box_labels, + img_feats=fpn_hidden_states, + img_pos_embeds=fpn_position_encoding, + ) + + geometry_prompt_features = geometry_outputs.last_hidden_state + geometry_prompt_mask = geometry_outputs.attention_mask + + if geometry_prompt_features is not None: + # Repeat text_features for all geometry prompts + if text_features.shape[0] == 1 and geometry_prompt_features.shape[0] > 1: + text_features = text_features.repeat(geometry_prompt_features.shape[0], 1, 1) + combined_prompt_features = torch.cat([text_features, geometry_prompt_features], dim=1) + if text_mask is not None and text_mask.shape[0] == 1 and geometry_prompt_mask.shape[0] > 1: + text_mask = text_mask.repeat(geometry_prompt_mask.shape[0], 1) + + if text_mask is not None and geometry_prompt_mask is not None: + combined_prompt_mask = torch.cat([text_mask, geometry_prompt_mask], dim=1) + elif text_mask is not None: + geo_valid_mask = torch.ones( + batch_size, geometry_prompt_features.shape[1], dtype=torch.bool, device=device + ) + combined_prompt_mask = torch.cat([text_mask, geo_valid_mask], dim=1) + elif geometry_prompt_mask is not None: + text_valid_mask = torch.ones(batch_size, text_features.shape[1], dtype=torch.bool, device=device) + combined_prompt_mask = torch.cat([text_valid_mask, geometry_prompt_mask], dim=1) + else: + combined_prompt_mask = None + else: + combined_prompt_features = text_features + combined_prompt_mask = text_mask + + encoder_outputs = self.detr_encoder( + vision_features=[fpn_hidden_states[-1]], + text_features=combined_prompt_features, + vision_pos_embeds=[fpn_position_encoding[-1]], + text_mask=combined_prompt_mask, + **kwargs, + ) + + decoder_outputs = self.detr_decoder( + vision_features=encoder_outputs.last_hidden_state, + text_features=encoder_outputs.text_features, + vision_pos_encoding=encoder_outputs.pos_embeds_flattened, + text_mask=combined_prompt_mask, + spatial_shapes=encoder_outputs.spatial_shapes, + **kwargs, + ) + + # Refine boxes from decoder + all_box_offsets = self.detr_decoder.box_head(decoder_outputs.intermediate_hidden_states) + reference_boxes_inv_sig = inverse_sigmoid(decoder_outputs.reference_boxes) + all_pred_boxes_cxcywh = (reference_boxes_inv_sig + all_box_offsets).sigmoid() + all_pred_boxes = box_cxcywh_to_xyxy(all_pred_boxes_cxcywh) + + all_pred_logits = self.dot_product_scoring( + decoder_hidden_states=decoder_outputs.intermediate_hidden_states, + text_features=encoder_outputs.text_features, + text_mask=combined_prompt_mask, + ).squeeze(-1) + + pred_logits = all_pred_logits[-1] + pred_boxes = all_pred_boxes[-1] + decoder_hidden_states = decoder_outputs.intermediate_hidden_states[-1] + presence_logits = decoder_outputs.presence_logits[-1] + + mask_outputs = self.mask_decoder( + decoder_queries=decoder_hidden_states, + backbone_features=list(fpn_hidden_states), + encoder_hidden_states=encoder_outputs.last_hidden_state, + prompt_features=combined_prompt_features, + prompt_mask=combined_prompt_mask, + **kwargs, + ) + + return Sam3LiteTextImageSegmentationOutput( + pred_masks=mask_outputs.pred_masks, + pred_boxes=pred_boxes, + pred_logits=pred_logits, + presence_logits=presence_logits, + semantic_seg=mask_outputs.semantic_seg, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_reference_boxes=decoder_outputs.reference_boxes, + encoder_hidden_states=encoder_outputs.hidden_states, + vision_hidden_states=vision_outputs.hidden_states, + vision_attentions=vision_outputs.attentions, + detr_encoder_attentions=encoder_outputs.attentions, + detr_decoder_attentions=decoder_outputs.attentions, + mask_decoder_attentions=mask_outputs.attentions, + ) + + +__all__ = ["Sam3LiteTextModel", "Sam3LiteTextPreTrainedModel", "Sam3LiteTextTextModel"] diff --git a/src/transformers/models/sam3_lite_text/modular_sam3_lite_text.py b/src/transformers/models/sam3_lite_text/modular_sam3_lite_text.py new file mode 100644 index 000000000000..46408464a333 --- /dev/null +++ b/src/transformers/models/sam3_lite_text/modular_sam3_lite_text.py @@ -0,0 +1,532 @@ +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from huggingface_hub.dataclasses import strict + +from ... import initialization as init +from ...activations import ACT2FN +from ...configuration_utils import PreTrainedConfig +from ...masking_utils import create_bidirectional_mask +from ...modeling_outputs import BaseModelOutputWithPooling +from ...processing_utils import Unpack +from ...utils import auto_docstring +from ...utils.generic import TransformersKwargs, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel +from ..sam3.configuration_sam3 import ( + Sam3DETRDecoderConfig, + Sam3DETREncoderConfig, + Sam3GeometryEncoderConfig, + Sam3MaskDecoderConfig, +) +from ..sam3.modeling_sam3 import Sam3Model, Sam3PreTrainedModel +from ..siglip.modeling_siglip import SiglipAttention, SiglipEncoderLayer, SiglipMLP + + +@auto_docstring(checkpoint="facebook/sam3_lite_text") +@strict +class Sam3LiteTextViTConfig(PreTrainedConfig): + r""" + rope_theta (`float`, *optional*, defaults to 10000.0): + Base frequency for RoPE. + window_size (`int`, *optional*, defaults to 24): + Window size for windowed attention. + global_attn_indexes (`list[int]`, *optional*, defaults to `[7, 15, 23, 31]`): + Indexes of layers with global attention. + pretrain_image_size (`int`, *optional*, defaults to 336): + Pretrained model image size for position embedding initialization. + hidden_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for hidden states. + """ + + base_config_key = "backbone_config" + model_type = "sam3_vit_model" + + hidden_size: int = 1024 + intermediate_size: int = 4736 + num_hidden_layers: int = 32 + num_attention_heads: int = 16 + num_channels: int = 3 + image_size: int | list[int] | tuple[int, int] = 1008 + patch_size: int | list[int] | tuple[int, int] = 14 + hidden_act: str = "gelu" + layer_norm_eps: float = 1e-6 + attention_dropout: float | int = 0.0 + rope_theta: float = 10000.0 + window_size: int = 24 + global_attn_indexes: list[int] | None = None + layer_scale_init_value: float | None = None + pretrain_image_size: int | list[int] | tuple[int, int] = 336 + hidden_dropout: float | int = 0.0 + initializer_range: float = 0.02 + + def __post_init__(self, **kwargs): + super().__post_init__(**kwargs) + if self.global_attn_indexes is None: + self.global_attn_indexes = [7, 15, 23, 31] + + +@auto_docstring(checkpoint="facebook/sam3_lite_text") +@strict +class Sam3LiteTextVisionConfig(PreTrainedConfig): + r""" + fpn_hidden_size (`int`, *optional*, defaults to 256): + The hidden dimension of the FPN. + backbone_feature_sizes (`List[List[int]]`, *optional*, defaults to `[[288, 288], [144, 144], [72, 72]]`): + The spatial sizes (height, width) of the feature maps from the backbone at different scales. + scale_factors (`list[float]`, *optional*, defaults to `[4.0, 2.0, 1.0, 0.5]`): + Scale factors for FPN multi-scale features. List of scaling factors for each FPN level. + """ + + base_config_key = "vision_config" + model_type = "sam3_vision_model" + sub_configs = {"backbone_config": AutoConfig} + + backbone_config: dict | PreTrainedConfig | None = None + fpn_hidden_size: int = 256 + backbone_feature_sizes: list | None = None + scale_factors: list[float] | None = None + hidden_act: str = "gelu" + layer_norm_eps: float = 1e-6 + initializer_range: float = 0.02 + + def __post_init__(self, **kwargs): + self.scale_factors = [4.0, 2.0, 1.0, 0.5] if self.scale_factors is None else self.scale_factors + if self.backbone_feature_sizes is None: + self.backbone_feature_sizes = [[288, 288], [144, 144], [72, 72]] + + if isinstance(self.backbone_config, dict): + self.backbone_config["model_type"] = self.backbone_config.get("model_type", "sam3_vit_model") + self.backbone_config = CONFIG_MAPPING[self.backbone_config["model_type"]](**self.backbone_config) + elif self.backbone_config is None: + self.backbone_config = CONFIG_MAPPING["sam3_vit_model"]() + + super().__post_init__(**kwargs) + + @property + def image_size(self): + """Image size for the vision encoder.""" + return self.backbone_config.image_size + + @image_size.setter + def image_size(self, value): + """Set the image size and propagate to backbone.""" + self.backbone_config.image_size = value + + +@auto_docstring(checkpoint="facebook/sam3_lite_text") +@strict +class Sam3LiteTextGeometryEncoderConfig(Sam3GeometryEncoderConfig): + pass + + +@auto_docstring(checkpoint="facebook/sam3_lite_text") +@strict +class Sam3LiteTextDETREncoderConfig(Sam3DETREncoderConfig): + pass + + +@auto_docstring(checkpoint="facebook/sam3_lite_text") +@strict +class Sam3LiteTextDETRDecoderConfig(Sam3DETRDecoderConfig): + pass + + +@auto_docstring(checkpoint="facebook/sam3_lite_text") +@strict +class Sam3LiteTextMaskDecoderConfig(Sam3MaskDecoderConfig): + pass + + +@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0") +@strict +class Sam3LiteTextTextConfig(PreTrainedConfig): + r""" + use_repmixer_blocks (`bool`, *optional*, defaults to `True`): + Whether to use RepMixer blocks (MobileCLIP-style) for the first and last encoder layers. + When `False`, all layers are standard Transformer encoder layers. + layer_scale_init_value (`float`, *optional*, defaults to `1e-5`): + Initial value for the learnable layer-scale parameters in RepMixer blocks (residual branches). + repmixer_kernel_size (`int`, *optional*, defaults to `11`): + Kernel size for depthwise convolutions in RepMixer blocks (token mixer and convolutional feed-forward path). + """ + + model_type = "sam3_lite_text_text_model" + + vocab_size: int = 49408 + hidden_size: int = 512 + intermediate_size: int = 2048 + projection_dim: int = 512 + num_hidden_layers: int = 12 + num_attention_heads: int = 8 + max_position_embeddings: int = 77 + hidden_act: str = "gelu" + layer_norm_eps: float = 1e-5 + attention_dropout: float = 0.0 + use_repmixer_blocks: bool = True + layer_scale_init_value: float = 1e-5 + repmixer_kernel_size: int = 11 + + +@auto_docstring(checkpoint="facebook/sam3_lite_text") +@strict +class Sam3LiteTextConfig(PreTrainedConfig): + r""" + geometry_encoder_config (`dict` or `Sam3LiteTextGeometryEncoderConfig`, *optional*): + Configuration for the geometry encoder. + detr_encoder_config (`dict` or `Sam3LiteTextDETREncoderConfig`, *optional*): + Configuration for the DETR encoder. + detr_decoder_config (`dict` or `Sam3LiteTextDETRDecoderConfig`, *optional*): + Configuration for the DETR decoder. + mask_decoder_config (`dict` or `Sam3LiteTextMaskDecoderConfig`, *optional*): + Configuration for the mask decoder. + + Example: + ```python + >>> from transformers import Sam3LiteTextConfig, Sam3LiteTextModel + + >>> # Initializing a SAM3_LITE_TEXT configuration + >>> configuration = Sam3LiteTextConfig() + + >>> # Initializing a model from the configuration + >>> model = Sam3LiteTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "sam3_lite_text" + sub_configs = { + "vision_config": AutoConfig, + "text_config": Sam3LiteTextTextConfig, + "geometry_encoder_config": Sam3LiteTextGeometryEncoderConfig, + "detr_encoder_config": Sam3LiteTextDETREncoderConfig, + "detr_decoder_config": Sam3LiteTextDETRDecoderConfig, + "mask_decoder_config": Sam3LiteTextMaskDecoderConfig, + } + + vision_config: dict | PreTrainedConfig | None = None + text_config: dict | PreTrainedConfig | None = None + geometry_encoder_config: dict | PreTrainedConfig | None = None + detr_encoder_config: dict | PreTrainedConfig | None = None + detr_decoder_config: dict | PreTrainedConfig | None = None + mask_decoder_config: dict | PreTrainedConfig | None = None + initializer_range: float = 0.02 + + def __post_init__(self, **kwargs): + if isinstance(self.vision_config, dict): + self.vision_config["model_type"] = self.vision_config.get("model_type", "sam3_vision_model") + self.vision_config = CONFIG_MAPPING[self.vision_config["model_type"]](**self.vision_config) + elif self.vision_config is None: + self.vision_config = CONFIG_MAPPING["sam3_vision_model"]() + + if self.text_config is None: + self.text_config = Sam3LiteTextTextConfig() + if isinstance(self.text_config, dict): + self.text_config = Sam3LiteTextTextConfig(**self.text_config) + + if self.geometry_encoder_config is None: + self.geometry_encoder_config = Sam3LiteTextGeometryEncoderConfig() + if isinstance(self.geometry_encoder_config, dict): + self.geometry_encoder_config = Sam3LiteTextGeometryEncoderConfig(**self.geometry_encoder_config) + + if self.detr_encoder_config is None: + self.detr_encoder_config = Sam3LiteTextDETREncoderConfig() + if isinstance(self.detr_encoder_config, dict): + self.detr_encoder_config = Sam3LiteTextDETREncoderConfig(**self.detr_encoder_config) + + if self.detr_decoder_config is None: + self.detr_decoder_config = Sam3LiteTextDETRDecoderConfig() + if isinstance(self.detr_decoder_config, dict): + self.detr_decoder_config = Sam3LiteTextDETRDecoderConfig(**self.detr_decoder_config) + + if self.mask_decoder_config is None: + self.mask_decoder_config = Sam3LiteTextMaskDecoderConfig() + if isinstance(self.mask_decoder_config, dict): + self.mask_decoder_config = Sam3LiteTextMaskDecoderConfig(**self.mask_decoder_config) + + super().__post_init__(**kwargs) + + @property + def image_size(self): + """Image size for the SAM3_LITE_TEXT model.""" + return self.vision_config.image_size + + @image_size.setter + def image_size(self, value): + """Set the image size and propagate to vision config.""" + self.vision_config.image_size = value + + +@dataclass +class Sam3LiteTextTextEncoderOutput(BaseModelOutputWithPooling): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Full sequence of hidden states from the text encoder. + pooler_output (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): + EOT-pooled output projected to `projection_dim` via the internal CLIP-style projection. + hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of hidden states at each layer, returned when `output_hidden_states=True`. + attentions (`tuple(torch.FloatTensor)`, *optional*): + Tuple of attention weights at each transformer layer, returned when `output_attentions=True`. + """ + + +class Sam3LiteTextTextPositionEmbedding(nn.Module): + """Learnable positional embedding with bilinear interpolation for variable sequence lengths.""" + + def __init__(self, max_position_embeddings: int, hidden_size: int): + super().__init__() + self.position_embedding = nn.Parameter(torch.empty(1, 1, max_position_embeddings, hidden_size)) + + def forward(self, seq_len: int) -> torch.Tensor: + position_embedding = self.position_embedding + if seq_len != position_embedding.shape[2]: + position_embedding = F.interpolate( + position_embedding, + size=(seq_len, position_embedding.shape[-1]), + mode="bilinear", + ) + return position_embedding.reshape(1, seq_len, -1) + + +class Sam3LiteTextMobileOneBlock(nn.Module): + """Depthwise conv branch with batch norm on the skip path and after the conv (MobileOne-style).""" + + def __init__(self, hidden_size: int, kernel_size: int = 3): + super().__init__() + self.batchnorm_skip = nn.BatchNorm2d(hidden_size) + self.conv = nn.Conv2d( + hidden_size, + hidden_size, + kernel_size=(1, kernel_size), + stride=1, + padding=(0, kernel_size // 2), + groups=hidden_size, + bias=False, + ) + self.batchnorm_conv = nn.BatchNorm2d(hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.batchnorm_conv(self.conv(hidden_states)) + hidden_states = hidden_states + self.batchnorm_skip(residual) + return hidden_states + + +class Sam3LiteTextConvMLP(SiglipMLP): + """Pointwise MLP using 1×1 convolutions, compatible with 4-D (B, C, H, W) feature maps.""" + + def __init__(self, config: Sam3LiteTextTextConfig): + nn.Module.__init__(self) + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Conv2d(config.hidden_size, config.intermediate_size, kernel_size=1) + self.fc2 = nn.Conv2d(config.intermediate_size, config.hidden_size, kernel_size=1) + + +class Sam3LiteTextConvolutionalFeedForward(nn.Module): + """Convolutional feed-forward network: depthwise conv + two pointwise projections.""" + + def __init__(self, config: Sam3LiteTextTextConfig): + super().__init__() + self.depthwise_conv = nn.Conv2d( + config.hidden_size, + config.hidden_size, + kernel_size=(1, config.repmixer_kernel_size), + padding=(0, config.repmixer_kernel_size // 2), + groups=config.hidden_size, + bias=False, + ) + self.depthwise_batchnorm = nn.BatchNorm2d(config.hidden_size) + self.mlp = Sam3LiteTextConvMLP(config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.depthwise_batchnorm(self.depthwise_conv(hidden_states)) + return self.mlp(hidden_states) + + +class Sam3LiteTextLayerScaledResidual(nn.Module): + """Common layer-scale residual pattern shared by the RepMixer and feed-forward branches.""" + + def __init__(self, hidden_size: int, layer_scale_init_value: float): + super().__init__() + self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((hidden_size, 1, 1)), requires_grad=True) + + def layer_scale_residual(self, hidden_states: torch.Tensor, update: torch.Tensor) -> torch.Tensor: + return hidden_states + self.layer_scale * update + + +class Sam3LiteTextRepMixer(Sam3LiteTextLayerScaledResidual): + """Re-parameterisable depthwise-conv token mixer operating on 1D sequence data.""" + + def __init__(self, config: Sam3LiteTextTextConfig): + super().__init__(config.hidden_size, config.layer_scale_init_value) + self.reference_batchnorm = nn.BatchNorm2d(config.hidden_size) + self.mixer = Sam3LiteTextMobileOneBlock(config.hidden_size, kernel_size=config.repmixer_kernel_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.layer_scale_residual( + hidden_states, self.mixer(hidden_states) - self.reference_batchnorm(hidden_states) + ) + + +class Sam3LiteTextRepMixerBlock(Sam3LiteTextLayerScaledResidual): + """Token-mixing RepMixer plus a convolutional feed-forward path, each with layer scale.""" + + def __init__(self, config: Sam3LiteTextTextConfig): + super().__init__(config.hidden_size, config.layer_scale_init_value) + self.token_mixer = Sam3LiteTextRepMixer(config) + self.conv_feed_forward = Sam3LiteTextConvolutionalFeedForward(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + hidden_states = hidden_states.transpose(1, 2).unsqueeze(2) + hidden_states = self.token_mixer(hidden_states) + hidden_states = self.layer_scale_residual(hidden_states, self.conv_feed_forward(hidden_states)) + return hidden_states.squeeze(2).transpose(1, 2) + + +class Sam3LiteTextTextAttention(SiglipAttention): + pass + + +class Sam3LiteTextTextMLP(SiglipMLP): + pass + + +class Sam3LiteTextTextEncoderLayer(SiglipEncoderLayer): + def __init__(self, config: Sam3LiteTextTextConfig): + super().__init__(config) + self.self_attn = Sam3LiteTextTextAttention(config) + self.mlp = Sam3LiteTextTextMLP(config) + + +class Sam3LiteTextTextEmbeddings(nn.Module): + """Token embedding + interpolatable positional embedding for the text encoder.""" + + def __init__(self, config: Sam3LiteTextTextConfig): + super().__init__() + self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) + self.position_embedding = Sam3LiteTextTextPositionEmbedding(config.max_position_embeddings, config.hidden_size) + + def forward(self, input_ids: torch.LongTensor) -> torch.Tensor: + hidden_states = self.token_embedding(input_ids) + hidden_states = hidden_states + self.position_embedding(input_ids.shape[1]).to(hidden_states.dtype) + return hidden_states + + +@auto_docstring +class Sam3LiteTextPreTrainedModel(Sam3PreTrainedModel): + config_class = Sam3LiteTextConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, Sam3LiteTextTextPositionEmbedding): + init.normal_(module.position_embedding, std=module.position_embedding.shape[-1] ** -0.5) + elif isinstance(module, Sam3LiteTextTextModel): + init.normal_(module.projection.weight, std=module.config.hidden_size**-0.5) + + +@auto_docstring( + custom_intro=""" + MobileCLIP MCT text encoder used in EfficientSAM3 LiteText. + + When `config.use_repmixer_blocks` is `True`, the first and last layers are + `Sam3LiteTextRepMixerBlock` modules; the rest are standard `Sam3LiteTextTextEncoderLayer` layers. +""" +) +class Sam3LiteTextTextModel(Sam3LiteTextPreTrainedModel): + config_class = Sam3LiteTextTextConfig + config: Sam3LiteTextTextConfig + _can_record_outputs = { + "hidden_states": Sam3LiteTextTextEncoderLayer, + "attentions": Sam3LiteTextTextAttention, + } + + def __init__(self, config: Sam3LiteTextTextConfig): + super().__init__(config) + self.embeddings = Sam3LiteTextTextEmbeddings(config) + repmixer_positions = {0, config.num_hidden_layers - 1} if config.use_repmixer_blocks else set() + self.layers = nn.ModuleList( + [ + Sam3LiteTextRepMixerBlock(config) if i in repmixer_positions else Sam3LiteTextTextEncoderLayer(config) + for i in range(config.num_hidden_layers) + ] + ) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False) + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Sam3LiteTextTextEncoderOutput: + hidden_states = self.embeddings(input_ids) + attention_mask = create_bidirectional_mask(self.config, hidden_states, attention_mask) + + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask=attention_mask, **kwargs) + + hidden_states = self.final_layer_norm(hidden_states) + + pooled = hidden_states[ + torch.arange(hidden_states.shape[0], device=hidden_states.device), input_ids.argmax(dim=-1) + ] + pooled = self.projection(pooled) + return Sam3LiteTextTextEncoderOutput( + last_hidden_state=hidden_states, + pooler_output=pooled, + ) + + +class Sam3LiteTextModel(Sam3Model): + # DETR components create float masks from features, so flash/flex attention cannot be dispatched safely. + _supports_flash_attn = False + _supports_flex_attn = False + + def __init__(self, config: Sam3LiteTextConfig): + super().__init__(config) + self.text_encoder = Sam3LiteTextTextModel(config.text_config) + self.vision_encoder = AutoModel.from_config(config.vision_config) + + +__all__ = [ + "Sam3LiteTextConfig", + "Sam3LiteTextTextConfig", + "Sam3LiteTextGeometryEncoderConfig", + "Sam3LiteTextDETREncoderConfig", + "Sam3LiteTextDETRDecoderConfig", + "Sam3LiteTextMaskDecoderConfig", + "Sam3LiteTextModel", + "Sam3LiteTextPreTrainedModel", + "Sam3LiteTextTextModel", +] diff --git a/src/transformers/models/sam3_video/convert_sam3_video_to_hf.py b/src/transformers/models/sam3_video/convert_sam3_video_to_hf.py index f021aea27329..6cc00e80dfe9 100644 --- a/src/transformers/models/sam3_video/convert_sam3_video_to_hf.py +++ b/src/transformers/models/sam3_video/convert_sam3_video_to_hf.py @@ -26,7 +26,7 @@ from transformers import CLIPTokenizerFast from transformers.models.sam2_video.video_processing_sam2_video import Sam2VideoVideoProcessor -from transformers.models.sam3.image_processing_sam3_fast import Sam3ImageProcessorFast +from transformers.models.sam3.image_processing_sam3 import Sam3ImageProcessor from transformers.models.sam3.modeling_sam3 import Sam3Model from transformers.models.sam3_tracker.modeling_sam3_tracker import Sam3TrackerModel from transformers.models.sam3_tracker_video.modeling_sam3_tracker_video import Sam3TrackerVideoModel @@ -664,7 +664,7 @@ def convert_sam3_checkpoint( # Save processor print("Creating and saving processor...") - image_processor = Sam3ImageProcessorFast() + image_processor = Sam3ImageProcessor() video_processor = Sam2VideoVideoProcessor( image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5], size={"height": 1008, "width": 1008} ) diff --git a/tests/models/sam3/test_modeling_sam3.py b/tests/models/sam3/test_modeling_sam3.py index 6bbd5947c263..df94063c0a7f 100644 --- a/tests/models/sam3/test_modeling_sam3.py +++ b/tests/models/sam3/test_modeling_sam3.py @@ -17,8 +17,7 @@ import tempfile import unittest -import requests - +from transformers.image_utils import load_image from transformers.testing_utils import ( backend_empty_cache, require_deterministic_for_xpu, @@ -26,11 +25,12 @@ slow, torch_device, ) -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor from ...test_pipeline_mixin import PipelineTesterMixin +from ...test_processing_common import url_to_local_path if is_torch_available(): @@ -50,10 +50,6 @@ from transformers.models.sam3.processing_sam3 import Sam3Processor -if is_vision_available(): - from PIL import Image - - class Sam3VisionModelTester: def __init__( self, @@ -973,16 +969,14 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def prepare_coco_cat_image(): """Prepare COCO cat and laptop image (from batched inference notebook).""" - img_url = "http://images.cocodataset.org/val2017/000000077595.jpg" - raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") - return raw_image + img_url = url_to_local_path("http://images.cocodataset.org/val2017/000000077595.jpg") + return load_image(img_url).convert("RGB") def prepare_coco_kitchen_image(): """Prepare COCO kitchen scene image (from batched inference notebook).""" - img_url = "http://images.cocodataset.org/val2017/000000136466.jpg" - raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") - return raw_image + img_url = url_to_local_path("http://images.cocodataset.org/val2017/000000136466.jpg") + return load_image(img_url).convert("RGB") @slow diff --git a/tests/models/sam3_lite_text/__init__.py b/tests/models/sam3_lite_text/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/sam3_lite_text/test_modeling_sam3_lite_text.py b/tests/models/sam3_lite_text/test_modeling_sam3_lite_text.py new file mode 100644 index 000000000000..05a9307bfa87 --- /dev/null +++ b/tests/models/sam3_lite_text/test_modeling_sam3_lite_text.py @@ -0,0 +1,1346 @@ +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch SAM3 LiteText model.""" + +import gc +import tempfile +import unittest + +from transformers.image_utils import load_image +from transformers.testing_utils import ( + backend_empty_cache, + require_deterministic_for_xpu, + require_torch, + slow, + torch_device, +) +from transformers.utils import is_torch_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin +from ...test_processing_common import url_to_local_path + + +if is_torch_available(): + import torch + from torch import nn + + from transformers.models.sam3.configuration_sam3 import Sam3VisionConfig, Sam3ViTConfig + from transformers.models.sam3.processing_sam3 import Sam3Processor as Sam3LiteTextProcessor + from transformers.models.sam3_lite_text.configuration_sam3_lite_text import ( + Sam3LiteTextConfig, + Sam3LiteTextDETRDecoderConfig, + Sam3LiteTextDETREncoderConfig, + Sam3LiteTextGeometryEncoderConfig, + Sam3LiteTextMaskDecoderConfig, + ) + from transformers.models.sam3_lite_text.modeling_sam3_lite_text import Sam3LiteTextModel + + +class Sam3LiteTextModelTester: + def __init__( + self, + parent, + num_channels=3, + image_size=224, # Keep reasonable size: 224 = 16 * 14 + hidden_size=32, + patch_size=14, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=64, + window_size=8, # 224/14 = 16 patches, 16/2 = 8 per window + global_attn_indexes=None, + fpn_hidden_size=32, + scale_factors=None, + geometry_encoder_hidden_size=32, + geometry_encoder_num_layers=1, # Reduced from 2 to 1 + detr_encoder_hidden_size=32, + detr_encoder_num_layers=1, # Reduced from 2 to 1 + detr_decoder_hidden_size=32, + detr_decoder_num_layers=1, # Reduced from 2 to 1 + detr_decoder_num_queries=5, # Reduced from 10 to 5 + mask_decoder_hidden_size=32, + batch_size=2, + text_seq_length=16, + vocab_size=100, + is_training=True, + ): + if global_attn_indexes is None: + global_attn_indexes = [0, 1] + if scale_factors is None: + scale_factors = [2.0, 1.0] # Just 2 scales to reduce params + + self.parent = parent + self.num_channels = num_channels + self.image_size = image_size + self.hidden_size = hidden_size + self.patch_size = patch_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.fpn_hidden_size = fpn_hidden_size + self.scale_factors = scale_factors + self.batch_size = batch_size + self.text_seq_length = text_seq_length + self.vocab_size = vocab_size + self.is_training = is_training + + # Geometry encoder + self.geometry_encoder_hidden_size = geometry_encoder_hidden_size + self.geometry_encoder_num_layers = geometry_encoder_num_layers + + # DETR encoder/decoder + self.detr_encoder_hidden_size = detr_encoder_hidden_size + self.detr_encoder_num_layers = detr_encoder_num_layers + self.detr_decoder_hidden_size = detr_decoder_hidden_size + self.detr_decoder_num_layers = detr_decoder_num_layers + self.detr_decoder_num_queries = detr_decoder_num_queries + + # Mask decoder + self.mask_decoder_hidden_size = mask_decoder_hidden_size + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + # Simple text input (will be processed by text encoder) + input_ids = torch.randint(0, self.vocab_size, (self.batch_size, self.text_seq_length), device=torch_device) + attention_mask = torch.ones_like(input_ids) + + config = self.get_config() + + return config, pixel_values, input_ids, attention_mask + + def get_config(self): + backbone_config = Sam3ViTConfig( + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + num_channels=self.num_channels, + image_size=self.image_size, + patch_size=self.patch_size, + window_size=self.window_size, + global_attn_indexes=self.global_attn_indexes, + ) + + vision_config = Sam3VisionConfig( + backbone_config=backbone_config, + fpn_hidden_size=self.fpn_hidden_size, + scale_factors=self.scale_factors, + ) + + # Small text config for testing (instead of default full MobileCLIP model) + # use_repmixer_blocks=False ensures all layers are standard TransformerLayers so + # attention output, hidden state, and SDPA dispatch tests pass correctly. + text_config = { + "vocab_size": self.vocab_size, + "hidden_size": 32, + "intermediate_size": 64, + "projection_dim": 32, + "num_hidden_layers": self.num_hidden_layers, + "num_attention_heads": 4, + "max_position_embeddings": self.text_seq_length, + "hidden_act": "gelu", + "use_repmixer_blocks": False, + } + + geometry_encoder_config = Sam3LiteTextGeometryEncoderConfig( + hidden_size=self.geometry_encoder_hidden_size, + num_layers=self.geometry_encoder_num_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + mask_fuser_hidden_size=self.geometry_encoder_hidden_size, # Match hidden_size to reduce params + mask_fuser_num_layers=1, # Reduce from default 2 to 1 + ) + + detr_encoder_config = Sam3LiteTextDETREncoderConfig( + hidden_size=self.detr_encoder_hidden_size, + num_layers=self.detr_encoder_num_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + ) + + detr_decoder_config = Sam3LiteTextDETRDecoderConfig( + hidden_size=self.detr_decoder_hidden_size, + num_layers=self.detr_decoder_num_layers, + num_queries=self.detr_decoder_num_queries, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + ) + + mask_decoder_config = Sam3LiteTextMaskDecoderConfig( + hidden_size=self.mask_decoder_hidden_size, + num_upsampling_stages=2, # Reduced from 3 to 2 + ) + + return Sam3LiteTextConfig( + vision_config=vision_config, + text_config=text_config, + geometry_encoder_config=geometry_encoder_config, + detr_encoder_config=detr_encoder_config, + detr_decoder_config=detr_decoder_config, + mask_decoder_config=mask_decoder_config, + ) + + def create_and_check_model(self, config, pixel_values, input_ids, attention_mask): + model = Sam3LiteTextModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask) + + # Check output shapes + self.parent.assertIsNotNone(result.pred_masks) + self.parent.assertIsNotNone(result.pred_boxes) + self.parent.assertIsNotNone(result.pred_logits) + + # Masks should be [batch_size, num_queries, H, W] + self.parent.assertEqual(result.pred_masks.shape[0], self.batch_size) + self.parent.assertEqual(result.pred_masks.shape[1], self.detr_decoder_num_queries) + + # Boxes should be [batch_size, num_queries, 4] + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.detr_decoder_num_queries, 4)) + + # Logits should be [batch_size, num_queries] + self.parent.assertEqual(result.pred_logits.shape, (self.batch_size, self.detr_decoder_num_queries)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, input_ids, attention_mask = config_and_inputs + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class Sam3LiteTextModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + """ + Tests for SAM3 full model. + """ + + all_model_classes = (Sam3LiteTextModel,) if is_torch_available() else () + pipeline_model_mapping = {"mask-generation": Sam3LiteTextModel} if is_torch_available() else {} + + test_resize_embeddings = False + _is_composite = True + + def setUp(self): + self.model_tester = Sam3LiteTextModelTester(self) + common_properties = ["initializer_range"] + self.config_tester = ConfigTester( + self, config_class=Sam3LiteTextConfig, has_text_modality=False, common_properties=common_properties + ) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="SAM3 does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + def test_model_get_set_embeddings(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + # Vision encoder has input embeddings + self.assertIsInstance(model.vision_encoder.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_batching_equivalence(self, atol=5e-4, rtol=5e-4): + super().test_batching_equivalence(atol=atol, rtol=rtol) + + # Override as SAM3Model has component-specific attention outputs + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class._from_config(config, attn_implementation="eager") + config = model.config + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + # Check that we have the component-specific attention outputs + # Note: Some may be empty tuples if attentions aren't collected for that component + self.assertIsNotNone(outputs.vision_attentions) + self.assertIsNotNone(outputs.detr_encoder_attentions) + self.assertIsNotNone(outputs.detr_decoder_attentions) + self.assertIsNotNone(outputs.mask_decoder_attentions) + + # Check vision attentions (from ViT backbone) - should be properly collected + if outputs.vision_attentions: + vision_attentions = outputs.vision_attentions + self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers) + + # Check that at least vision attentions are present (others may require different collection mechanism) + self.assertTrue( + len(outputs.vision_attentions) > 0, + "At least vision attentions should be collected when output_attentions=True", + ) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + for k in config.sub_configs: + if (subconfig := getattr(config, k)) is not None: + subconfig.output_attentions = True + # Sam3LiteText has a vision subconfig with itself a sub config.... + for k in subconfig.sub_configs: + if (subsubconfig := getattr(subconfig, k)) is not None: + subsubconfig.output_attentions = True + + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + # Verify again with config-based setting + self.assertIsNotNone(outputs.vision_attentions) + self.assertIsNotNone(outputs.detr_encoder_attentions) + self.assertIsNotNone(outputs.detr_decoder_attentions) + self.assertIsNotNone(outputs.mask_decoder_attentions) + + # Override as SAM3Model has component-specific attention/hidden state outputs + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for k in config.sub_configs: + if getattr(config, k) is not None: + getattr(config, k).output_hidden_states = True + getattr(config, k).output_attentions = True + + config.output_hidden_states = True + config.output_attentions = True + config._attn_implementation = "eager" + + # Use first model class + model_class = self.all_model_classes[0] + model = model_class._from_config(config, attn_implementation="eager") + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + outputs = model(**inputs) + + output = outputs[0] + + # SAM3 has component-specific hidden states and attentions + # Check vision hidden states and attentions + if outputs.vision_hidden_states is not None and len(outputs.vision_hidden_states) > 0: + vision_hidden_states = outputs.vision_hidden_states[0] + vision_hidden_states.retain_grad() + + if outputs.vision_attentions is not None and len(outputs.vision_attentions) > 0: + vision_attentions = outputs.vision_attentions[0] + vision_attentions.retain_grad() + + # Check DETR encoder hidden states and attentions + if outputs.encoder_hidden_states is not None and len(outputs.encoder_hidden_states) > 0: + encoder_hidden_states = outputs.encoder_hidden_states[0] + encoder_hidden_states.retain_grad() + + if outputs.detr_encoder_attentions is not None and len(outputs.detr_encoder_attentions) > 0: + detr_encoder_attentions = outputs.detr_encoder_attentions[0] + detr_encoder_attentions.retain_grad() + + # Check DETR decoder hidden states and attentions + if outputs.decoder_hidden_states is not None and len(outputs.decoder_hidden_states) > 0: + decoder_hidden_states = outputs.decoder_hidden_states[0] + decoder_hidden_states.retain_grad() + + if outputs.detr_decoder_attentions is not None and len(outputs.detr_decoder_attentions) > 0: + detr_decoder_attentions = outputs.detr_decoder_attentions[0] + detr_decoder_attentions.retain_grad() + + # Check mask decoder attentions + if outputs.mask_decoder_attentions is not None and len(outputs.mask_decoder_attentions) > 0: + mask_decoder_attentions = outputs.mask_decoder_attentions[0] + mask_decoder_attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + # Check gradients are not None + if outputs.vision_hidden_states is not None and len(outputs.vision_hidden_states) > 0: + self.assertIsNotNone(vision_hidden_states.grad) + + if outputs.vision_attentions is not None and len(outputs.vision_attentions) > 0: + self.assertIsNotNone(vision_attentions.grad) + + if outputs.encoder_hidden_states is not None and len(outputs.encoder_hidden_states) > 0: + self.assertIsNotNone(encoder_hidden_states.grad) + + if outputs.detr_encoder_attentions is not None and len(outputs.detr_encoder_attentions) > 0: + self.assertIsNotNone(detr_encoder_attentions.grad) + + if outputs.decoder_hidden_states is not None and len(outputs.decoder_hidden_states) > 0: + self.assertIsNotNone(decoder_hidden_states.grad) + + if outputs.detr_decoder_attentions is not None and len(outputs.detr_decoder_attentions) > 0: + self.assertIsNotNone(detr_decoder_attentions.grad) + + if outputs.mask_decoder_attentions is not None and len(outputs.mask_decoder_attentions) > 0: + self.assertIsNotNone(mask_decoder_attentions.grad) + + def test_hidden_states_output(self): + """Test that SAM3 properly outputs component-specific hidden states.""" + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + # Enable hidden states output + config.output_hidden_states = True + for k in config.sub_configs: + if getattr(config, k) is not None: + getattr(config, k).output_hidden_states = True + + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + # SAM3 has component-specific hidden states + # Check vision hidden states + if outputs.vision_hidden_states is not None: + vision_hidden_states = outputs.vision_hidden_states + self.assertIsInstance(vision_hidden_states, (list, tuple)) + # Vision encoder outputs hidden states from each layer + expected_num_vision_layers = self.model_tester.num_hidden_layers + 1 # +1 for embeddings + self.assertEqual(len(vision_hidden_states), expected_num_vision_layers) + + # Check DETR encoder hidden states (stored as encoder_hidden_states) + if outputs.encoder_hidden_states is not None: + encoder_hidden_states = outputs.encoder_hidden_states + self.assertIsInstance(encoder_hidden_states, (list, tuple)) + + # Check DETR decoder hidden states (stored as decoder_hidden_states) + if outputs.decoder_hidden_states is not None: + decoder_hidden_states = outputs.decoder_hidden_states + self.assertIsInstance(decoder_hidden_states, (list, tuple)) + + def test_sdpa_can_dispatch_composite_models(self): + """ + Tests if composite models dispatch correctly on SDPA/eager when requested. + SAM3 has multiple sub-models: vision_encoder, text_encoder, geometry_encoder, + detr_encoder, detr_decoder, mask_decoder. + """ + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self._is_composite: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa") + model_sdpa = model_sdpa.eval().to(torch_device) + + # Get all sub-models that support attention + vision_encoder_sdpa = getattr(model_sdpa, "vision_encoder") + text_encoder_sdpa = getattr(model_sdpa, "text_encoder", None) + detr_encoder_sdpa = getattr(model_sdpa, "detr_encoder", None) + detr_decoder_sdpa = getattr(model_sdpa, "detr_decoder", None) + mask_decoder_sdpa = getattr(model_sdpa, "mask_decoder", None) + + # Check that sub-models dispatch to SDPA if they support it + self.assertTrue(vision_encoder_sdpa.config._attn_implementation == "sdpa") + if text_encoder_sdpa is not None and hasattr(text_encoder_sdpa, "_supports_sdpa"): + # Sam3LiteTextTextModel supports SDPA + self.assertTrue(text_encoder_sdpa.config._attn_implementation == "sdpa") + if detr_encoder_sdpa is not None: + self.assertTrue(detr_encoder_sdpa.config._attn_implementation == "sdpa") + if detr_decoder_sdpa is not None: + self.assertTrue(detr_decoder_sdpa.config._attn_implementation == "sdpa") + if mask_decoder_sdpa is not None: + self.assertTrue(mask_decoder_sdpa.config._attn_implementation == "sdpa") + + # Now test with eager + model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device) + + self.assertTrue(getattr(model_eager, "vision_encoder").config._attn_implementation == "eager") + if hasattr(model_eager, "text_encoder") and hasattr(model_eager.text_encoder, "config"): + self.assertTrue(model_eager.text_encoder.config._attn_implementation == "eager") + if hasattr(model_eager, "detr_encoder"): + self.assertTrue(model_eager.detr_encoder.config._attn_implementation == "eager") + if hasattr(model_eager, "detr_decoder"): + self.assertTrue(model_eager.detr_decoder.config._attn_implementation == "eager") + if hasattr(model_eager, "mask_decoder"): + self.assertTrue(model_eager.mask_decoder.config._attn_implementation == "eager") + + # Verify no SDPA layers in eager model + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if ( + class_name.endswith("Attention") + and getattr(submodule, "config", None) + and submodule.config._attn_implementation == "sdpa" + ): + raise ValueError("The eager model should not have SDPA attention layers") + + def test_forward_with_text_embeds(self): + """Test that text_embeds parameter works correctly.""" + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + # First get text embeddings + with torch.no_grad(): + text_embeds = model.get_text_features( + input_ids=inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"], return_dict=True + ).pooler_output + + # Forward with text_embeds (remove input_ids) + inputs_with_embeds = { + "pixel_values": inputs_dict["pixel_values"], + "text_embeds": text_embeds, + } + + with torch.no_grad(): + outputs_with_embeds = model(**inputs_with_embeds) + + # Forward with input_ids + with torch.no_grad(): + outputs_with_ids = model(**inputs_dict) + + # Outputs should be very close + self.assertTrue(torch.allclose(outputs_with_embeds.pred_logits, outputs_with_ids.pred_logits, atol=1e-5)) + self.assertTrue(torch.allclose(outputs_with_embeds.pred_boxes, outputs_with_ids.pred_boxes, atol=1e-5)) + + def test_forward_with_both_input_ids_and_text_embeds_raises_error(self): + """Test that passing both input_ids and text_embeds raises an error.""" + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + # Get text embeddings + with torch.no_grad(): + text_embeds = model.get_text_features( + input_ids=inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"] + ) + + # Try to pass both (should raise error) + inputs_with_both = { + "pixel_values": inputs_dict["pixel_values"], + "input_ids": inputs_dict["input_ids"], + "text_embeds": text_embeds, + } + + with self.assertRaises(ValueError): + model(**inputs_with_both) + + def test_forward_with_vision_embeds(self): + """Test that vision_embeds parameter works correctly.""" + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + # First get vision embeddings + with torch.no_grad(): + vision_embeds = model.get_vision_features(pixel_values=inputs_dict["pixel_values"]) + + # Forward with vision_embeds (remove pixel_values) + inputs_with_embeds = { + "vision_embeds": vision_embeds, + "input_ids": inputs_dict["input_ids"], + "attention_mask": inputs_dict["attention_mask"], + } + + with torch.no_grad(): + outputs_with_embeds = model(**inputs_with_embeds) + + # Forward with pixel_values + with torch.no_grad(): + outputs_with_pixels = model(**inputs_dict) + + # Outputs should be very close + self.assertTrue( + torch.allclose(outputs_with_embeds.pred_logits, outputs_with_pixels.pred_logits, atol=1e-5) + ) + self.assertTrue(torch.allclose(outputs_with_embeds.pred_boxes, outputs_with_pixels.pred_boxes, atol=1e-5)) + + def test_forward_with_both_pixel_values_and_vision_embeds_raises_error(self): + """Test that passing both pixel_values and vision_embeds raises an error.""" + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + # Get vision embeddings + with torch.no_grad(): + vision_embeds = model.get_vision_features(pixel_values=inputs_dict["pixel_values"]) + + # Try to pass both (should raise error) + inputs_with_both = { + "pixel_values": inputs_dict["pixel_values"], + "vision_embeds": vision_embeds, + "input_ids": inputs_dict["input_ids"], + "attention_mask": inputs_dict["attention_mask"], + } + + with self.assertRaises(ValueError): + model(**inputs_with_both) + + def test_custom_image_size(self): + """Test that custom image size can be set and propagates correctly through nested configs.""" + config = self.model_tester.get_config() + config.image_size = 560 + + self.assertEqual(config.image_size, 560) + self.assertEqual(config.vision_config.image_size, 560) + self.assertEqual(config.vision_config.backbone_config.image_size, 560) + + # Verify model works with custom size + model = Sam3LiteTextModel(config=config).to(torch_device).eval() + pixel_values = floats_tensor([self.model_tester.batch_size, self.model_tester.num_channels, 560, 560]).to( + torch_device + ) + input_ids = torch.randint( + 0, + config.text_config.vocab_size, + (self.model_tester.batch_size, self.model_tester.text_seq_length), + device=torch_device, + ) + + with torch.no_grad(): + outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=torch.ones_like(input_ids)) + + self.assertIsNotNone(outputs.pred_masks) + self.assertIsNotNone(outputs.pred_boxes) + self.assertIsNotNone(outputs.pred_logits) + + @unittest.skip(reason="SAM3 LiteText model can't be compiled dynamic yet") + def test_sdpa_can_compile_dynamic(self): + pass + + @unittest.skip( + reason="Sam3LiteTextModel creates float attention masks from features (with gradients) in the DETR " + "encoder/decoder, which Flash Attention requires to be None." + ) + def test_sdpa_can_dispatch_on_flash(self): + pass + + def test_model_outputs_equivalence(self): + """ + Test that tuple and dict outputs are equivalent. + SAM3 returns complex outputs with component-specific fields, so we need to ensure proper conversion. + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def set_nan_tensor_to_zero(t): + t[t != t] = 0 + return t + + def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + with torch.no_grad(): + tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) + dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (list, tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, dict): + for tuple_iterable_value, dict_iterable_value in zip( + tuple_object.values(), dict_object.values() + ): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + # model might return non-tensors objects (e.g. Cache class) + elif isinstance(tuple_object, torch.Tensor): + self.assertTrue( + torch.allclose( + set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 + ), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" + f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." + ), + ) + + recursive_check(tuple_output, dict_output) + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs) + + # Test with output_hidden_states + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + # Test with output_attentions if supported + if self.has_attentions: + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence( + model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True} + ) + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + """Override to ensure input_ids and attention_mask are always present for Sam3LiteTextModel.""" + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + # Sam3LiteTextModel always requires input_ids and attention_mask for text encoding + if model_class == Sam3LiteTextModel: + if "input_ids" not in inputs_dict or inputs_dict.get("input_ids") is None: + # Create dummy input_ids if not present + # Get batch_size from pixel_values or vision_embeds + if "pixel_values" in inputs_dict and inputs_dict.get("pixel_values") is not None: + batch_size = inputs_dict["pixel_values"].shape[0] + elif "vision_embeds" in inputs_dict and inputs_dict.get("vision_embeds") is not None: + vision_embeds = inputs_dict["vision_embeds"] + if vision_embeds.fpn_hidden_states is not None and len(vision_embeds.fpn_hidden_states) > 0: + batch_size = vision_embeds.fpn_hidden_states[0].shape[0] + elif vision_embeds.last_hidden_state is not None: + batch_size = vision_embeds.last_hidden_state.shape[0] + else: + batch_size = 2 + else: + batch_size = 2 + config = self.model_tester.get_config() + # text_config might be a dict or a config object + if isinstance(config.text_config, dict): + vocab_size = config.text_config.get("vocab_size", 1000) + else: + vocab_size = getattr(config.text_config, "vocab_size", 1000) + inputs_dict["input_ids"] = torch.randint(0, vocab_size, (batch_size, 16), device=torch_device) + if "attention_mask" not in inputs_dict or inputs_dict.get("attention_mask") is None: + inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["input_ids"]) + + return inputs_dict + + +def prepare_coco_cat_image(): + """Prepare COCO cat and laptop image (from batched inference notebook).""" + img_url = url_to_local_path("http://images.cocodataset.org/val2017/000000077595.jpg") + return load_image(img_url).convert("RGB") + + +def prepare_coco_kitchen_image(): + """Prepare COCO kitchen scene image (from batched inference notebook).""" + img_url = url_to_local_path("http://images.cocodataset.org/val2017/000000136466.jpg") + return load_image(img_url).convert("RGB") + + +@slow +@require_torch +class Sam3LiteTextModelIntegrationTest(unittest.TestCase): + """Integration tests for SAM3 model with real pretrained weights.""" + + def setUp(self): + super().setUp() + model_name = "yonigozlan/sam3-litetext-s0" + self.model = Sam3LiteTextModel.from_pretrained(model_name, dtype=torch.float32) + self.processor = Sam3LiteTextProcessor.from_pretrained(model_name) + self.model.to(torch_device) + self.model.eval() + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def test_inference_text_prompt_only(self): + """Test inference with text prompt only (from multiway_prompting notebook).""" + # Example from notebook: "short hair" text prompt + raw_image = prepare_coco_cat_image() + text = "ear" + + inputs = self.processor(images=raw_image, text=text, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = self.model(**inputs) + + # Check exact output shapes + self.assertEqual(outputs.pred_masks.shape, (1, 200, 288, 288)) + self.assertEqual(outputs.pred_boxes.shape, (1, 200, 4)) + self.assertEqual(outputs.pred_logits.shape, (1, 200)) + + # Check that predictions have reasonable scores (after sigmoid) + scores = torch.sigmoid(outputs.pred_logits) + self.assertTrue((scores >= 0).all() and (scores <= 1).all()) + + # Check exact values + sorted_indices = torch.argsort(scores.squeeze(), descending=True) + top_scores = scores.squeeze()[sorted_indices[:3]] + top_logits = outputs.pred_logits.squeeze()[sorted_indices[:3]] + top_idx = sorted_indices[0].item() + + torch.testing.assert_close( + top_scores, torch.tensor([0.9326, 0.9149, 0.1009]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + top_logits, torch.tensor([2.6268, 2.3755, -2.1877]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + outputs.pred_boxes[0, top_idx], + torch.tensor([0.4704, 0.2015, 0.5615, 0.3770]).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + torch.testing.assert_close( + outputs.pred_masks[0, top_idx, :3, :3], + torch.tensor( + [[-2.1856, -6.2395, -7.0870], [-6.0620, -10.4022, -11.2534], [-8.7157, -10.7961, -9.9844]] + ).to(torch_device), + atol=1e-2, + rtol=1e-2, + ) + + # Test post-processing + results = self.processor.post_process_instance_segmentation( + outputs, threshold=0.5, mask_threshold=0.5, target_sizes=inputs.get("original_sizes").tolist() + ) + self.assertEqual(len(results), 1) + result = results[0] + + # Check that we have detections + self.assertGreater(len(result["masks"]), 0) + self.assertGreater(len(result["boxes"]), 0) + self.assertGreater(len(result["scores"]), 0) + + # Check exact values for top detection + top_pp_score = result["scores"][0] + top_pp_box = result["boxes"][0] + + torch.testing.assert_close(top_pp_score, torch.tensor(0.9137).to(torch_device), atol=1e-4, rtol=1e-4) + torch.testing.assert_close( + top_pp_box, torch.tensor([402.3560, 90.1643, 459.6652, 156.5201]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + + def test_inference_single_box_prompt(self): + """Test inference with a single bounding box prompt (from batched_inference notebook).""" + raw_image = prepare_coco_cat_image() + # Example from notebook: laptop region in image 1 + # Box in xyxy format: [100, 150, 500, 450] + box_xyxy = [100, 150, 500, 450] + input_boxes = [[box_xyxy]] + + inputs = self.processor( + images=raw_image, + input_boxes=input_boxes, + input_boxes_labels=[[1]], # Positive box + return_tensors="pt", + ).to(torch_device) + + with torch.no_grad(): + outputs = self.model(**inputs) + + # Check exact output shapes + self.assertEqual(outputs.pred_masks.shape, (1, 200, 288, 288)) + self.assertEqual(outputs.pred_boxes.shape, (1, 200, 4)) + self.assertEqual(outputs.pred_logits.shape, (1, 200)) + + # Check exact values + scores = torch.sigmoid(outputs.pred_logits) + sorted_indices = torch.argsort(scores.squeeze(), descending=True) + top_scores = scores.squeeze()[sorted_indices[:3]] + top_logits = outputs.pred_logits.squeeze()[sorted_indices[:3]] + top_idx = sorted_indices[0].item() + + torch.testing.assert_close( + top_scores, torch.tensor([0.9387, 0.1474, 0.1087]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + top_logits, torch.tensor([2.7291, -1.7554, -2.1046]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + outputs.pred_boxes[0, top_idx], + torch.tensor([0.1628, 0.4168, 0.7534, 0.9935]).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + torch.testing.assert_close( + outputs.pred_masks[0, top_idx, :3, :3], + torch.tensor([[-1.9982, -3.7409, -3.9956], [-3.6554, -5.9248, -6.0131], [-4.5402, -6.0183, -6.2711]]).to( + torch_device + ), + atol=1e-2, + rtol=1e-2, + ) + + # Test post-processing + results = self.processor.post_process_instance_segmentation( + outputs, threshold=0.5, mask_threshold=0.5, target_sizes=inputs.get("original_sizes").tolist() + ) + self.assertEqual(len(results), 1) + result = results[0] + + # Check that we have detections + self.assertGreater(len(result["masks"]), 0) + + # Check exact values for top detection + top_pp_score = result["scores"][0] + top_pp_box = result["boxes"][0] + + torch.testing.assert_close(top_pp_score, torch.tensor(0.9387).to(torch_device), atol=1e-4, rtol=1e-4) + torch.testing.assert_close( + top_pp_box, torch.tensor([104.2026, 177.1267, 482.1449, 422.2436]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + + def test_inference_multi_box_prompt(self): + """Test inference with multiple box prompts with positive and negative labels (from batched_inference notebook).""" + raw_image = prepare_coco_kitchen_image() + # Example from notebook: multiple positive boxes (dial + button) + # Dial box (xyxy): [59, 144, 76, 163] + # Button box (xyxy): [87, 148, 104, 159] + box1_xyxy = [59, 144, 76, 163] + box2_xyxy = [87, 148, 104, 159] + + input_boxes = [[box1_xyxy, box2_xyxy]] + input_boxes_labels = [[1, 1]] # Both positive + + inputs = self.processor( + images=raw_image, input_boxes=input_boxes, input_boxes_labels=input_boxes_labels, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = self.model(**inputs) + + # Check exact output shapes + self.assertEqual(outputs.pred_masks.shape, (1, 200, 288, 288)) + self.assertEqual(outputs.pred_boxes.shape, (1, 200, 4)) + self.assertEqual(outputs.pred_logits.shape, (1, 200)) + + # Check exact values + scores = torch.sigmoid(outputs.pred_logits) + sorted_indices = torch.argsort(scores.squeeze(), descending=True) + top_scores = scores.squeeze()[sorted_indices[:3]] + top_logits = outputs.pred_logits.squeeze()[sorted_indices[:3]] + top_idx = sorted_indices[0].item() + + torch.testing.assert_close( + top_scores, torch.tensor([0.9615, 0.9399, 0.8262]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + top_logits, torch.tensor([3.2166, 2.7504, 1.5591]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + outputs.pred_boxes[0, top_idx], + torch.tensor([0.1758, 0.2889, 0.2296, 0.3258]).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + torch.testing.assert_close( + outputs.pred_masks[0, top_idx, :3, :3], + torch.tensor( + [[-9.1072, -15.0329, -18.6687], [-14.1670, -21.3808, -26.8545], [-15.5131, -23.4600, -17.5040]] + ).to(torch_device), + atol=1e-2, + rtol=1e-2, + ) + + # Test post-processing + results = self.processor.post_process_instance_segmentation( + outputs, threshold=0.5, mask_threshold=0.5, target_sizes=inputs.get("original_sizes").tolist() + ) + self.assertEqual(len(results), 1) + result = results[0] + + # Check that we have detections + self.assertGreater(len(result["masks"]), 0) + + # Check exact values for top detection + top_pp_score = result["scores"][0] + top_pp_box = result["boxes"][0] + + torch.testing.assert_close(top_pp_score, torch.tensor(0.9399).to(torch_device), atol=1e-4, rtol=1e-4) + torch.testing.assert_close( + top_pp_box, torch.tensor([86.8764, 147.5362, 104.4958, 159.6079]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + + def test_inference_combined_prompts(self): + """Test inference with combined text and geometry prompts (text + negative box from batched_inference notebook).""" + raw_image = prepare_coco_kitchen_image() + # Example from notebook: text "handle" + negative box to exclude oven handle + text = "handle" + # Negative box covering the oven handle area (xyxy): [40, 183, 318, 204] + oven_handle_box = [40, 183, 318, 204] + + input_boxes = [[oven_handle_box]] + + inputs = self.processor( + images=raw_image, + text=text, + input_boxes=input_boxes, + input_boxes_labels=[[0]], # 0 = negative + return_tensors="pt", + ).to(torch_device) + + with torch.no_grad(): + outputs = self.model(**inputs) + + # Check exact output shapes + self.assertEqual(outputs.pred_masks.shape, (1, 200, 288, 288)) + self.assertEqual(outputs.pred_boxes.shape, (1, 200, 4)) + self.assertEqual(outputs.pred_logits.shape, (1, 200)) + + def test_inference_batched_images(self): + """Test batched inference with multiple images (from batched_inference notebook).""" + # Example from notebook: batch of 2 images with different text prompts + raw_image1 = prepare_coco_cat_image() + raw_image2 = prepare_coco_kitchen_image() + + # Batch of 2 images with different text prompts: "ear" for cat, "dial" for kitchen + inputs = self.processor(images=[raw_image1, raw_image2], text=["ear", "dial"], return_tensors="pt").to( + torch_device + ) + + with torch.no_grad(): + outputs = self.model(**inputs) + + # Check exact output shapes + self.assertEqual(outputs.pred_masks.shape, (2, 200, 288, 288)) + self.assertEqual(outputs.pred_boxes.shape, (2, 200, 4)) + self.assertEqual(outputs.pred_logits.shape, (2, 200)) + + # Check scores are reasonable + scores = torch.sigmoid(outputs.pred_logits) + self.assertTrue((scores >= 0).all() and (scores <= 1).all()) + + # Check exact values + sorted_indices_0 = torch.argsort(scores[0], descending=True) + sorted_indices_1 = torch.argsort(scores[1], descending=True) + top_scores_0 = scores[0][sorted_indices_0[:3]] + top_scores_1 = scores[1][sorted_indices_1[:3]] + top_logits_0 = outputs.pred_logits[0][sorted_indices_0[:3]] + top_logits_1 = outputs.pred_logits[1][sorted_indices_1[:3]] + top_idx_0 = sorted_indices_0[0].item() + top_idx_1 = sorted_indices_1[0].item() + + torch.testing.assert_close( + top_scores_0, torch.tensor([0.9326, 0.9149, 0.1009]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + top_scores_1, torch.tensor([0.8532, 0.8497, 0.8473]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + top_logits_0, torch.tensor([2.6268, 2.3755, -2.1877]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + top_logits_1, torch.tensor([1.7598, 1.7324, 1.7133]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + outputs.pred_boxes[0, top_idx_0], + torch.tensor([0.4704, 0.2015, 0.5615, 0.3770]).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + torch.testing.assert_close( + outputs.pred_boxes[1, top_idx_1], + torch.tensor([0.5088, 0.2749, 0.5775, 0.3228]).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + torch.testing.assert_close( + outputs.pred_masks[0, top_idx_0, :3, :3], + torch.tensor( + [[-2.1858, -6.2385, -7.0859], [-6.0630, -10.4010, -11.2507], [-8.7169, -10.7953, -9.9857]] + ).to(torch_device), + atol=1e-2, + rtol=1e-2, + ) + torch.testing.assert_close( + outputs.pred_masks[1, top_idx_1, :3, :3], + torch.tensor( + [[-5.7639, -10.4997, -11.1576], [-8.6470, -14.8703, -17.1438], [-8.9861, -14.4202, -12.6988]] + ).to(torch_device), + atol=1e-2, + rtol=1e-2, + ) + + # Test post-processing + results = self.processor.post_process_instance_segmentation( + outputs, threshold=0.3, mask_threshold=0.5, target_sizes=inputs.get("original_sizes").tolist() + ) + self.assertEqual(len(results), 2) + + # Check that both have detections + self.assertGreater(len(results[0]["masks"]), 0) + self.assertGreater(len(results[1]["masks"]), 0) + + # Check exact values for top detection in each image + top_pp_score_0 = results[0]["scores"][0] + top_pp_box_0 = results[0]["boxes"][0] + top_pp_score_1 = results[1]["scores"][0] + top_pp_box_1 = results[1]["boxes"][0] + + torch.testing.assert_close(top_pp_score_0, torch.tensor(0.9137).to(torch_device), atol=1e-4, rtol=1e-4) + torch.testing.assert_close( + top_pp_box_0, torch.tensor([402.3560, 90.1644, 459.6652, 156.5200]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close(top_pp_score_1, torch.tensor(0.5882).to(torch_device), atol=1e-4, rtol=1e-4) + torch.testing.assert_close( + top_pp_box_1, torch.tensor([110.6558, 271.1691, 137.3885, 301.4023]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + + def test_inference_batched_mixed_prompts(self): + """Test batched inference with mixed prompt types (from batched_inference notebook).""" + # Example from notebook: Image 1 with text "laptop", Image 2 with visual prompt (dial) + raw_image1 = prepare_coco_cat_image() + raw_image2 = prepare_coco_kitchen_image() + + # Box for dial in image 2 (xyxy): [59, 144, 76, 163] + box2_xyxy = [59, 144, 76, 163] + + inputs = self.processor( + images=[raw_image1, raw_image2], + text=["laptop", None], # Only first image has text + input_boxes=[None, [box2_xyxy]], # Only second image has box + input_boxes_labels=[None, [1]], + return_tensors="pt", + ).to(torch_device) + + with torch.no_grad(): + outputs = self.model(**inputs) + + # Check exact output shapes + self.assertEqual(outputs.pred_masks.shape, (2, 200, 288, 288)) + self.assertEqual(outputs.pred_boxes.shape, (2, 200, 4)) + self.assertEqual(outputs.pred_logits.shape, (2, 200)) + + # Check exact values + scores = torch.sigmoid(outputs.pred_logits) + sorted_indices_0 = torch.argsort(scores[0], descending=True) + sorted_indices_1 = torch.argsort(scores[1], descending=True) + top_scores_0 = scores[0][sorted_indices_0[:3]] + top_scores_1 = scores[1][sorted_indices_1[:3]] + top_logits_0 = outputs.pred_logits[0][sorted_indices_0[:3]] + top_logits_1 = outputs.pred_logits[1][sorted_indices_1[:3]] + top_idx_0 = sorted_indices_0[0].item() + top_idx_1 = sorted_indices_1[0].item() + + # Top-1 score/logit is always stable; positions 2-3 can swap when scores are very close + torch.testing.assert_close(top_scores_0[0], torch.tensor(0.9696).to(torch_device), atol=1e-4, rtol=1e-4) + torch.testing.assert_close(top_scores_1[0], torch.tensor(0.9696).to(torch_device), atol=1e-4, rtol=1e-4) + torch.testing.assert_close(top_logits_0[0], torch.tensor(3.4640).to(torch_device), atol=1e-4, rtol=1e-4) + torch.testing.assert_close(top_logits_1[0], torch.tensor(3.4608).to(torch_device), atol=1e-4, rtol=1e-4) + # Positions 2-3 use relaxed tolerance since their scores are very close (~0.16 and ~0.07) + torch.testing.assert_close( + top_scores_0[1:], torch.tensor([0.1615, 0.0683]).to(torch_device), atol=1e-2, rtol=1e-2 + ) + torch.testing.assert_close( + top_scores_1[1:], torch.tensor([0.8302, 0.8153]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + top_logits_1[1:], torch.tensor([1.5873, 1.4849]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + outputs.pred_boxes[0, top_idx_0], + torch.tensor([-0.0012, 0.0017, 0.4518, 0.9965]).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + torch.testing.assert_close( + outputs.pred_boxes[1, top_idx_1], + torch.tensor([0.1775, 0.2877, 0.2297, 0.3261]).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + torch.testing.assert_close( + outputs.pred_masks[0, top_idx_0, :3, :3], + torch.tensor([[0.0345, 0.2871, 0.3752], [0.6577, 0.9573, 1.0289], [0.8247, 0.9766, 0.9590]]).to( + torch_device + ), + atol=1e-2, + rtol=1e-2, + ) + torch.testing.assert_close( + outputs.pred_masks[1, top_idx_1, :3, :3], + torch.tensor( + [[-9.2271, -14.6975, -18.0719], [-14.1411, -21.1871, -26.6270], [-15.6623, -23.1574, -18.0739]] + ).to(torch_device), + atol=1e-2, + rtol=1e-2, + ) + + # Test post-processing + results = self.processor.post_process_instance_segmentation( + outputs, threshold=0.3, mask_threshold=0.5, target_sizes=inputs.get("original_sizes").tolist() + ) + self.assertEqual(len(results), 2) + + # Check that both have detections + self.assertGreater(len(results[0]["masks"]), 0) + self.assertGreater(len(results[1]["masks"]), 0) + + # Check exact values for top detection in each image + top_pp_score_0 = results[0]["scores"][0] + top_pp_box_0 = results[0]["boxes"][0] + top_pp_score_1 = results[1]["scores"][0] + top_pp_box_1 = results[1]["boxes"][0] + + torch.testing.assert_close(top_pp_score_0, torch.tensor(0.9556).to(torch_device), atol=1e-4, rtol=1e-4) + torch.testing.assert_close( + top_pp_box_0, torch.tensor([-0.7773, 0.7140, 289.1797, 423.5321]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close(top_pp_score_1, torch.tensor(0.8153).to(torch_device), atol=1e-4, rtol=1e-4) + torch.testing.assert_close( + top_pp_box_1, torch.tensor([168.9672, 137.3469, 191.7236, 161.3282]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + + def test_semantic_segmentation_output(self): + """Test that semantic segmentation output is produced.""" + raw_image = prepare_coco_cat_image() + inputs = self.processor(images=raw_image, text="ear", return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = self.model(**inputs) + + # Check exact semantic segmentation output shape + self.assertEqual(outputs.semantic_seg.shape, (1, 1, 288, 288)) + # Check that semantic seg has same spatial size as pred_masks + self.assertEqual(outputs.semantic_seg.shape[-2:], outputs.pred_masks.shape[-2:]) + + @require_deterministic_for_xpu + def test_efficient_multi_prompt_single_image(self): + """Test efficient inference with multiple prompts on a single image using get_vision_features.""" + raw_image = prepare_coco_cat_image() + + # Pre-compute vision embeddings once + img_inputs = self.processor(images=raw_image, return_tensors="pt").to(torch_device) + with torch.no_grad(): + vision_embeds = self.model.get_vision_features(pixel_values=img_inputs.pixel_values) + + # Run multiple text prompts efficiently + text_prompts = ["ear", "eye"] + all_results = [] + + for prompt in text_prompts: + text_inputs = self.processor(text=prompt, return_tensors="pt").to(torch_device) + with torch.no_grad(): + outputs = self.model(vision_embeds=vision_embeds, **text_inputs) + + results = self.processor.post_process_instance_segmentation( + outputs, + threshold=0.5, + mask_threshold=0.5, + target_sizes=img_inputs.get("original_sizes").tolist(), + )[0] + all_results.append(results) + + # Check that we get results for both prompts + self.assertEqual(len(all_results), 2) + + # Verify outputs are equivalent to running with pixel_values directly + text_inputs = self.processor(text="ear", return_tensors="pt").to(torch_device) + with torch.no_grad(): + outputs_with_embeds = self.model(vision_embeds=vision_embeds, **text_inputs) + + inputs_direct = self.processor(images=raw_image, text="ear", return_tensors="pt").to(torch_device) + with torch.no_grad(): + outputs_direct = self.model(**inputs_direct) + + # Outputs should be identical + torch.testing.assert_close(outputs_with_embeds.pred_logits, outputs_direct.pred_logits, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(outputs_with_embeds.pred_boxes, outputs_direct.pred_boxes, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(outputs_with_embeds.pred_masks, outputs_direct.pred_masks, atol=1e-5, rtol=1e-5) + + @require_deterministic_for_xpu + def test_efficient_single_prompt_multi_images(self): + """Test efficient inference with same prompt on multiple images using get_text_features.""" + raw_image1 = prepare_coco_cat_image() + raw_image2 = prepare_coco_kitchen_image() + + # Pre-compute text embeddings once + text_prompt = "handle" + text_inputs = self.processor(text=text_prompt, return_tensors="pt").to(torch_device) + with torch.no_grad(): + text_embeds = self.model.get_text_features(**text_inputs).pooler_output + + # Run inference on multiple images reusing text embeddings + # Note: attention_mask must be passed along with text_embeds for proper masking + images = [raw_image1, raw_image2] + all_results = [] + + for image in images: + img_inputs = self.processor(images=image, return_tensors="pt").to(torch_device) + with torch.no_grad(): + outputs = self.model( + text_embeds=text_embeds, + attention_mask=text_inputs.attention_mask, + **img_inputs, + ) + + results = self.processor.post_process_instance_segmentation( + outputs, + threshold=0.5, + mask_threshold=0.5, + target_sizes=img_inputs.get("original_sizes").tolist(), + )[0] + all_results.append(results) + + # Check that we get results for both images + self.assertEqual(len(all_results), 2) + + # Verify outputs are equivalent to running with input_ids directly + img_inputs = self.processor(images=raw_image2, return_tensors="pt").to(torch_device) + with torch.no_grad(): + outputs_with_embeds = self.model( + text_embeds=text_embeds, + attention_mask=text_inputs.attention_mask, + **img_inputs, + ) + + inputs_direct = self.processor(images=raw_image2, text=text_prompt, return_tensors="pt").to(torch_device) + with torch.no_grad(): + outputs_direct = self.model(**inputs_direct) + + # Outputs should be identical + torch.testing.assert_close(outputs_with_embeds.pred_logits, outputs_direct.pred_logits, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(outputs_with_embeds.pred_boxes, outputs_direct.pred_boxes, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(outputs_with_embeds.pred_masks, outputs_direct.pred_masks, atol=1e-5, rtol=1e-5) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index c987d506f319..8d78898dc4cf 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -83,6 +83,8 @@ "AutoformerConfig": ["num_static_real_features", "num_time_features"], "SamVisionConfig": ["mlp_ratio"], "Sam3VisionConfig": ["backbone_feature_sizes"], + "Sam3LiteTextViTConfig": ["global_attn_indexes", "window_size"], + "Sam3LiteTextVisionConfig": ["fpn_hidden_size", "scale_factors"], "SamHQVisionConfig": ["mlp_ratio"], "ClapAudioConfig": ["num_classes"], "ClvpDecoderConfig": ["add_cross_attention"], diff --git a/utils/check_repo.py b/utils/check_repo.py index b1a3d158c716..4a464be9db41 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -272,6 +272,7 @@ "UVDocBridge", # Building part of a bigger model, tested implicitly through UVDocModel "Gemma4VisionModel", # Building part of a bigger model, tested implicitly "Gemma4AudioModel", # Building part of a bigger model, tested implicitly + "Sam3LiteTextTextModel", # Building part of a bigger model, tested implicitly through Sam3LiteTextModel ] ) diff --git a/utils/fetch_hub_objects_for_ci.py b/utils/fetch_hub_objects_for_ci.py index 59cf65117913..dba74dc418d7 100644 --- a/utils/fetch_hub_objects_for_ci.py +++ b/utils/fetch_hub_objects_for_ci.py @@ -18,6 +18,8 @@ "http://images.cocodataset.org/val2017/000000000802.jpg", "http://images.cocodataset.org/val2017/000000000872.jpg", "http://images.cocodataset.org/val2017/000000039769.jpg", + "http://images.cocodataset.org/val2017/000000077595.jpg", + "http://images.cocodataset.org/val2017/000000136466.jpg", "https://www.ilankelman.org/stopsigns/australia.jpg", "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg", "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",