Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ def forward(
):
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
for level_id, (height, width) in enumerate(value_spatial_shapes_list):
# batch_size, height*width, num_heads, hidden_dim
# -> batch_size, height*width, num_heads*hidden_dim
# -> batch_size, num_heads*hidden_dim, height*width
Expand Down
30 changes: 24 additions & 6 deletions src/transformers/models/grounding_dino/modeling_grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
_CHECKPOINT_FOR_DOC = "IDEA-Research/grounding-dino-tiny"


# Copied from models.deformable_detr.MultiScaleDeformableAttention
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttention
class MultiScaleDeformableAttention(nn.Module):
def forward(
self,
Expand All @@ -66,10 +66,10 @@ def forward(
):
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
for level_id, (height, width) in enumerate(value_spatial_shapes_list):
# batch_size, height*width, num_heads, hidden_dim
# -> batch_size, height*width, num_heads*hidden_dim
# -> batch_size, num_heads*hidden_dim, height*width
Expand Down Expand Up @@ -1015,6 +1015,7 @@ def forward(
position_embeddings: torch.Tensor = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
output_attentions: bool = False,
):
Expand All @@ -1030,6 +1031,8 @@ def forward(
Reference points.
spatial_shapes (`torch.LongTensor`, *optional*):
Spatial shapes of the backbone feature maps.
spatial_shapes_list (`List[Tuple[int, int]]`, *optional*):
Spatial shapes of the backbone feature maps (but as list for export compatibility).
level_start_index (`torch.LongTensor`, *optional*):
Level start index.
output_attentions (`bool`, *optional*):
Expand All @@ -1047,6 +1050,7 @@ def forward(
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -1147,6 +1151,7 @@ def forward(
vision_features: Tensor,
vision_position_embedding: Tensor,
spatial_shapes: Tensor,
spatial_shapes_list: List[Tuple[int, int]],
level_start_index: Tensor,
key_padding_mask: Tensor,
reference_points: Tensor,
Expand Down Expand Up @@ -1179,6 +1184,7 @@ def forward(
position_embeddings=vision_position_embedding,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
)

Expand Down Expand Up @@ -1295,6 +1301,7 @@ def forward(
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
vision_encoder_hidden_states: Optional[torch.Tensor] = None,
vision_encoder_attention_mask: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -1347,6 +1354,7 @@ def forward(
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -1594,6 +1602,7 @@ def forward(
vision_attention_mask: Tensor,
vision_position_embedding: Tensor,
spatial_shapes: Tensor,
spatial_shapes_list: List[Tuple[int, int]],
level_start_index: Tensor,
valid_ratios=None,
text_features: Optional[Tensor] = None,
Expand All @@ -1618,6 +1627,8 @@ def forward(
Position embeddings that are added to the queries and keys in each self-attention layer.
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
Spatial shapes of each feature map.
spatial_shapes_list (`List[Tuple[int, int]]`):
Spatial shapes of each feature map (but as list for export compatibility).
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
Starting index of each feature map.
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
Expand Down Expand Up @@ -1670,6 +1681,7 @@ def forward(
vision_features=vision_features,
vision_position_embedding=vision_position_embedding,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
key_padding_mask=vision_attention_mask,
reference_points=reference_points,
Expand Down Expand Up @@ -1748,6 +1760,7 @@ def forward(
text_encoder_attention_mask=None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
valid_ratios=None,
self_attn_mask=None,
Expand Down Expand Up @@ -1775,6 +1788,8 @@ def forward(
Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
Spatial shapes of the feature maps.
spatial_shapes_list (`List[Tuple[int, int]]`):
Spatial shapes of the feature maps (but as list for export compatibility).
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*):
Indexes for the start of each feature level. In range `[0, sequence_length]`.
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
Expand Down Expand Up @@ -1867,6 +1882,7 @@ def custom_forward(*inputs):
position_embeddings=query_pos,
reference_points=reference_points_input,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
vision_encoder_hidden_states=vision_encoder_hidden_states,
vision_encoder_attention_mask=vision_encoder_attention_mask,
Expand Down Expand Up @@ -2248,11 +2264,11 @@ def forward(
source_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
spatial_shapes_list = []
for level, (source, mask, pos_embed) in enumerate(zip(feature_maps, masks, position_embeddings_list)):
batch_size, num_channels, height, width = source.shape
spatial_shape = (height, width)
spatial_shapes.append(spatial_shape)
spatial_shapes_list.append(spatial_shape)
source = source.flatten(2).transpose(1, 2)
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
Expand All @@ -2263,7 +2279,7 @@ def forward(
source_flatten = torch.cat(source_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
valid_ratios = valid_ratios.float()
Expand All @@ -2276,6 +2292,7 @@ def forward(
vision_attention_mask=~mask_flatten,
vision_position_embedding=lvl_pos_embed_flatten,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
text_features=text_features,
Expand Down Expand Up @@ -2352,6 +2369,7 @@ def forward(
text_encoder_attention_mask=~text_token_mask,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
self_attn_mask=None,
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/omdet_turbo/modeling_omdet_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ class OmDetTurboObjectDetectionOutput(ModelOutput):
classes_structure: Optional[torch.LongTensor] = None


# Copied from models.deformable_detr.MultiScaleDeformableAttention
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttention
class MultiScaleDeformableAttention(nn.Module):
def forward(
self,
Expand All @@ -187,10 +187,10 @@ def forward(
):
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
for level_id, (height, width) in enumerate(value_spatial_shapes_list):
# batch_size, height*width, num_heads, hidden_dim
# -> batch_size, height*width, num_heads*hidden_dim
# -> batch_size, num_heads*hidden_dim, height*width
Expand Down
7 changes: 3 additions & 4 deletions src/transformers/models/rt_detr/modeling_rt_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
_CHECKPOINT_FOR_DOC = "PekingU/rtdetr_r50vd"


# Copied from models.deformable_detr.MultiScaleDeformableAttention
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttention
class MultiScaleDeformableAttention(nn.Module):
def forward(
self,
Expand All @@ -65,10 +65,10 @@ def forward(
):
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
for level_id, (height, width) in enumerate(value_spatial_shapes_list):
# batch_size, height*width, num_heads, hidden_dim
# -> batch_size, height*width, num_heads*hidden_dim
# -> batch_size, num_heads*hidden_dim, height*width
Expand Down Expand Up @@ -1998,7 +1998,6 @@ def forward(
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.model(
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1997,7 +1997,6 @@ def forward(
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.model(
Expand Down