From 91db38eedbd0f6c42bcd27f47d212b9eaf034b7c Mon Sep 17 00:00:00 2001 From: qubvel Date: Mon, 24 Mar 2025 11:14:00 +0000 Subject: [PATCH 1/2] Fix pytorch path for DeformableAttention --- .../models/deformable_detr/modeling_deformable_detr.py | 4 ++-- .../models/omdet_turbo/modeling_omdet_turbo.py | 6 +++--- src/transformers/models/rt_detr/modeling_rt_detr.py | 7 +++---- src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py | 1 - 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 6ffebca32dcf..8db75abc0a77 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -70,10 +70,10 @@ def forward( ): batch_size, _, num_heads, hidden_dim = value.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape - value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) + value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] - for level_id, (height, width) in enumerate(value_spatial_shapes): + for level_id, (height, width) in enumerate(value_spatial_shapes_list): # batch_size, height*width, num_heads, hidden_dim # -> batch_size, height*width, num_heads*hidden_dim # -> batch_size, num_heads*hidden_dim, height*width diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py index 61cc747ca752..570c8cc3a32b 100644 --- a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -172,8 +172,8 @@ class OmDetTurboObjectDetectionOutput(ModelOutput): classes_structure: Optional[torch.LongTensor] = None -# Copied from models.deformable_detr.MultiScaleDeformableAttention @use_kernel_forward_from_hub("MultiScaleDeformableAttention") +# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttention class MultiScaleDeformableAttention(nn.Module): def forward( self, @@ -187,10 +187,10 @@ def forward( ): batch_size, _, num_heads, hidden_dim = value.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape - value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) + value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] - for level_id, (height, width) in enumerate(value_spatial_shapes): + for level_id, (height, width) in enumerate(value_spatial_shapes_list): # batch_size, height*width, num_heads, hidden_dim # -> batch_size, height*width, num_heads*hidden_dim # -> batch_size, num_heads*hidden_dim, height*width diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py index 54dfc43bf867..d7305d426527 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -50,8 +50,8 @@ _CHECKPOINT_FOR_DOC = "PekingU/rtdetr_r50vd" -# Copied from models.deformable_detr.MultiScaleDeformableAttention @use_kernel_forward_from_hub("MultiScaleDeformableAttention") +# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttention class MultiScaleDeformableAttention(nn.Module): def forward( self, @@ -65,10 +65,10 @@ def forward( ): batch_size, _, num_heads, hidden_dim = value.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape - value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) + value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] - for level_id, (height, width) in enumerate(value_spatial_shapes): + for level_id, (height, width) in enumerate(value_spatial_shapes_list): # batch_size, height*width, num_heads, hidden_dim # -> batch_size, height*width, num_heads*hidden_dim # -> batch_size, num_heads*hidden_dim, height*width @@ -1998,7 +1998,6 @@ def forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( diff --git a/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py b/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py index 514ea362eb8f..f707f5af27cb 100644 --- a/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +++ b/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py @@ -1997,7 +1997,6 @@ def forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( From 6e40c6fd42228a2bcdce8ecf108f21d0112b2892 Mon Sep 17 00:00:00 2001 From: qubvel Date: Mon, 24 Mar 2025 11:15:48 +0000 Subject: [PATCH 2/2] Apply for GroundingDino --- .../grounding_dino/modeling_grounding_dino.py | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index a8b244be5f6a..0b3b2899c145 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -51,8 +51,8 @@ _CHECKPOINT_FOR_DOC = "IDEA-Research/grounding-dino-tiny" -# Copied from models.deformable_detr.MultiScaleDeformableAttention @use_kernel_forward_from_hub("MultiScaleDeformableAttention") +# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttention class MultiScaleDeformableAttention(nn.Module): def forward( self, @@ -66,10 +66,10 @@ def forward( ): batch_size, _, num_heads, hidden_dim = value.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape - value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) + value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] - for level_id, (height, width) in enumerate(value_spatial_shapes): + for level_id, (height, width) in enumerate(value_spatial_shapes_list): # batch_size, height*width, num_heads, hidden_dim # -> batch_size, height*width, num_heads*hidden_dim # -> batch_size, num_heads*hidden_dim, height*width @@ -1015,6 +1015,7 @@ def forward( position_embeddings: torch.Tensor = None, reference_points=None, spatial_shapes=None, + spatial_shapes_list=None, level_start_index=None, output_attentions: bool = False, ): @@ -1030,6 +1031,8 @@ def forward( Reference points. spatial_shapes (`torch.LongTensor`, *optional*): Spatial shapes of the backbone feature maps. + spatial_shapes_list (`List[Tuple[int, int]]`, *optional*): + Spatial shapes of the backbone feature maps (but as list for export compatibility). level_start_index (`torch.LongTensor`, *optional*): Level start index. output_attentions (`bool`, *optional*): @@ -1047,6 +1050,7 @@ def forward( position_embeddings=position_embeddings, reference_points=reference_points, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, output_attentions=output_attentions, ) @@ -1147,6 +1151,7 @@ def forward( vision_features: Tensor, vision_position_embedding: Tensor, spatial_shapes: Tensor, + spatial_shapes_list: List[Tuple[int, int]], level_start_index: Tensor, key_padding_mask: Tensor, reference_points: Tensor, @@ -1179,6 +1184,7 @@ def forward( position_embeddings=vision_position_embedding, reference_points=reference_points, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, ) @@ -1295,6 +1301,7 @@ def forward( position_embeddings: Optional[torch.Tensor] = None, reference_points=None, spatial_shapes=None, + spatial_shapes_list=None, level_start_index=None, vision_encoder_hidden_states: Optional[torch.Tensor] = None, vision_encoder_attention_mask: Optional[torch.Tensor] = None, @@ -1347,6 +1354,7 @@ def forward( position_embeddings=position_embeddings, reference_points=reference_points, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, output_attentions=output_attentions, ) @@ -1594,6 +1602,7 @@ def forward( vision_attention_mask: Tensor, vision_position_embedding: Tensor, spatial_shapes: Tensor, + spatial_shapes_list: List[Tuple[int, int]], level_start_index: Tensor, valid_ratios=None, text_features: Optional[Tensor] = None, @@ -1618,6 +1627,8 @@ def forward( Position embeddings that are added to the queries and keys in each self-attention layer. spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): Spatial shapes of each feature map. + spatial_shapes_list (`List[Tuple[int, int]]`): + Spatial shapes of each feature map (but as list for export compatibility). level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): Starting index of each feature map. valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): @@ -1670,6 +1681,7 @@ def forward( vision_features=vision_features, vision_position_embedding=vision_position_embedding, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, key_padding_mask=vision_attention_mask, reference_points=reference_points, @@ -1748,6 +1760,7 @@ def forward( text_encoder_attention_mask=None, reference_points=None, spatial_shapes=None, + spatial_shapes_list=None, level_start_index=None, valid_ratios=None, self_attn_mask=None, @@ -1775,6 +1788,8 @@ def forward( Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area. spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`): Spatial shapes of the feature maps. + spatial_shapes_list (`List[Tuple[int, int]]`): + Spatial shapes of the feature maps (but as list for export compatibility). level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*): Indexes for the start of each feature level. In range `[0, sequence_length]`. valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*): @@ -1867,6 +1882,7 @@ def custom_forward(*inputs): position_embeddings=query_pos, reference_points=reference_points_input, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, vision_encoder_hidden_states=vision_encoder_hidden_states, vision_encoder_attention_mask=vision_encoder_attention_mask, @@ -2248,11 +2264,11 @@ def forward( source_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] - spatial_shapes = [] + spatial_shapes_list = [] for level, (source, mask, pos_embed) in enumerate(zip(feature_maps, masks, position_embeddings_list)): batch_size, num_channels, height, width = source.shape spatial_shape = (height, width) - spatial_shapes.append(spatial_shape) + spatial_shapes_list.append(spatial_shape) source = source.flatten(2).transpose(1, 2) mask = mask.flatten(1) pos_embed = pos_embed.flatten(2).transpose(1, 2) @@ -2263,7 +2279,7 @@ def forward( source_flatten = torch.cat(source_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) - spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device) + spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) valid_ratios = valid_ratios.float() @@ -2276,6 +2292,7 @@ def forward( vision_attention_mask=~mask_flatten, vision_position_embedding=lvl_pos_embed_flatten, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, valid_ratios=valid_ratios, text_features=text_features, @@ -2352,6 +2369,7 @@ def forward( text_encoder_attention_mask=~text_token_mask, reference_points=reference_points, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, valid_ratios=valid_ratios, self_attn_mask=None,