From aa2dc53929d7ecbd526adf494221b31a16e2ff18 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Mon, 20 Apr 2026 23:14:35 +0000 Subject: [PATCH 1/2] fix sam3 lite text --- .../configuration_sam3_lite_text.py | 101 +----------- .../sam3_lite_text/modeling_sam3_lite_text.py | 153 +----------------- .../sam3_lite_text/modular_sam3_lite_text.py | 104 +----------- .../test_modeling_sam3_lite_text.py | 7 + utils/check_config_attributes.py | 2 - 5 files changed, 20 insertions(+), 347 deletions(-) diff --git a/src/transformers/models/sam3_lite_text/configuration_sam3_lite_text.py b/src/transformers/models/sam3_lite_text/configuration_sam3_lite_text.py index 696751075611..f77fa99677f4 100644 --- a/src/transformers/models/sam3_lite_text/configuration_sam3_lite_text.py +++ b/src/transformers/models/sam3_lite_text/configuration_sam3_lite_text.py @@ -25,98 +25,7 @@ from ..auto import CONFIG_MAPPING, AutoConfig -@auto_docstring(checkpoint="facebook/sam3_lite_text") -@strict -class Sam3LiteTextViTConfig(PreTrainedConfig): - r""" - rope_theta (`float`, *optional*, defaults to 10000.0): - Base frequency for RoPE. - window_size (`int`, *optional*, defaults to 24): - Window size for windowed attention. - global_attn_indexes (`list[int]`, *optional*, defaults to `[7, 15, 23, 31]`): - Indexes of layers with global attention. - pretrain_image_size (`int`, *optional*, defaults to 336): - Pretrained model image size for position embedding initialization. - hidden_dropout (`float`, *optional*, defaults to 0.0): - Dropout probability for hidden states. - """ - - base_config_key = "backbone_config" - model_type = "sam3_vit_model" - - hidden_size: int = 1024 - intermediate_size: int = 4736 - num_hidden_layers: int = 32 - num_attention_heads: int = 16 - num_channels: int = 3 - image_size: int | list[int] | tuple[int, int] = 1008 - patch_size: int | list[int] | tuple[int, int] = 14 - hidden_act: str = "gelu" - layer_norm_eps: float = 1e-6 - attention_dropout: float | int = 0.0 - rope_theta: float = 10000.0 - window_size: int = 24 - global_attn_indexes: list[int] | None = None - layer_scale_init_value: float | None = None - pretrain_image_size: int | list[int] | tuple[int, int] = 336 - hidden_dropout: float | int = 0.0 - initializer_range: float = 0.02 - - def __post_init__(self, **kwargs): - super().__post_init__(**kwargs) - if self.global_attn_indexes is None: - self.global_attn_indexes = [7, 15, 23, 31] - - -@auto_docstring(checkpoint="facebook/sam3_lite_text") -@strict -class Sam3LiteTextVisionConfig(PreTrainedConfig): - r""" - fpn_hidden_size (`int`, *optional*, defaults to 256): - The hidden dimension of the FPN. - backbone_feature_sizes (`List[List[int]]`, *optional*, defaults to `[[288, 288], [144, 144], [72, 72]]`): - The spatial sizes (height, width) of the feature maps from the backbone at different scales. - scale_factors (`list[float]`, *optional*, defaults to `[4.0, 2.0, 1.0, 0.5]`): - Scale factors for FPN multi-scale features. List of scaling factors for each FPN level. - """ - - base_config_key = "vision_config" - model_type = "sam3_vision_model" - sub_configs = {"backbone_config": AutoConfig} - - backbone_config: dict | PreTrainedConfig | None = None - fpn_hidden_size: int = 256 - backbone_feature_sizes: list | None = None - scale_factors: list[float] | None = None - hidden_act: str = "gelu" - layer_norm_eps: float = 1e-6 - initializer_range: float = 0.02 - - def __post_init__(self, **kwargs): - self.scale_factors = [4.0, 2.0, 1.0, 0.5] if self.scale_factors is None else self.scale_factors - if self.backbone_feature_sizes is None: - self.backbone_feature_sizes = [[288, 288], [144, 144], [72, 72]] - - if isinstance(self.backbone_config, dict): - self.backbone_config["model_type"] = self.backbone_config.get("model_type", "sam3_vit_model") - self.backbone_config = CONFIG_MAPPING[self.backbone_config["model_type"]](**self.backbone_config) - elif self.backbone_config is None: - self.backbone_config = CONFIG_MAPPING["sam3_vit_model"]() - - super().__post_init__(**kwargs) - - @property - def image_size(self): - """Image size for the vision encoder.""" - return self.backbone_config.image_size - - @image_size.setter - def image_size(self, value): - """Set the image size and propagate to backbone.""" - self.backbone_config.image_size = value - - -@auto_docstring(checkpoint="facebook/sam3_lite_text") +@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0") @strict class Sam3LiteTextGeometryEncoderConfig(PreTrainedConfig): r""" @@ -138,7 +47,7 @@ class Sam3LiteTextGeometryEncoderConfig(PreTrainedConfig): initializer_range: float = 0.02 -@auto_docstring(checkpoint="facebook/sam3_lite_text") +@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0") @strict class Sam3LiteTextDETREncoderConfig(PreTrainedConfig): r""" @@ -159,7 +68,7 @@ class Sam3LiteTextDETREncoderConfig(PreTrainedConfig): initializer_range: float = 0.02 -@auto_docstring(checkpoint="facebook/sam3_lite_text") +@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0") @strict class Sam3LiteTextDETRDecoderConfig(PreTrainedConfig): r""" @@ -181,7 +90,7 @@ class Sam3LiteTextDETRDecoderConfig(PreTrainedConfig): initializer_range: float = 0.02 -@auto_docstring(checkpoint="facebook/sam3_lite_text") +@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0") @strict class Sam3LiteTextMaskDecoderConfig(PreTrainedConfig): r""" @@ -229,7 +138,7 @@ class Sam3LiteTextTextConfig(PreTrainedConfig): repmixer_kernel_size: int = 11 -@auto_docstring(checkpoint="facebook/sam3_lite_text") +@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0") @strict class Sam3LiteTextConfig(PreTrainedConfig): r""" diff --git a/src/transformers/models/sam3_lite_text/modeling_sam3_lite_text.py b/src/transformers/models/sam3_lite_text/modeling_sam3_lite_text.py index 5a7b02880edd..da3cee6c9c90 100644 --- a/src/transformers/models/sam3_lite_text/modeling_sam3_lite_text.py +++ b/src/transformers/models/sam3_lite_text/modeling_sam3_lite_text.py @@ -19,7 +19,7 @@ # limitations under the License. import math -from collections.abc import Callable, Iterable +from collections.abc import Callable from dataclasses import dataclass import numpy as np @@ -47,7 +47,6 @@ Sam3LiteTextGeometryEncoderConfig, Sam3LiteTextMaskDecoderConfig, Sam3LiteTextTextConfig, - Sam3LiteTextViTConfig, ) @@ -341,140 +340,6 @@ def forward(self, input_ids: torch.LongTensor) -> torch.Tensor: return hidden_states -class Sam3LiteTextViTRotaryEmbedding(nn.Module): - """ - Vision Rotary Position Embedding for SAM3_LITE_TEXT, following transformers library standards. - Supports 2D (axial) rotary embeddings for spatial dimensions. - """ - - def __init__(self, config: Sam3LiteTextViTConfig, end_x: int, end_y: int, scale: float = 1.0): - super().__init__() - dim = config.hidden_size // config.num_attention_heads - # Ensure even dimension for proper axial splitting - if dim % 4 != 0: - raise ValueError("Dimension must be divisible by 4 for axial RoPE") - self.end_x, self.end_y = end_x, end_y - self.dim = dim - self.rope_theta = config.rope_theta - self.scale = scale - freqs = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) - - flattened_indices = torch.arange(end_x * end_y, dtype=torch.long) - x_positions = (flattened_indices % end_x) * scale - y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * scale - freqs_x = torch.outer(x_positions, freqs).float() - freqs_y = torch.outer(y_positions, freqs).float() - inv_freq = torch.cat([freqs_x, freqs_y], dim=-1) - inv_freq = inv_freq.repeat_interleave(2, dim=-1) - # directly register the cos and sin embeddings as we have a fixed feature shape - self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False) - self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False) - - @torch.no_grad() - def forward(self) -> tuple[torch.Tensor, torch.Tensor]: - # As the feature map size is fixed for each stage, we can just return the pre-computed embeddings. - return self.rope_embeddings_cos, self.rope_embeddings_sin - - -class Sam3LiteTextViTPatchEmbeddings(nn.Module): - """ - This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial - `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a - Transformer. - """ - - def __init__(self, config: Sam3LiteTextViTConfig): - super().__init__() - image_size, patch_size = config.pretrain_image_size, config.patch_size - num_channels, hidden_size = config.num_channels, config.hidden_size - - image_size = image_size if isinstance(image_size, Iterable) else (image_size, image_size) - patch_size = patch_size if isinstance(patch_size, Iterable) else (patch_size, patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_patches = num_patches - - self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=False) - - def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: - embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)).flatten(2).transpose(1, 2) - return embeddings - - -class Sam3LiteTextViTEmbeddings(nn.Module): - """ - Construct the patch embeddings and position embeddings for SAM3_LITE_TEXT ViT. - - Position embeddings are tiled (not interpolated) when resizing to match different input sizes. - """ - - def __init__(self, config: Sam3LiteTextViTConfig): - super().__init__() - - self.patch_embeddings = Sam3LiteTextViTPatchEmbeddings(config) - num_patches = self.patch_embeddings.num_patches - self.position_embeddings = nn.Parameter( - torch.randn(1, num_patches, config.hidden_size) - ) # !Remove cls token in convert weights! - - self.dropout = nn.Dropout(config.hidden_dropout) - self.patch_size = config.patch_size - - def _tile_position_embeddings( - self, - position_embeddings: torch.Tensor, - height: int, - width: int, - ) -> torch.Tensor: - """ - Tile position embeddings to match target spatial dimensions. - Args: - position_embeddings: Shape [1, num_pretrain_patches, hidden_size] - height: Target height in patches - width: Target width in patches - - Returns: - Shape [1, height * width, hidden_size] - """ - pretrain_size = int(position_embeddings.shape[1] ** 0.5) - - # Skip tiling if sizes match (but always tile during tracing for consistent graph) - if not torch.jit.is_tracing() and pretrain_size == height and pretrain_size == width: - return position_embeddings.reshape(1, height * width, -1) - - # Tile position embeddings to match target spatial dimensions - hidden_size = position_embeddings.shape[-1] - pos_embed = position_embeddings.reshape(1, pretrain_size, pretrain_size, hidden_size).permute(0, 3, 1, 2) - repeat_h = height // pretrain_size + 1 - repeat_w = width // pretrain_size + 1 - pos_embed = pos_embed.tile([1, 1, repeat_h, repeat_w])[:, :, :height, :width] - return pos_embed.permute(0, 2, 3, 1).reshape(1, height * width, hidden_size) - - def forward( - self, - pixel_values: torch.Tensor, - interpolate_pos_encoding: bool = False, - ) -> torch.Tensor: - height, width = pixel_values.shape[-2:] - embeddings = self.patch_embeddings(pixel_values) - - # Calculate spatial dimensions in patches - height_patches = height // self.patch_size - width_patches = width // self.patch_size - - position_embeddings = self._tile_position_embeddings( - self.position_embeddings, - height_patches, - width_patches, - ) - embeddings = embeddings + position_embeddings - embeddings = self.dropout(embeddings) - - return embeddings - - @auto_docstring class Sam3LiteTextPreTrainedModel(PreTrainedModel): config_class = Sam3LiteTextConfig @@ -490,21 +355,6 @@ class Sam3LiteTextPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) - if isinstance(module, Sam3LiteTextViTEmbeddings): - init.normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range) - elif isinstance(module, Sam3LiteTextViTRotaryEmbedding): - end_x, end_y = module.end_x, module.end_y - dim = module.dim - freqs = 1.0 / (module.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) - flattened_indices = torch.arange(end_x * end_y, dtype=torch.long) - x_positions = (flattened_indices % end_x) * module.scale - y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * module.scale - freqs_x = torch.outer(x_positions, freqs).float() - freqs_y = torch.outer(y_positions, freqs).float() - inv_freq = torch.cat([freqs_x, freqs_y], dim=-1) - inv_freq = inv_freq.repeat_interleave(2, dim=-1) - init.copy_(module.rope_embeddings_cos, inv_freq.cos()) - init.copy_(module.rope_embeddings_sin, inv_freq.sin()) if isinstance(module, Sam3LiteTextTextPositionEmbedding): init.normal_(module.position_embedding, std=module.position_embedding.shape[-1] ** -0.5) elif isinstance(module, Sam3LiteTextTextModel): @@ -2043,7 +1893,6 @@ class Sam3LiteTextModel(Sam3LiteTextPreTrainedModel): r"^tracker_model.", r"^tracker_neck.", ] - # DETR components create float masks from features, so flash/flex attention cannot be dispatched safely. _supports_flash_attn = False _supports_flex_attn = False diff --git a/src/transformers/models/sam3_lite_text/modular_sam3_lite_text.py b/src/transformers/models/sam3_lite_text/modular_sam3_lite_text.py index 46408464a333..4e830a6d5ec3 100644 --- a/src/transformers/models/sam3_lite_text/modular_sam3_lite_text.py +++ b/src/transformers/models/sam3_lite_text/modular_sam3_lite_text.py @@ -24,6 +24,7 @@ from ...configuration_utils import PreTrainedConfig from ...masking_utils import create_bidirectional_mask from ...modeling_outputs import BaseModelOutputWithPooling +from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring from ...utils.generic import TransformersKwargs, merge_with_config_defaults @@ -39,116 +40,25 @@ from ..siglip.modeling_siglip import SiglipAttention, SiglipEncoderLayer, SiglipMLP -@auto_docstring(checkpoint="facebook/sam3_lite_text") -@strict -class Sam3LiteTextViTConfig(PreTrainedConfig): - r""" - rope_theta (`float`, *optional*, defaults to 10000.0): - Base frequency for RoPE. - window_size (`int`, *optional*, defaults to 24): - Window size for windowed attention. - global_attn_indexes (`list[int]`, *optional*, defaults to `[7, 15, 23, 31]`): - Indexes of layers with global attention. - pretrain_image_size (`int`, *optional*, defaults to 336): - Pretrained model image size for position embedding initialization. - hidden_dropout (`float`, *optional*, defaults to 0.0): - Dropout probability for hidden states. - """ - - base_config_key = "backbone_config" - model_type = "sam3_vit_model" - - hidden_size: int = 1024 - intermediate_size: int = 4736 - num_hidden_layers: int = 32 - num_attention_heads: int = 16 - num_channels: int = 3 - image_size: int | list[int] | tuple[int, int] = 1008 - patch_size: int | list[int] | tuple[int, int] = 14 - hidden_act: str = "gelu" - layer_norm_eps: float = 1e-6 - attention_dropout: float | int = 0.0 - rope_theta: float = 10000.0 - window_size: int = 24 - global_attn_indexes: list[int] | None = None - layer_scale_init_value: float | None = None - pretrain_image_size: int | list[int] | tuple[int, int] = 336 - hidden_dropout: float | int = 0.0 - initializer_range: float = 0.02 - - def __post_init__(self, **kwargs): - super().__post_init__(**kwargs) - if self.global_attn_indexes is None: - self.global_attn_indexes = [7, 15, 23, 31] - - -@auto_docstring(checkpoint="facebook/sam3_lite_text") -@strict -class Sam3LiteTextVisionConfig(PreTrainedConfig): - r""" - fpn_hidden_size (`int`, *optional*, defaults to 256): - The hidden dimension of the FPN. - backbone_feature_sizes (`List[List[int]]`, *optional*, defaults to `[[288, 288], [144, 144], [72, 72]]`): - The spatial sizes (height, width) of the feature maps from the backbone at different scales. - scale_factors (`list[float]`, *optional*, defaults to `[4.0, 2.0, 1.0, 0.5]`): - Scale factors for FPN multi-scale features. List of scaling factors for each FPN level. - """ - - base_config_key = "vision_config" - model_type = "sam3_vision_model" - sub_configs = {"backbone_config": AutoConfig} - - backbone_config: dict | PreTrainedConfig | None = None - fpn_hidden_size: int = 256 - backbone_feature_sizes: list | None = None - scale_factors: list[float] | None = None - hidden_act: str = "gelu" - layer_norm_eps: float = 1e-6 - initializer_range: float = 0.02 - - def __post_init__(self, **kwargs): - self.scale_factors = [4.0, 2.0, 1.0, 0.5] if self.scale_factors is None else self.scale_factors - if self.backbone_feature_sizes is None: - self.backbone_feature_sizes = [[288, 288], [144, 144], [72, 72]] - - if isinstance(self.backbone_config, dict): - self.backbone_config["model_type"] = self.backbone_config.get("model_type", "sam3_vit_model") - self.backbone_config = CONFIG_MAPPING[self.backbone_config["model_type"]](**self.backbone_config) - elif self.backbone_config is None: - self.backbone_config = CONFIG_MAPPING["sam3_vit_model"]() - - super().__post_init__(**kwargs) - - @property - def image_size(self): - """Image size for the vision encoder.""" - return self.backbone_config.image_size - - @image_size.setter - def image_size(self, value): - """Set the image size and propagate to backbone.""" - self.backbone_config.image_size = value - - -@auto_docstring(checkpoint="facebook/sam3_lite_text") +@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0") @strict class Sam3LiteTextGeometryEncoderConfig(Sam3GeometryEncoderConfig): pass -@auto_docstring(checkpoint="facebook/sam3_lite_text") +@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0") @strict class Sam3LiteTextDETREncoderConfig(Sam3DETREncoderConfig): pass -@auto_docstring(checkpoint="facebook/sam3_lite_text") +@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0") @strict class Sam3LiteTextDETRDecoderConfig(Sam3DETRDecoderConfig): pass -@auto_docstring(checkpoint="facebook/sam3_lite_text") +@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0") @strict class Sam3LiteTextMaskDecoderConfig(Sam3MaskDecoderConfig): pass @@ -184,7 +94,7 @@ class Sam3LiteTextTextConfig(PreTrainedConfig): repmixer_kernel_size: int = 11 -@auto_docstring(checkpoint="facebook/sam3_lite_text") +@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0") @strict class Sam3LiteTextConfig(PreTrainedConfig): r""" @@ -444,7 +354,7 @@ class Sam3LiteTextPreTrainedModel(Sam3PreTrainedModel): @torch.no_grad() def _init_weights(self, module): - super()._init_weights(module) + PreTrainedModel._init_weights(module) if isinstance(module, Sam3LiteTextTextPositionEmbedding): init.normal_(module.position_embedding, std=module.position_embedding.shape[-1] ** -0.5) elif isinstance(module, Sam3LiteTextTextModel): diff --git a/tests/models/sam3_lite_text/test_modeling_sam3_lite_text.py b/tests/models/sam3_lite_text/test_modeling_sam3_lite_text.py index 05a9307bfa87..c9b0766dc5d1 100644 --- a/tests/models/sam3_lite_text/test_modeling_sam3_lite_text.py +++ b/tests/models/sam3_lite_text/test_modeling_sam3_lite_text.py @@ -661,6 +661,13 @@ def test_sdpa_can_compile_dynamic(self): def test_sdpa_can_dispatch_on_flash(self): pass + @unittest.skip( + reason="Sam3LiteTextModel creates float attention masks from features (with gradients) in the DETR " + "encoder/decoder, which Flash Attention requires to be None." + ) + def test_flash_attn_2_can_dispatch_composite_models(self): + pass + def test_model_outputs_equivalence(self): """ Test that tuple and dict outputs are equivalent. diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index b0496b38a10f..65a202cd3f9a 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -84,8 +84,6 @@ "AutoformerConfig": ["num_static_real_features", "num_time_features"], "SamVisionConfig": ["mlp_ratio"], "Sam3VisionConfig": ["backbone_feature_sizes"], - "Sam3LiteTextViTConfig": ["global_attn_indexes", "window_size"], - "Sam3LiteTextVisionConfig": ["fpn_hidden_size", "scale_factors"], "SamHQVisionConfig": ["mlp_ratio"], "ClapAudioConfig": ["num_classes"], "ClvpDecoderConfig": ["add_cross_attention"], From 79142e24bf1d8d2a690bb8f190a51daf46e50f13 Mon Sep 17 00:00:00 2001 From: vasqu Date: Tue, 21 Apr 2026 14:08:00 +0200 Subject: [PATCH 2/2] sync modular --- .../models/sam3_lite_text/modeling_sam3_lite_text.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/sam3_lite_text/modeling_sam3_lite_text.py b/src/transformers/models/sam3_lite_text/modeling_sam3_lite_text.py index da3cee6c9c90..05a28e4bea2a 100644 --- a/src/transformers/models/sam3_lite_text/modeling_sam3_lite_text.py +++ b/src/transformers/models/sam3_lite_text/modeling_sam3_lite_text.py @@ -1893,6 +1893,7 @@ class Sam3LiteTextModel(Sam3LiteTextPreTrainedModel): r"^tracker_model.", r"^tracker_neck.", ] + # DETR components create float masks from features, so flash/flex attention cannot be dispatched safely. _supports_flash_attn = False _supports_flex_attn = False