Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,98 +25,7 @@
from ..auto import CONFIG_MAPPING, AutoConfig


@auto_docstring(checkpoint="facebook/sam3_lite_text")
@strict
class Sam3LiteTextViTConfig(PreTrainedConfig):
r"""
rope_theta (`float`, *optional*, defaults to 10000.0):
Base frequency for RoPE.
window_size (`int`, *optional*, defaults to 24):
Window size for windowed attention.
global_attn_indexes (`list[int]`, *optional*, defaults to `[7, 15, 23, 31]`):
Indexes of layers with global attention.
pretrain_image_size (`int`, *optional*, defaults to 336):
Pretrained model image size for position embedding initialization.
hidden_dropout (`float`, *optional*, defaults to 0.0):
Dropout probability for hidden states.
"""

base_config_key = "backbone_config"
model_type = "sam3_vit_model"

hidden_size: int = 1024
intermediate_size: int = 4736
num_hidden_layers: int = 32
num_attention_heads: int = 16
num_channels: int = 3
image_size: int | list[int] | tuple[int, int] = 1008
patch_size: int | list[int] | tuple[int, int] = 14
hidden_act: str = "gelu"
layer_norm_eps: float = 1e-6
attention_dropout: float | int = 0.0
rope_theta: float = 10000.0
window_size: int = 24
global_attn_indexes: list[int] | None = None
layer_scale_init_value: float | None = None
pretrain_image_size: int | list[int] | tuple[int, int] = 336
hidden_dropout: float | int = 0.0
initializer_range: float = 0.02

def __post_init__(self, **kwargs):
super().__post_init__(**kwargs)
if self.global_attn_indexes is None:
self.global_attn_indexes = [7, 15, 23, 31]


@auto_docstring(checkpoint="facebook/sam3_lite_text")
@strict
class Sam3LiteTextVisionConfig(PreTrainedConfig):
r"""
fpn_hidden_size (`int`, *optional*, defaults to 256):
The hidden dimension of the FPN.
backbone_feature_sizes (`List[List[int]]`, *optional*, defaults to `[[288, 288], [144, 144], [72, 72]]`):
The spatial sizes (height, width) of the feature maps from the backbone at different scales.
scale_factors (`list[float]`, *optional*, defaults to `[4.0, 2.0, 1.0, 0.5]`):
Scale factors for FPN multi-scale features. List of scaling factors for each FPN level.
"""

base_config_key = "vision_config"
model_type = "sam3_vision_model"
sub_configs = {"backbone_config": AutoConfig}

backbone_config: dict | PreTrainedConfig | None = None
fpn_hidden_size: int = 256
backbone_feature_sizes: list | None = None
scale_factors: list[float] | None = None
hidden_act: str = "gelu"
layer_norm_eps: float = 1e-6
initializer_range: float = 0.02

def __post_init__(self, **kwargs):
self.scale_factors = [4.0, 2.0, 1.0, 0.5] if self.scale_factors is None else self.scale_factors
if self.backbone_feature_sizes is None:
self.backbone_feature_sizes = [[288, 288], [144, 144], [72, 72]]

if isinstance(self.backbone_config, dict):
self.backbone_config["model_type"] = self.backbone_config.get("model_type", "sam3_vit_model")
self.backbone_config = CONFIG_MAPPING[self.backbone_config["model_type"]](**self.backbone_config)
elif self.backbone_config is None:
self.backbone_config = CONFIG_MAPPING["sam3_vit_model"]()

super().__post_init__(**kwargs)

@property
def image_size(self):
"""Image size for the vision encoder."""
return self.backbone_config.image_size

@image_size.setter
def image_size(self, value):
"""Set the image size and propagate to backbone."""
self.backbone_config.image_size = value


@auto_docstring(checkpoint="facebook/sam3_lite_text")
@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0")
@strict
class Sam3LiteTextGeometryEncoderConfig(PreTrainedConfig):
r"""
Expand All @@ -138,7 +47,7 @@ class Sam3LiteTextGeometryEncoderConfig(PreTrainedConfig):
initializer_range: float = 0.02


@auto_docstring(checkpoint="facebook/sam3_lite_text")
@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0")
@strict
class Sam3LiteTextDETREncoderConfig(PreTrainedConfig):
r"""
Expand All @@ -159,7 +68,7 @@ class Sam3LiteTextDETREncoderConfig(PreTrainedConfig):
initializer_range: float = 0.02


@auto_docstring(checkpoint="facebook/sam3_lite_text")
@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0")
@strict
class Sam3LiteTextDETRDecoderConfig(PreTrainedConfig):
r"""
Expand All @@ -181,7 +90,7 @@ class Sam3LiteTextDETRDecoderConfig(PreTrainedConfig):
initializer_range: float = 0.02


@auto_docstring(checkpoint="facebook/sam3_lite_text")
@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0")
@strict
class Sam3LiteTextMaskDecoderConfig(PreTrainedConfig):
r"""
Expand Down Expand Up @@ -229,7 +138,7 @@ class Sam3LiteTextTextConfig(PreTrainedConfig):
repmixer_kernel_size: int = 11


@auto_docstring(checkpoint="facebook/sam3_lite_text")
@auto_docstring(checkpoint="yonigozlan/sam3-litetext-s0")
@strict
class Sam3LiteTextConfig(PreTrainedConfig):
r"""
Expand Down
152 changes: 1 addition & 151 deletions src/transformers/models/sam3_lite_text/modeling_sam3_lite_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# limitations under the License.

import math
from collections.abc import Callable, Iterable
from collections.abc import Callable
from dataclasses import dataclass

import numpy as np
Expand Down Expand Up @@ -47,7 +47,6 @@
Sam3LiteTextGeometryEncoderConfig,
Sam3LiteTextMaskDecoderConfig,
Sam3LiteTextTextConfig,
Sam3LiteTextViTConfig,
)


Expand Down Expand Up @@ -341,140 +340,6 @@ def forward(self, input_ids: torch.LongTensor) -> torch.Tensor:
return hidden_states


class Sam3LiteTextViTRotaryEmbedding(nn.Module):
"""
Vision Rotary Position Embedding for SAM3_LITE_TEXT, following transformers library standards.
Supports 2D (axial) rotary embeddings for spatial dimensions.
"""

def __init__(self, config: Sam3LiteTextViTConfig, end_x: int, end_y: int, scale: float = 1.0):
super().__init__()
dim = config.hidden_size // config.num_attention_heads
# Ensure even dimension for proper axial splitting
if dim % 4 != 0:
raise ValueError("Dimension must be divisible by 4 for axial RoPE")
self.end_x, self.end_y = end_x, end_y
self.dim = dim
self.rope_theta = config.rope_theta
self.scale = scale
freqs = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))

flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
x_positions = (flattened_indices % end_x) * scale
y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * scale
freqs_x = torch.outer(x_positions, freqs).float()
freqs_y = torch.outer(y_positions, freqs).float()
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
# directly register the cos and sin embeddings as we have a fixed feature shape
self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)

@torch.no_grad()
def forward(self) -> tuple[torch.Tensor, torch.Tensor]:
# As the feature map size is fixed for each stage, we can just return the pre-computed embeddings.
return self.rope_embeddings_cos, self.rope_embeddings_sin


class Sam3LiteTextViTPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""

def __init__(self, config: Sam3LiteTextViTConfig):
super().__init__()
image_size, patch_size = config.pretrain_image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size

image_size = image_size if isinstance(image_size, Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches

self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=False)

def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)).flatten(2).transpose(1, 2)
return embeddings


class Sam3LiteTextViTEmbeddings(nn.Module):
"""
Construct the patch embeddings and position embeddings for SAM3_LITE_TEXT ViT.

Position embeddings are tiled (not interpolated) when resizing to match different input sizes.
"""

def __init__(self, config: Sam3LiteTextViTConfig):
super().__init__()

self.patch_embeddings = Sam3LiteTextViTPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(
torch.randn(1, num_patches, config.hidden_size)
) # !Remove cls token in convert weights!

self.dropout = nn.Dropout(config.hidden_dropout)
self.patch_size = config.patch_size

def _tile_position_embeddings(
self,
position_embeddings: torch.Tensor,
height: int,
width: int,
) -> torch.Tensor:
"""
Tile position embeddings to match target spatial dimensions.
Args:
position_embeddings: Shape [1, num_pretrain_patches, hidden_size]
height: Target height in patches
width: Target width in patches

Returns:
Shape [1, height * width, hidden_size]
"""
pretrain_size = int(position_embeddings.shape[1] ** 0.5)

# Skip tiling if sizes match (but always tile during tracing for consistent graph)
if not torch.jit.is_tracing() and pretrain_size == height and pretrain_size == width:
return position_embeddings.reshape(1, height * width, -1)

# Tile position embeddings to match target spatial dimensions
hidden_size = position_embeddings.shape[-1]
pos_embed = position_embeddings.reshape(1, pretrain_size, pretrain_size, hidden_size).permute(0, 3, 1, 2)
repeat_h = height // pretrain_size + 1
repeat_w = width // pretrain_size + 1
pos_embed = pos_embed.tile([1, 1, repeat_h, repeat_w])[:, :, :height, :width]
return pos_embed.permute(0, 2, 3, 1).reshape(1, height * width, hidden_size)

def forward(
self,
pixel_values: torch.Tensor,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
height, width = pixel_values.shape[-2:]
embeddings = self.patch_embeddings(pixel_values)

# Calculate spatial dimensions in patches
height_patches = height // self.patch_size
width_patches = width // self.patch_size

position_embeddings = self._tile_position_embeddings(
self.position_embeddings,
height_patches,
width_patches,
)
embeddings = embeddings + position_embeddings
embeddings = self.dropout(embeddings)

return embeddings


@auto_docstring
class Sam3LiteTextPreTrainedModel(PreTrainedModel):
config_class = Sam3LiteTextConfig
Expand All @@ -490,21 +355,6 @@ class Sam3LiteTextPreTrainedModel(PreTrainedModel):
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, Sam3LiteTextViTEmbeddings):
init.normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
elif isinstance(module, Sam3LiteTextViTRotaryEmbedding):
end_x, end_y = module.end_x, module.end_y
dim = module.dim
freqs = 1.0 / (module.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
x_positions = (flattened_indices % end_x) * module.scale
y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * module.scale
freqs_x = torch.outer(x_positions, freqs).float()
freqs_y = torch.outer(y_positions, freqs).float()
inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
inv_freq = inv_freq.repeat_interleave(2, dim=-1)
init.copy_(module.rope_embeddings_cos, inv_freq.cos())
init.copy_(module.rope_embeddings_sin, inv_freq.sin())
if isinstance(module, Sam3LiteTextTextPositionEmbedding):
init.normal_(module.position_embedding, std=module.position_embedding.shape[-1] ** -0.5)
elif isinstance(module, Sam3LiteTextTextModel):
Expand Down
Loading
Loading