From eaef822134ab632f7b78ea558b0c05b570dfa77d Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Fri, 27 Feb 2026 22:06:58 +0400 Subject: [PATCH 01/17] init: Add files (v1) --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/deimv2.md | 68 + src/transformers/loss/loss_deimv2.py | 167 ++ src/transformers/loss/loss_utils.py | 2 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + .../models/auto/image_processing_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 2 + src/transformers/models/deimv2/__init__.py | 29 + .../models/deimv2/configuration_deimv2.py | 365 +++ ...eimv2_original_pytorch_checkpoint_to_hf.py | 451 ++++ .../models/deimv2/modeling_deimv2.py | 2098 +++++++++++++++++ .../models/deimv2/modular_deimv2.py | 819 +++++++ tests/models/deimv2/__init__.py | 0 tests/models/deimv2/test_modeling_deimv2.py | 655 +++++ 15 files changed, 4662 insertions(+) create mode 100644 docs/source/en/model_doc/deimv2.md create mode 100644 src/transformers/loss/loss_deimv2.py create mode 100644 src/transformers/models/deimv2/__init__.py create mode 100644 src/transformers/models/deimv2/configuration_deimv2.py create mode 100644 src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py create mode 100644 src/transformers/models/deimv2/modeling_deimv2.py create mode 100644 src/transformers/models/deimv2/modular_deimv2.py create mode 100644 tests/models/deimv2/__init__.py create mode 100644 tests/models/deimv2/test_modeling_deimv2.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 61f34ccb891c..1668ed109a40 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -842,6 +842,8 @@ title: DAB-DETR - local: model_doc/deformable_detr title: Deformable DETR + - local: model_doc/deimv2 + title: DEIMv2 - local: model_doc/deit title: DeiT - local: model_doc/depth_anything diff --git a/docs/source/en/model_doc/deimv2.md b/docs/source/en/model_doc/deimv2.md new file mode 100644 index 000000000000..3d4e4a4a77b1 --- /dev/null +++ b/docs/source/en/model_doc/deimv2.md @@ -0,0 +1,68 @@ + + +# DEIMv2 + +## Overview + +DEIMv2 (DETR with Improved Matching v2) was proposed in [DEIMv2: Real-Time Object Detection Meets DINOv3](https://huggingface.co/papers/2509.20787) by Shihua Huang, Yongjie Hou, Longfei Liu, Xuanlong Yu, and Xi Shen. + +DEIMv2 builds upon D-FINE's distribution-based bounding box refinement approach, adding several key innovations: +- **SwiGLU FFN**: Replaces the standard MLP in decoder layers with a SwiGLU-gated feed-forward network. +- **RMSNorm**: Uses RMSNorm instead of LayerNorm in decoder layers for improved training stability. +- **RepNCSPELAN5**: An enhanced 5-branch CSP-ELAN encoder block (vs D-Fine's 4-branch RepNCSPELAN4). +- **Matching Auxiliary Loss (MAL)**: A focal-style BCE loss with IoU-weighted targets replacing VFL. +- **Dense O2O Matching**: Unified matching across decoder layers for improved training convergence. + +## Usage + +```python +from transformers import AutoImageProcessor, Deimv2ForObjectDetection +from transformers.image_utils import load_image + +url = "http://images.cocodataset.org/val2017/000000039769.jpg" +image = load_image(url) + +# TODO: Replace with Transformers-compatible ckpts once uploaded. +image_processor = AutoImageProcessor.from_pretrained("Intellindust/DEIMv2_HGNetv2_N_COCO") +model = Deimv2ForObjectDetection.from_pretrained("Intellindust/DEIMv2_HGNetv2_N_COCO") + +inputs = image_processor(images=image, return_tensors="pt") +outputs = model(**inputs) + +results = image_processor.post_process_object_detection( + outputs, threshold=0.5, target_sizes=[image.size[::-1]] +) + +for result in results: + for score, label, box in zip(result["scores"], result["labels"], result["boxes"]): + box = [round(i, 2) for i in box.tolist()] + print(f"Detected {model.config.id2label[label.item()]} with confidence {round(score.item(), 3)} at location {box}") +``` + +## Deimv2Config + +[[autodoc]] Deimv2Config + +## Deimv2Model + +[[autodoc]] Deimv2Model + - forward + +## Deimv2ForObjectDetection + +[[autodoc]] Deimv2ForObjectDetection + - forward diff --git a/src/transformers/loss/loss_deimv2.py b/src/transformers/loss/loss_deimv2.py new file mode 100644 index 000000000000..ec71299a3d0a --- /dev/null +++ b/src/transformers/loss/loss_deimv2.py @@ -0,0 +1,167 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn.functional as F + +from ..utils import is_vision_available +from .loss_d_fine import DFineLoss, _set_aux_loss, _set_aux_loss2 +from .loss_for_object_detection import box_iou + + +if is_vision_available(): + from transformers.image_transforms import center_to_corners_format + + +class Deimv2Loss(DFineLoss): + def __init__(self, config): + super().__init__(config) + self.weight_dict = { + "loss_mal": config.weight_loss_mal, + "loss_bbox": config.weight_loss_bbox, + "loss_giou": config.weight_loss_giou, + "loss_fgl": config.weight_loss_fgl, + "loss_ddf": config.weight_loss_ddf, + } + self.losses = ["mal", "boxes", "local"] + self.mal_alpha = config.mal_alpha + self.use_dense_o2o = config.use_dense_o2o + + def loss_labels_mal(self, outputs, targets, indices, num_boxes): + idx = self._get_source_permutation_idx(indices) + + src_boxes = outputs["pred_boxes"][idx] + target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) + ious, _ = box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes)) + ious = torch.diag(ious).detach() + + src_logits = outputs["logits"] + target_classes_o = torch.cat([t["class_labels"][i] for t, (_, i) in zip(targets, indices)]) + target_classes = torch.full( + src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device + ) + target_classes[idx] = target_classes_o + target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1] + + target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype) + target_score_o[idx] = ious.to(target_score_o.dtype) + target_score = target_score_o.unsqueeze(-1) * target + + pred_score = F.sigmoid(src_logits).detach() + target_score = target_score.pow(self.gamma) + if self.mal_alpha is not None: + weight = self.mal_alpha * pred_score.pow(self.gamma) * (1 - target) + target + else: + weight = pred_score.pow(self.gamma) * (1 - target) + target + + loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction="none") + loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes + return {"loss_mal": loss} + + def _get_dense_o2o_indices(self, indices, indices_aux_list): + results = [] + for indices_aux in indices_aux_list: + indices = [ + (torch.cat([idx1[0], idx2[0]]), torch.cat([idx1[1], idx2[1]])) + for idx1, idx2 in zip(indices.copy(), indices_aux.copy()) + ] + + for ind in [torch.cat([idx[0][:, None], idx[1][:, None]], 1) for idx in indices]: + unique, counts = torch.unique(ind, return_counts=True, dim=0) + count_sort_indices = torch.argsort(counts, descending=True) + unique_sorted = unique[count_sort_indices] + column_to_row = {} + for idx_pair in unique_sorted: + row_idx, col_idx = idx_pair[0].item(), idx_pair[1].item() + if row_idx not in column_to_row: + column_to_row[row_idx] = col_idx + final_rows = torch.tensor(list(column_to_row.keys()), device=ind.device) + final_cols = torch.tensor(list(column_to_row.values()), device=ind.device) + results.append((final_rows.long(), final_cols.long())) + return results + + def get_loss(self, loss, outputs, targets, indices, num_boxes): + loss_map = { + "cardinality": self.loss_cardinality, + "local": self.loss_local, + "boxes": self.loss_boxes, + "focal": self.loss_labels_focal, + "vfl": self.loss_labels_vfl, + "mal": self.loss_labels_mal, + } + if loss not in loss_map: + raise ValueError(f"Loss {loss} not supported") + return loss_map[loss](outputs, targets, indices, num_boxes) + + +def Deimv2ForObjectDetectionLoss( + logits, + labels, + device, + pred_boxes, + config, + outputs_class=None, + outputs_coord=None, + enc_topk_logits=None, + enc_topk_bboxes=None, + denoising_meta_values=None, + predicted_corners=None, + initial_reference_points=None, + **kwargs, +): + criterion = Deimv2Loss(config) + criterion.to(device) + outputs_loss = {} + outputs_loss["logits"] = logits + outputs_loss["pred_boxes"] = pred_boxes.clamp(min=0, max=1) + auxiliary_outputs = None + if config.auxiliary_loss: + if denoising_meta_values is not None: + dn_out_coord, outputs_coord = torch.split( + outputs_coord.clamp(min=0, max=1), denoising_meta_values["dn_num_split"], dim=2 + ) + dn_out_class, outputs_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2) + dn_out_corners, out_corners = torch.split(predicted_corners, denoising_meta_values["dn_num_split"], dim=2) + dn_out_refs, out_refs = torch.split(initial_reference_points, denoising_meta_values["dn_num_split"], dim=2) + + auxiliary_outputs = _set_aux_loss2( + outputs_class[:, :-1].transpose(0, 1), + outputs_coord[:, :-1].transpose(0, 1), + out_corners[:, :-1].transpose(0, 1), + out_refs[:, :-1].transpose(0, 1), + out_corners[:, -1], + outputs_class[:, -1], + ) + + outputs_loss["auxiliary_outputs"] = auxiliary_outputs + outputs_loss["auxiliary_outputs"].extend( + _set_aux_loss([enc_topk_logits], [enc_topk_bboxes.clamp(min=0, max=1)]) + ) + + dn_auxiliary_outputs = _set_aux_loss2( + dn_out_class.transpose(0, 1), + dn_out_coord.transpose(0, 1), + dn_out_corners.transpose(0, 1), + dn_out_refs.transpose(0, 1), + dn_out_corners[:, -1], + dn_out_class[:, -1], + ) + outputs_loss["dn_auxiliary_outputs"] = dn_auxiliary_outputs + outputs_loss["denoising_meta_values"] = denoising_meta_values + + loss_dict = criterion(outputs_loss, labels) + + loss = sum(loss_dict.values()) + return loss, loss_dict, auxiliary_outputs diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index df269477e9ec..51564d299e55 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -19,6 +19,7 @@ from .loss_d_fine import DFineForObjectDetectionLoss from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss +from .loss_deimv2 import Deimv2ForObjectDetectionLoss from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss from .loss_grounding_dino import GroundingDinoForObjectDetectionLoss from .loss_lw_detr import LwDetrForObjectDetectionLoss @@ -163,6 +164,7 @@ def ForTokenClassification(logits: torch.Tensor, labels, config, **kwargs): "RTDetrForObjectDetection": RTDetrForObjectDetectionLoss, "RTDetrV2ForObjectDetection": RTDetrForObjectDetectionLoss, "DFineForObjectDetection": DFineForObjectDetectionLoss, + "Deimv2ForObjectDetection": Deimv2ForObjectDetectionLoss, "CsmForConditionalGeneration": ForCausalLMLoss, "LwDetrForObjectDetection": LwDetrForObjectDetectionLoss, } diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 791a23149934..31a949e80738 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -94,6 +94,7 @@ from .deepseek_vl import * from .deepseek_vl_hybrid import * from .deformable_detr import * + from .deimv2 import * from .deit import * from .deprecated import * from .depth_anything import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 4a493e759d16..3a7ca87b3dce 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -112,6 +112,7 @@ ("deepseek_vl", "DeepseekVLConfig"), ("deepseek_vl_hybrid", "DeepseekVLHybridConfig"), ("deformable_detr", "DeformableDetrConfig"), + ("deimv2", "Deimv2Config"), ("deit", "DeiTConfig"), ("depth_anything", "DepthAnythingConfig"), ("depth_pro", "DepthProConfig"), @@ -595,6 +596,7 @@ ("deepseek_vl", "DeepseekVL"), ("deepseek_vl_hybrid", "DeepseekVLHybrid"), ("deformable_detr", "Deformable DETR"), + ("deimv2", "DEIMv2"), ("deit", "DeiT"), ("deplot", "DePlot"), ("depth_anything", "Depth Anything"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index e4d8f08963d6..512992a6ddf0 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -88,6 +88,7 @@ ("deepseek_vl", ("DeepseekVLImageProcessor", "DeepseekVLImageProcessorFast")), ("deepseek_vl_hybrid", ("DeepseekVLHybridImageProcessor", "DeepseekVLHybridImageProcessorFast")), ("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")), + ("deimv2", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")), ("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")), ("depth_anything", ("DPTImageProcessor", "DPTImageProcessorFast")), ("depth_pro", ("DepthProImageProcessor", "DepthProImageProcessorFast")), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 1fbab01f684b..34854698f6d6 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -116,6 +116,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("deepseek_vl", "DeepseekVLModel"), ("deepseek_vl_hybrid", "DeepseekVLHybridModel"), ("deformable_detr", "DeformableDetrModel"), + ("deimv2", "Deimv2Model"), ("deit", "DeiTModel"), ("depth_pro", "DepthProModel"), ("detr", "DetrModel"), @@ -1068,6 +1069,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("d_fine", "DFineForObjectDetection"), ("dab-detr", "DabDetrForObjectDetection"), ("deformable_detr", "DeformableDetrForObjectDetection"), + ("deimv2", "Deimv2ForObjectDetection"), ("detr", "DetrForObjectDetection"), ("lw_detr", "LwDetrForObjectDetection"), ("pp_doclayout_v2", "PPDocLayoutV2ForObjectDetection"), diff --git a/src/transformers/models/deimv2/__init__.py b/src/transformers/models/deimv2/__init__.py new file mode 100644 index 000000000000..07caf7c91851 --- /dev/null +++ b/src/transformers/models/deimv2/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_deimv2 import * + from .modeling_deimv2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/deimv2/configuration_deimv2.py b/src/transformers/models/deimv2/configuration_deimv2.py new file mode 100644 index 000000000000..56936d2cdd7a --- /dev/null +++ b/src/transformers/models/deimv2/configuration_deimv2.py @@ -0,0 +1,365 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/deimv2/modular_deimv2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_deimv2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...backbone_utils import consolidate_backbone_kwargs_to_config +from ...configuration_utils import PreTrainedConfig +from ..auto import AutoConfig + + +# TODO: Attribute map assignment logic should be fixed in modular +# as well as super() call parsing because otherwise we cannot re-write args after initialization +class Deimv2Config(PreTrainedConfig): + """ + This is the configuration class to store the configuration of a [`Deimv2Model`]. It is used to instantiate a + DEIMv2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of DEIMv2-L-COCO. + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + initializer_range (`float`, *optional*, defaults to 0.01): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_bias_prior_prob (`float`, *optional*): + The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`. + If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + batch_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the batch normalization layers. + backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `HGNetV2Config()`): + The configuration of the backbone model. + freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`): + Whether to freeze the batch normalization layers in the backbone. + encoder_hidden_dim (`int`, *optional*, defaults to 256): + Dimension of the layers in hybrid encoder. + encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`): + Multi level features input for encoder. + feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`): + Strides used in each feature map. + encoder_layers (`int`, *optional*, defaults to 1): + Total of layers to be used by the encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in encoder. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + dropout (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`): + Indexes of the projected layers to be used in the encoder. + positional_encoding_temperature (`int`, *optional*, defaults to 10000): + The temperature parameter used to create the positional encodings. + encoder_activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. + activation_function (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the general layer. + eval_size (`tuple[int, int]`, *optional*): + Height and width used to computes the effective height and width of the position embeddings after taking + into account the stride. + normalize_before (`bool`, *optional*, defaults to `False`): + Determine whether to apply layer normalization in the transformer encoder layer before self-attention and + feed-forward modules. + hidden_expansion (`float`, *optional*, defaults to 1.0): + Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer. + d_model (`int`, *optional*, defaults to 256): + Dimension of the layers exclude hybrid encoder. + num_queries (`int`, *optional*, defaults to 300): + Number of object queries. + decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`): + Multi level features dimension for decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of input feature levels. + decoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the decoder. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_activation_function (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the decoder. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_denoising (`int`, *optional*, defaults to 100): + The total number of denoising tasks or queries to be used for contrastive denoising. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + The fraction of denoising labels to which random noise should be added. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale or magnitude of noise to be added to the bounding boxes. + learn_initial_query (`bool`, *optional*, defaults to `False`): + Indicates whether the initial query embeddings for the decoder should be learned during training. + anchor_image_size (`tuple[int, int]`, *optional*): + Height and width of the input image used during evaluation to generate the bounding box anchors. + with_box_refine (`bool`, *optional*, defaults to `True`): + Whether to apply iterative bounding box refinement. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the architecture has an encoder decoder structure. + matcher_alpha (`float`, *optional*, defaults to 0.25): + Parameter alpha used by the Hungarian Matcher. + matcher_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used by the Hungarian Matcher. + matcher_class_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the class loss used by the Hungarian Matcher. + matcher_bbox_cost (`float`, *optional*, defaults to 5.0): + The relative weight of the bounding box loss used by the Hungarian Matcher. + matcher_giou_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the giou loss of used by the Hungarian Matcher. + use_focal_loss (`bool`, *optional*, defaults to `True`): + Parameter informing if focal loss should be used. + auxiliary_loss (`bool`, *optional*, defaults to `True`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + focal_loss_alpha (`float`, *optional*, defaults to 0.75): + Parameter alpha used to compute the focal loss. + focal_loss_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used to compute the focal loss. + weight_loss_vfl (`float`, *optional*, defaults to 1.0): + Relative weight of the varifocal loss in the object detection loss. + weight_loss_bbox (`float`, *optional*, defaults to 5.0): + Relative weight of the L1 bounding box loss in the object detection loss. + weight_loss_giou (`float`, *optional*, defaults to 2.0): + Relative weight of the generalized IoU loss in the object detection loss. + weight_loss_fgl (`float`, *optional*, defaults to 0.15): + Relative weight of the fine-grained localization loss in the object detection loss. + weight_loss_ddf (`float`, *optional*, defaults to 1.5): + Relative weight of the decoupled distillation focal loss in the object detection loss. + eos_coefficient (`float`, *optional*, defaults to 0.0001): + Relative classification weight of the 'no-object' class in the object detection loss. + eval_idx (`int`, *optional*, defaults to -1): + Index of the decoder layer to use for evaluation. + layer_scale (`float`, *optional*, defaults to `1.0`): + Scaling factor for the hidden dimension in later decoder layers. + max_num_bins (`int`, *optional*, defaults to 32): + Maximum number of bins for the distribution-guided bounding box refinement. + reg_scale (`float`, *optional*, defaults to 4.0): + Scale factor for the regression distribution. + depth_mult (`float`, *optional*, defaults to 1.0): + Multiplier for the number of blocks in RepNCSPELAN5 layers. + top_prob_values (`int`, *optional*, defaults to 4): + Number of top probability values to consider from each corner's distribution. + lqe_hidden_dim (`int`, *optional*, defaults to 64): + Hidden dimension size for the Location Quality Estimator (LQE) network. + lqe_layers (`int`, *optional*, defaults to 2): + Number of layers in the Location Quality Estimator MLP. + decoder_offset_scale (`float`, *optional*, defaults to 0.5): + Offset scale used in deformable attention. + decoder_method (`str`, *optional*, defaults to `"default"`): + The method to use for the decoder: `"default"` or `"discrete"`. + up (`float`, *optional*, defaults to 0.5): + Controls the upper bounds of the Weighting Function. + weight_loss_mal (`float`, *optional*, defaults to 1.0): + Relative weight of the matching auxiliary loss in the object detection loss. + use_dense_o2o (`bool`, *optional*, defaults to `True`): + Whether to use dense one-to-one matching across decoder layers. + mal_alpha (`float`, *optional*): + Alpha parameter for the Matching Auxiliary Loss (MAL). If `None`, uses `focal_loss_alpha`. + encoder_fuse_op (`str`, *optional*, defaults to `"sum"`): + Fusion operation used in the encoder FPN. DEIMv2 uses `"sum"` instead of D-Fine's `"cat"`. + use_spatial_tuning_adapter (`bool`, *optional*, defaults to `False`): + Whether to use the Spatial Tuning Adapter (STA) for DINOv2 backbone variants. + sta_inplanes (`int`, *optional*, defaults to 16): + Number of input planes for the STA convolutional stem. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings. + """ + + model_type = "deimv2" + sub_configs = {"backbone_config": AutoConfig} + layer_types = ["basic", "bottleneck"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + } + + def __init__( + self, + initializer_range=0.01, + initializer_bias_prior_prob=None, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + backbone_config=None, + freeze_backbone_batch_norms=True, + encoder_hidden_dim=256, + encoder_in_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=1024, + encoder_attention_heads=8, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + hidden_expansion=1.0, + d_model=256, + num_queries=300, + decoder_in_channels=[256, 256, 256], + decoder_ffn_dim=1024, + num_feature_levels=3, + decoder_n_points=4, + decoder_layers=6, + decoder_attention_heads=8, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + with_box_refine=True, + is_encoder_decoder=True, + matcher_alpha=0.25, + matcher_gamma=2.0, + matcher_class_cost=2.0, + matcher_bbox_cost=5.0, + matcher_giou_cost=2.0, + use_focal_loss=True, + auxiliary_loss=True, + focal_loss_alpha=0.75, + focal_loss_gamma=2.0, + weight_loss_vfl=1.0, + weight_loss_bbox=5.0, + weight_loss_giou=2.0, + weight_loss_fgl=0.15, + weight_loss_ddf=1.5, + eos_coefficient=1e-4, + eval_idx=-1, + layer_scale=1, + max_num_bins=32, + reg_scale=4.0, + depth_mult=1.0, + top_prob_values=4, + lqe_hidden_dim=64, + lqe_layers=2, + decoder_offset_scale=0.5, + decoder_method="default", + up=0.5, + weight_loss_mal=1.0, + use_dense_o2o=True, + mal_alpha=None, + encoder_fuse_op="sum", + use_spatial_tuning_adapter=False, + sta_inplanes=16, + tie_word_embeddings=True, + **kwargs, + ): + self.weight_loss_mal = weight_loss_mal + self.use_dense_o2o = use_dense_o2o + self.mal_alpha = mal_alpha + self.encoder_fuse_op = encoder_fuse_op + self.use_spatial_tuning_adapter = use_spatial_tuning_adapter + self.sta_inplanes = sta_inplanes + self.initializer_range = initializer_range + self.initializer_bias_prior_prob = initializer_bias_prior_prob + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + + backbone_config, kwargs = consolidate_backbone_kwargs_to_config( + backbone_config=backbone_config, + default_config_type="hgnet_v2", + default_config_kwargs={"out_indices": [2, 3, 4]}, + **kwargs, + ) + + self.backbone_config = backbone_config + self.freeze_backbone_batch_norms = freeze_backbone_batch_norms + # encoder + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = encode_proj_layers + self.encoder_layers = encoder_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.eval_size = eval_size + self.normalize_before = normalize_before + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.hidden_expansion = hidden_expansion + # decoder + self.d_model = d_model + self.num_queries = num_queries + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_in_channels = decoder_in_channels + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.auxiliary_loss = auxiliary_loss + self.with_box_refine = with_box_refine + # Loss + self.matcher_alpha = matcher_alpha + self.matcher_gamma = matcher_gamma + self.matcher_class_cost = matcher_class_cost + self.matcher_bbox_cost = matcher_bbox_cost + self.matcher_giou_cost = matcher_giou_cost + self.use_focal_loss = use_focal_loss + self.focal_loss_alpha = focal_loss_alpha + self.focal_loss_gamma = focal_loss_gamma + self.weight_loss_vfl = weight_loss_vfl + self.weight_loss_bbox = weight_loss_bbox + self.weight_loss_giou = weight_loss_giou + self.weight_loss_fgl = weight_loss_fgl + self.weight_loss_ddf = weight_loss_ddf + self.eos_coefficient = eos_coefficient + # add the new attributes with the given values or defaults + self.eval_idx = eval_idx + self.layer_scale = layer_scale + self.max_num_bins = max_num_bins + self.reg_scale = reg_scale + self.depth_mult = depth_mult + self.decoder_offset_scale = decoder_offset_scale + self.decoder_method = decoder_method + self.top_prob_values = top_prob_values + self.lqe_hidden_dim = lqe_hidden_dim + self.lqe_layers = lqe_layers + self.up = up + self.tie_word_embeddings = tie_word_embeddings + + if isinstance(self.decoder_n_points, list): + if len(self.decoder_n_points) != self.num_feature_levels: + raise ValueError( + f"Length of decoder_n_points list ({len(self.decoder_n_points)}) must match num_feature_levels ({self.num_feature_levels})." + ) + + head_dim = self.d_model // self.decoder_attention_heads + if head_dim * self.decoder_attention_heads != self.d_model: + raise ValueError( + f"Embedded dimension {self.d_model} must be divisible by decoder_attention_heads {self.decoder_attention_heads}" + ) + + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + + +__all__ = ["Deimv2Config"] diff --git a/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py b/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py new file mode 100644 index 000000000000..5aaecc844594 --- /dev/null +++ b/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py @@ -0,0 +1,451 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import re +from io import BytesIO +from pathlib import Path + +import httpx +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from safetensors.torch import load_file +from torchvision import transforms + +from transformers import Deimv2Config, Deimv2ForObjectDetection, RTDetrImageProcessor +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +MODEL_NAME_TO_HUB_REPO = { + "deimv2_hgnetv2_n_coco": "Intellindust/DEIMv2_HGNetv2_N_COCO", + "deimv2_hgnetv2_pico_coco": "Intellindust/DEIMv2_HGNetv2_PICO_COCO", + "deimv2_hgnetv2_femto_coco": "Intellindust/DEIMv2_HGNetv2_FEMTO_COCO", + "deimv2_hgnetv2_atto_coco": "Intellindust/DEIMv2_HGNetv2_ATTO_COCO", + "deimv2_dinov3_s_coco": "Intellindust/DEIMv2_DINOv3_S_COCO", + "deimv2_dinov3_m_coco": "Intellindust/DEIMv2_DINOv3_M_COCO", + "deimv2_dinov3_l_coco": "Intellindust/DEIMv2_DINOv3_L_COCO", + "deimv2_dinov3_x_coco": "Intellindust/DEIMv2_DINOv3_X_COCO", +} + + +def get_deimv2_config(model_name: str) -> Deimv2Config: + repo_id = MODEL_NAME_TO_HUB_REPO[model_name] + config_path = hf_hub_download(repo_id=repo_id, filename="config.json") + with open(config_path) as f: + orig_config = json.load(f) + + # COCO labels + id2label = json.load( + open(hf_hub_download("huggingface/label-files", "coco-detection-mmdet-id2label.json", repo_type="dataset")) + ) + id2label = {int(k): v for k, v in id2label.items()} + + decoder_cfg = orig_config["DEIMTransformer"] + if "HybridEncoder" in orig_config: + encoder_cfg = orig_config["HybridEncoder"] + elif "LiteEncoder" in orig_config: + raise ValueError( + "LiteEncoder variants (pico/femto/atto) are not yet supported. " + "The LiteEncoder uses a different architecture (AvgPool downsampling, GAP fusion, " + "RepNCSPELAN4 blocks) that requires a dedicated Deimv2LiteEncoder implementation. " + "Supported variants: deimv2_hgnetv2_n_coco and DINOv3 variants." + ) + else: + raise ValueError(f"No encoder config found. Available keys: {list(orig_config.keys())}") + + config = Deimv2Config() + config.num_labels = 80 + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + # Encoder settings + config.encoder_hidden_dim = encoder_cfg["hidden_dim"] + config.encoder_in_channels = encoder_cfg["in_channels"] + config.feat_strides = encoder_cfg["feat_strides"] + config.activation_function = encoder_cfg.get("act", "silu") + config.depth_mult = encoder_cfg.get("depth_mult", 1.0) + config.hidden_expansion = encoder_cfg.get("expansion", 1.0) + config.encoder_fuse_op = encoder_cfg.get("fuse_op", "sum") + config.encoder_ffn_dim = encoder_cfg["dim_feedforward"] + config.encoder_attention_heads = encoder_cfg["nhead"] + config.dropout = encoder_cfg.get("dropout", 0.0) + config.encode_proj_layers = encoder_cfg["use_encoder_idx"] + config.encoder_activation_function = encoder_cfg.get("enc_act", "gelu") + + # Decoder settings + config.d_model = decoder_cfg["hidden_dim"] + config.decoder_ffn_dim = decoder_cfg["dim_feedforward"] + config.decoder_layers = decoder_cfg["num_layers"] + config.num_feature_levels = decoder_cfg["num_levels"] + config.decoder_n_points = decoder_cfg["num_points"] + config.num_queries = decoder_cfg["num_queries"] + config.num_denoising = decoder_cfg.get("num_denoising", 100) + config.label_noise_ratio = decoder_cfg.get("label_noise_ratio", 0.5) + config.box_noise_scale = decoder_cfg.get("box_noise_scale", 1.0) + config.max_num_bins = decoder_cfg.get("reg_max", 32) + config.reg_scale = decoder_cfg.get("reg_scale", 4.0) + config.eval_idx = decoder_cfg.get("eval_idx", -1) + config.layer_scale = decoder_cfg.get("layer_scale", 1) + config.decoder_in_channels = decoder_cfg["feat_channels"] + config.eval_size = tuple(decoder_cfg["eval_spatial_size"]) if "eval_spatial_size" in decoder_cfg else None + config.decoder_activation_function = decoder_cfg.get("activation", "silu") + + # Backbone settings + if "HGNetv2" in orig_config: + backbone_cfg = orig_config["HGNetv2"] + backbone_name = backbone_cfg.get("name", "B0") + return_idx = backbone_cfg.get("return_idx", [2, 3]) + config.backbone_config.out_indices = [i + 1 for i in return_idx] + config.backbone_config.use_learnable_affine_block = backbone_cfg.get("use_lab", True) + + # Set backbone sizes based on the model variant + if backbone_name == "B0": + config.backbone_config.hidden_sizes = [128, 256, 512, 1024] + config.backbone_config.stem_channels = [3, 16, 16] + config.backbone_config.stage_in_channels = [16, 64, 256, 512] + config.backbone_config.stage_mid_channels = [16, 32, 64, 128] + config.backbone_config.stage_out_channels = [64, 256, 512, 1024] + config.backbone_config.stage_num_blocks = [1, 1, 2, 1] + config.backbone_config.stage_downsample = [False, True, True, True] + config.backbone_config.stage_light_block = [False, False, True, True] + config.backbone_config.stage_kernel_size = [3, 3, 5, 5] + config.backbone_config.stage_numb_of_layers = [3, 3, 3, 3] + elif backbone_name in ["B1", "B2"]: + config.backbone_config.hidden_sizes = [128, 256, 512, 1024] + config.backbone_config.stem_channels = [3, 16, 16] + config.backbone_config.stage_in_channels = [16, 64, 256, 512] + config.backbone_config.stage_mid_channels = [16, 32, 64, 128] + config.backbone_config.stage_out_channels = [64, 256, 512, 1024] + config.backbone_config.stage_num_blocks = [1, 1, 2, 1] + config.backbone_config.stage_downsample = [False, True, True, True] + config.backbone_config.stage_light_block = [False, False, True, True] + config.backbone_config.stage_kernel_size = [3, 3, 5, 5] + config.backbone_config.stage_numb_of_layers = [3, 3, 3, 3] + else: + raise ValueError(f"Unknown HGNetv2 variant: {backbone_name}") + + config.use_spatial_tuning_adapter = False + elif "DINOv3STAs" in orig_config: + raise ValueError( + "DINOv3 backbone variants are not yet supported. " + "The DINOv3+STA architecture requires ViT backbone key mappings and " + "STA adapter integration that are not yet implemented in the conversion script. " + "Supported variants: deimv2_hgnetv2_n_coco." + ) + else: + raise ValueError(f"Unknown backbone in config: {list(orig_config.keys())}") + + return config + + +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + # Backbone stem mappings + r"backbone\.stem\.(stem\w+)\.conv\.weight": r"model.backbone.model.embedder.\1.convolution.weight", + # Stem normalization + r"backbone\.stem\.(stem\w+)\.bn\.(weight|bias|running_mean|running_var)": r"model.backbone.model.embedder.\1.normalization.\2", + # Stem lab parameters + r"backbone\.stem\.(stem\w+)\.lab\.(scale|bias)": r"model.backbone.model.embedder.\1.lab.\2", + # Backbone stages mappings + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv\.weight": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.convolution.weight", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.bn\.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.normalization.\4", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.lab\.(scale|bias)": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.lab.\4", + # Conv1/Conv2 layers + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv1\.conv\.weight": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv1.convolution.weight", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv1\.bn\.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv1.normalization.\4", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv1\.lab\.(scale|bias)": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv1.lab.\4", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv2\.conv\.weight": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv2.convolution.weight", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv2\.bn\.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv2.normalization.\4", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv2\.lab\.(scale|bias)": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv2.lab.\4", + # Backbone stages aggregation + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.aggregation\.(\d+)\.conv\.weight": r"model.backbone.model.encoder.stages.\1.blocks.\2.aggregation.\3.convolution.weight", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.aggregation\.(\d+)\.bn\.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.blocks.\2.aggregation.\3.normalization.\4", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.aggregation\.(\d+)\.lab\.(scale|bias)": r"model.backbone.model.encoder.stages.\1.blocks.\2.aggregation.\3.lab.\4", + # Downsample + r"backbone\.stages\.(\d+)\.downsample\.conv\.weight": r"model.backbone.model.encoder.stages.\1.downsample.convolution.weight", + r"backbone\.stages\.(\d+)\.downsample\.bn\.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.downsample.normalization.\2", + r"backbone\.stages\.(\d+)\.downsample\.lab\.(scale|bias)": r"model.backbone.model.encoder.stages.\1.downsample.lab.\2", + # Encoder mappings + # Input projections + r"encoder\.input_proj\.(\d+)\.conv\.weight": r"model.encoder_input_proj.\1.0.weight", + r"encoder\.input_proj\.(\d+)\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder_input_proj.\1.1.\2", + # AIFI transformer encoder layers + r"encoder\.encoder\.(\d+)\.layers\.0\.self_attn\.out_proj\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.self_attn.o_proj.\2", + r"encoder\.encoder\.(\d+)\.layers\.0\.linear1\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.mlp.layers.0.\2", + r"encoder\.encoder\.(\d+)\.layers\.0\.linear2\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.mlp.layers.1.\2", + r"encoder\.encoder\.(\d+)\.layers\.0\.norm1\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.self_attn_layer_norm.\2", + r"encoder\.encoder\.(\d+)\.layers\.0\.norm2\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.final_layer_norm.\2", + # Encoder projections and convolutions + r"encoder\.lateral_convs\.(\d+)\.conv\.weight": r"model.encoder.lateral_convs.\1.conv.weight", + r"encoder\.lateral_convs\.(\d+)\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.lateral_convs.\1.norm.\2", + # FPN blocks - complete structure + # Basic convolutions + r"encoder\.fpn_blocks\.(\d+)\.cv1\.conv\.weight": r"model.encoder.fpn_blocks.\1.conv1.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.conv1.norm.\2", + r"encoder\.fpn_blocks\.(\d+)\.cv4\.conv\.weight": r"model.encoder.fpn_blocks.\1.conv4.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv4\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.conv4.norm.\2", + # CSP Rep1 path + r"encoder\.fpn_blocks\.(\d+)\.cv2\.0\.conv1\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep1.conv1.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv2\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep1.conv1.norm.\2", + r"encoder\.fpn_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep1.bottlenecks.\2.conv1.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep1.bottlenecks.\2.conv1.norm.\3", + r"encoder\.fpn_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep1.bottlenecks.\2.conv2.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep1.bottlenecks.\2.conv2.norm.\3", + # CSP Rep2 path + r"encoder\.fpn_blocks\.(\d+)\.cv3\.0\.conv1\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep2.conv1.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep2.conv1.norm.\2", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep2.bottlenecks.\2.conv1.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep2.bottlenecks.\2.conv1.norm.\3", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep2.bottlenecks.\2.conv2.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep2.bottlenecks.\2.conv2.norm.\3", + # FPN trailing convs + r"encoder\.fpn_blocks\.(\d+)\.cv2\.1\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep1.conv3.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv2\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep1.conv3.norm.\2", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.1\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep2.conv3.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep2.conv3.norm.\2", + # PAN blocks - complete structure + r"encoder\.pan_blocks\.(\d+)\.cv1\.conv\.weight": r"model.encoder.pan_blocks.\1.conv1.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.conv1.norm.\2", + r"encoder\.pan_blocks\.(\d+)\.cv4\.conv\.weight": r"model.encoder.pan_blocks.\1.conv4.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv4\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.conv4.norm.\2", + # CSP Rep1 path + r"encoder\.pan_blocks\.(\d+)\.cv2\.0\.conv1\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep1.conv1.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv2\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep1.conv1.norm.\2", + r"encoder\.pan_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep1.bottlenecks.\2.conv1.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep1.bottlenecks.\2.conv1.norm.\3", + r"encoder\.pan_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep1.bottlenecks.\2.conv2.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep1.bottlenecks.\2.conv2.norm.\3", + # CSP Rep2 path + r"encoder\.pan_blocks\.(\d+)\.cv3\.0\.conv1\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep2.conv1.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv3\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep2.conv1.norm.\2", + r"encoder\.pan_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep2.bottlenecks.\2.conv1.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep2.bottlenecks.\2.conv1.norm.\3", + r"encoder\.pan_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep2.bottlenecks.\2.conv2.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep2.bottlenecks.\2.conv2.norm.\3", + # PAN trailing convs + r"encoder\.pan_blocks\.(\d+)\.cv2\.1\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep1.conv3.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv2\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep1.conv3.norm.\2", + r"encoder\.pan_blocks\.(\d+)\.cv3\.1\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep2.conv3.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv3\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep2.conv3.norm.\2", + # Downsample convolutions + r"encoder\.downsample_convs\.(\d+)\.0\.cv(\d+)\.conv\.weight": r"model.encoder.downsample_convs.\1.conv\2.conv.weight", + r"encoder\.downsample_convs\.(\d+)\.0\.cv(\d+)\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.downsample_convs.\1.conv\2.norm.\3", + # Decoder layers + r"decoder\.input_proj\.(\d+)\.0\.weight": r"model.decoder_input_proj.\1.0.weight", + r"decoder\.input_proj\.(\d+)\.1\.(weight|bias|running_mean|running_var)": r"model.decoder_input_proj.\1.1.\2", + r"decoder\.decoder\.layers\.(\d+)\.self_attn\.out_proj\.(weight|bias)": r"model.decoder.layers.\1.self_attn.o_proj.\2", + r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.sampling_offsets\.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.sampling_offsets.\2", + r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.attention_weights\.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.attention_weights.\2", + r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.value_proj\.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.value_proj.\2", + r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.output_proj\.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.output_proj.\2", + r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.num_points_scale": r"model.decoder.layers.\1.encoder_attn.num_points_scale", + r"decoder\.decoder\.layers\.(\d+)\.norm1\.scale": r"model.decoder.layers.\1.self_attn_layer_norm.scale", + r"decoder\.decoder\.layers\.(\d+)\.norm3\.scale": r"model.decoder.layers.\1.final_layer_norm.scale", + r"decoder\.decoder\.layers\.(\d+)\.swish_ffn\.w12\.(weight|bias)": r"model.decoder.layers.\1.mlp.w12.\2", + r"decoder\.decoder\.layers\.(\d+)\.swish_ffn\.w3\.(weight|bias)": r"model.decoder.layers.\1.mlp.w3.\2", + r"decoder\.decoder\.layers\.(\d+)\.gateway\.gate\.(weight|bias)": r"model.decoder.layers.\1.gateway.gate.\2", + r"decoder\.decoder\.layers\.(\d+)\.gateway\.norm\.scale": r"model.decoder.layers.\1.gateway.norm.scale", + # LQE layers + r"decoder\.decoder\.lqe_layers\.(\d+)\.reg_conf\.layers\.(\d+)\.(weight|bias)": r"model.decoder.lqe_layers.\1.reg_conf.layers.\2.\3", + # Decoder heads and projections + r"decoder\.dec_score_head\.(\d+)\.(weight|bias)": r"model.decoder.class_embed.\1.\2", + r"decoder\.dec_bbox_head\.(\d+)\.layers\.(\d+)\.(weight|bias)": r"model.decoder.bbox_embed.\1.layers.\2.\3", + r"decoder\.pre_bbox_head\.layers\.(\d+)\.(weight|bias)": r"model.decoder.pre_bbox_head.layers.\1.\2", + r"decoder\.query_pos_head\.layers\.(\d+)\.(weight|bias)": r"model.decoder.query_pos_head.layers.\1.\2", + # Encoder output and score heads + r"decoder\.enc_output\.proj\.(weight|bias)": r"model.enc_output.0.\1", + r"decoder\.enc_output\.norm\.(weight|bias)": r"model.enc_output.1.\1", + r"decoder\.enc_score_head\.(weight|bias)": r"model.enc_score_head.\1", + r"decoder\.enc_bbox_head\.layers\.(\d+)\.(weight|bias)": r"model.enc_bbox_head.layers.\1.\2", + # Denoising class embed + r"decoder\.denoising_class_embed\.weight": r"model.denoising_class_embed.weight", + # Decoder parameters + r"decoder\.decoder\.up": r"model.decoder.up", + r"decoder\.decoder\.reg_scale": r"model.decoder.reg_scale", +} + + +def convert_old_keys_to_new_keys(state_dict): + # Use the mapping to rename keys + for original_key, converted_key in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): + for key in list(state_dict.keys()): + new_key = re.sub(f"^{original_key}$", converted_key, key) + if new_key != key: + state_dict[new_key] = state_dict.pop(key) + return state_dict + + +def read_in_q_k_v(state_dict, config): + encoder_hidden_dim = config.encoder_hidden_dim + d_model = config.d_model + + # first: transformer encoder + for i in range(config.encoder_layers): + # read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"encoder.encoder.{i}.layers.0.self_attn.in_proj_weight", None) + in_proj_bias = state_dict.pop(f"encoder.encoder.{i}.layers.0.self_attn.in_proj_bias", None) + if in_proj_weight is not None: + # next, add query, keys and values (in that order) to the state dict + state_dict[f"model.encoder.aifi.{i}.layers.0.self_attn.q_proj.weight"] = in_proj_weight[ + :encoder_hidden_dim + ] + state_dict[f"model.encoder.aifi.{i}.layers.0.self_attn.k_proj.weight"] = in_proj_weight[ + encoder_hidden_dim : 2 * encoder_hidden_dim + ] + state_dict[f"model.encoder.aifi.{i}.layers.0.self_attn.v_proj.weight"] = in_proj_weight[ + -encoder_hidden_dim: + ] + if in_proj_bias is not None: + state_dict[f"model.encoder.aifi.{i}.layers.0.self_attn.q_proj.bias"] = in_proj_bias[:encoder_hidden_dim] + state_dict[f"model.encoder.aifi.{i}.layers.0.self_attn.k_proj.bias"] = in_proj_bias[ + encoder_hidden_dim : 2 * encoder_hidden_dim + ] + state_dict[f"model.encoder.aifi.{i}.layers.0.self_attn.v_proj.bias"] = in_proj_bias[-encoder_hidden_dim:] + + # next: transformer decoder (which is a bit more complex because it also includes cross-attention) + for i in range(config.decoder_layers): + # read in weights + bias of input projection layer of self-attention + in_proj_weight = state_dict.pop(f"decoder.decoder.layers.{i}.self_attn.in_proj_weight", None) + in_proj_bias = state_dict.pop(f"decoder.decoder.layers.{i}.self_attn.in_proj_bias", None) + if in_proj_weight is not None: + # next, add query, keys and values (in that order) to the state dict + state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:d_model] + state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:d_model] + state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[d_model : 2 * d_model] + state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[d_model : 2 * d_model] + state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-d_model:] + state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-d_model:] + + +def load_original_state_dict(repo_id): + filepath = hf_hub_download(repo_id=repo_id, filename="model.safetensors") + return load_file(filepath) + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + with httpx.stream("GET", url) as response: + image = Image.open(BytesIO(response.read())) + return image + + +@torch.no_grad() +def convert_deimv2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, repo_id): + """ + Copy/paste/tweak model's weights to our Deimv2 structure. + """ + hub_repo = MODEL_NAME_TO_HUB_REPO[model_name] + config = get_deimv2_config(model_name) + state_dict = load_original_state_dict(hub_repo) + + logger.info(f"Converting model {model_name} from {hub_repo}...") + logger.info(f"Original state dict has {len(state_dict)} keys") + + state_dict.pop("decoder.valid_mask", None) + state_dict.pop("decoder.anchors", None) + + # query, key and value matrices need special treatment + read_in_q_k_v(state_dict, config) + + state_dict = convert_old_keys_to_new_keys(state_dict) + + if "model.enc_output.0.weight" not in state_dict: + d_model = config.d_model + state_dict["model.enc_output.0.weight"] = torch.eye(d_model) + state_dict["model.enc_output.0.bias"] = torch.zeros(d_model) + state_dict["model.enc_output.1.weight"] = torch.ones(d_model) + state_dict["model.enc_output.1.bias"] = torch.zeros(d_model) + + # for two_stage + for key in list(state_dict.keys()): + if "bbox_embed" in key or ("class_embed" in key and "denoising_" not in key): + new_key = key.split("model.decoder.")[-1] + if new_key not in state_dict: + state_dict[new_key] = state_dict[key] + + model = Deimv2ForObjectDetection(config) + missing, unexpected = model.load_state_dict(state_dict, strict=False) + + if missing: + logger.warning(f"Missing keys ({len(missing)}): {missing[:10]}...") + if unexpected: + logger.warning(f"Unexpected keys ({len(unexpected)}): {unexpected[:10]}...") + + model.eval() + + image_processor = RTDetrImageProcessor() + img = prepare_img() + + transformations = transforms.Compose( + [ + transforms.Resize([640, 640], interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + ] + ) + original_pixel_values = transformations(img).unsqueeze(0) + encoding = image_processor(images=img, return_tensors="pt") + pixel_values = encoding["pixel_values"] + + assert torch.allclose(original_pixel_values, pixel_values), "Image preprocessing mismatch!" + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + pixel_values = pixel_values.to(device) + + outputs = model(pixel_values) + logger.info(f"Logits shape: {outputs.logits.shape}") + logger.info(f"Boxes shape: {outputs.pred_boxes.shape}") + logger.info(f"Logits sample: {outputs.logits[0, :3, :3]}") + logger.info(f"Boxes sample: {outputs.pred_boxes[0, :3, :3]}") + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + logger.info(f"Saving model to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + logger.info(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + push_repo = repo_id or f"deimv2-{model_name}" + logger.info(f"Pushing to hub: {push_repo}") + config.push_to_hub(repo_id=push_repo) + model.push_to_hub(repo_id=push_repo) + image_processor.push_to_hub(repo_id=push_repo) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + default="deimv2_hgnetv2_n_coco", + type=str, + choices=list(MODEL_NAME_TO_HUB_REPO.keys()), + help="Model name to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + help="Path to the output directory.", + ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether to push to the hub.") + parser.add_argument("--repo_id", type=str, default=None, help="Hub repo_id to push to.") + args = parser.parse_args() + convert_deimv2_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.repo_id) diff --git a/src/transformers/models/deimv2/modeling_deimv2.py b/src/transformers/models/deimv2/modeling_deimv2.py new file mode 100644 index 000000000000..154170174227 --- /dev/null +++ b/src/transformers/models/deimv2/modeling_deimv2.py @@ -0,0 +1,2098 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/deimv2/modular_deimv2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_deimv2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from collections.abc import Callable +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from ... import initialization as init +from ...activations import ACT2CLS +from ...backbone_utils import load_backbone +from ...image_transforms import center_to_corners_format, corners_to_center_format +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import compile_compatible_method_lru_cache +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int +from ...utils.generic import can_return_tuple, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from .configuration_deimv2 import Deimv2Config + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the Deimv2Decoder. This class adds two attributes to + BaseModelOutputWithCrossAttentions, namely: + - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer) + - a stacked tensor of intermediate reference points. + """ +) +class Deimv2DecoderOutput(ModelOutput): + r""" + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate predicted corners (predicted corners of each layer of the decoder). + initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked initial reference points (initial reference points of each layer of the decoder). + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor | None = None + intermediate_hidden_states: torch.FloatTensor | None = None + intermediate_logits: torch.FloatTensor | None = None + intermediate_reference_points: torch.FloatTensor | None = None + intermediate_predicted_corners: torch.FloatTensor | None = None + initial_reference_points: torch.FloatTensor | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + cross_attentions: tuple[torch.FloatTensor] | None = None + + +class Deimv2RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.float() + hidden_states = hidden_states * torch.rsqrt(hidden_states.pow(2).mean(-1, keepdim=True) + self.eps) + return (hidden_states * self.scale).to(input_dtype) + + +class Deimv2SwiGLUFFN(nn.Module): + def __init__(self, in_features: int, hidden_features: int, out_features: int): + super().__init__() + self.w12 = nn.Linear(in_features, 2 * hidden_features) + self.w3 = nn.Linear(hidden_features, out_features) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + x12 = self.w12(hidden_states) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +class Deimv2Gate(nn.Module): + def __init__(self, d_model: int): + super().__init__() + self.gate = nn.Linear(2 * d_model, 2 * d_model) + self.norm = Deimv2RMSNorm(d_model) + + def forward(self, second_residual: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor: + gate_input = torch.cat([second_residual, hidden_states], dim=-1) + gates = torch.sigmoid(self.gate(gate_input)) + gate1, gate2 = gates.chunk(2, dim=-1) + hidden_states = self.norm(gate1 * second_residual + gate2 * hidden_states) + return hidden_states + + +class Deimv2MLP(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act: str = "relu"): + super().__init__() + self.num_layers = num_layers + hidden_dims = [hidden_dim] * (num_layers - 1) + input_dims = [input_dim] + hidden_dims + output_dims = hidden_dims + [output_dim] + self.layers = nn.ModuleList(nn.Linear(in_dim, out_dim) for in_dim, out_dim in zip(input_dims, output_dims)) + self.act = ACT2CLS[act]() + + def forward(self, stat_features: torch.Tensor) -> torch.Tensor: + for i, layer in enumerate(self.layers): + stat_features = self.act(layer(stat_features)) if i < self.num_layers - 1 else layer(stat_features) + return stat_features + + +def multi_scale_deformable_attention_v2( + value: Tensor, + value_spatial_shapes: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, + num_points_list: list[int], + method="default", +) -> Tensor: + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points = sampling_locations.shape + value_list = ( + value.permute(0, 2, 3, 1) + .flatten(0, 1) + .split([height * width for height, width in value_spatial_shapes], dim=-1) + ) + # sampling_offsets [8, 480, 8, 12, 2] + if method == "default": + sampling_grids = 2 * sampling_locations - 1 + elif method == "discrete": + sampling_grids = sampling_locations + sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) + sampling_grids = sampling_grids.split(num_points_list, dim=-2) + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = value_list[level_id].reshape(batch_size * num_heads, hidden_dim, height, width) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[level_id] + # batch_size*num_heads, hidden_dim, num_queries, num_points + if method == "default": + sampling_value_l_ = nn.functional.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + elif method == "discrete": + sampling_coord = (sampling_grid_l_ * torch.tensor([[width, height]], device=value.device) + 0.5).to( + torch.int64 + ) + + # Separate clamping for x and y coordinates + sampling_coord_x = sampling_coord[..., 0].clamp(0, width - 1) + sampling_coord_y = sampling_coord[..., 1].clamp(0, height - 1) + + # Combine the clamped coordinates + sampling_coord = torch.stack([sampling_coord_x, sampling_coord_y], dim=-1) + sampling_coord = sampling_coord.reshape(batch_size * num_heads, num_queries * num_points_list[level_id], 2) + sampling_idx = ( + torch.arange(sampling_coord.shape[0], device=value.device) + .unsqueeze(-1) + .repeat(1, sampling_coord.shape[1]) + ) + sampling_value_l_ = value_l_[sampling_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] + sampling_value_l_ = sampling_value_l_.permute(0, 2, 1).reshape( + batch_size * num_heads, hidden_dim, num_queries, num_points_list[level_id] + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.permute(0, 2, 1, 3).reshape( + batch_size * num_heads, 1, num_queries, sum(num_points_list) + ) + output = ( + (torch.concat(sampling_value_list, dim=-1) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +class Deimv2MultiscaleDeformableAttention(nn.Module): + def __init__(self, config: Deimv2Config): + """ + D-Fine version of multiscale deformable attention + """ + super().__init__() + self.d_model = config.d_model + self.n_heads = config.decoder_attention_heads + self.n_levels = config.num_feature_levels + self.offset_scale = config.decoder_offset_scale + self.decoder_method = config.decoder_method + self.n_points = config.decoder_n_points + + if isinstance(self.n_points, list): + num_points_list = self.n_points + else: + num_points_list = [self.n_points for _ in range(self.n_levels)] + + self.num_points_list = num_points_list + num_points_scale = [1 / n for n in self.num_points_list for _ in range(n)] + self.register_buffer("num_points_scale", torch.tensor(num_points_scale, dtype=torch.float32)) + + self.total_points = self.n_heads * sum(self.num_points_list) + + self.sampling_offsets = nn.Linear(self.d_model, self.total_points * 2) + self.attention_weights = nn.Linear(self.d_model, self.total_points) + + self.ms_deformable_attn_core = multi_scale_deformable_attention_v2 + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + reference_points=None, + encoder_hidden_states=None, + spatial_shapes=None, + spatial_shapes_list=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + + torch_compilable_check( + (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == sequence_length, + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states", + ) + + # Reshape for multi-head attention + value = encoder_hidden_states.reshape(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + if attention_mask is not None: + value = value.masked_fill(~attention_mask[..., None], float(0)) + + sampling_offsets: torch.Tensor = self.sampling_offsets(hidden_states) + sampling_offsets = sampling_offsets.reshape( + batch_size, num_queries, self.n_heads, sum(self.num_points_list), 2 + ) + + attention_weights = self.attention_weights(hidden_states).reshape( + batch_size, num_queries, self.n_heads, sum(self.num_points_list) + ) + attention_weights = F.softmax(attention_weights, dim=-1) + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.tensor(spatial_shapes) + offset_normalizer = offset_normalizer.flip([1]).reshape(1, 1, 1, self.n_levels, 1, 2) + sampling_locations = ( + reference_points.reshape(batch_size, sequence_length, 1, self.n_levels, 1, 2) + + sampling_offsets / offset_normalizer + ) + elif reference_points.shape[-1] == 4: + # reference_points [8, 480, None, 1, 4] + # sampling_offsets [8, 480, 8, 12, 2] + num_points_scale = self.num_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1) + offset = sampling_offsets * num_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale + sampling_locations = reference_points[:, :, None, :, :2] + offset + else: + raise ValueError( + f"Last dim of reference_points must be 2 or 4, but get {reference_points.shape[-1]} instead." + ) + + output = self.ms_deformable_attn_core( + value, + spatial_shapes_list, + sampling_locations, + attention_weights, + self.num_points_list, + self.decoder_method, + ) + + return output, attention_weights + + +class Deimv2ConvNormLayer(nn.Module): + def __init__( + self, + config: Deimv2Config, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + groups: int = 1, + padding: int | None = None, + activation: str | None = None, + ): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + groups=groups, + padding=(kernel_size - 1) // 2 if padding is None else padding, + bias=False, + ) + self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, hidden_state): + hidden_state = self.conv(hidden_state) + hidden_state = self.norm(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class Deimv2RepVggBlock(nn.Module): + """ + RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again". + """ + + def __init__(self, config: Deimv2Config, in_channels: int, out_channels: int): + super().__init__() + + activation = config.activation_function + hidden_channels = in_channels + self.conv1 = Deimv2ConvNormLayer(config, hidden_channels, out_channels, 3, 1, padding=1) + self.conv2 = Deimv2ConvNormLayer(config, hidden_channels, out_channels, 1, 1, padding=0) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, x): + y = self.conv1(x) + self.conv2(x) + return self.activation(y) + + +class Deimv2CSPRepLayer2(nn.Module): + def __init__( + self, config: Deimv2Config, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0 + ): + super().__init__() + activation = config.activation_function + hidden_channels = int(out_channels * expansion) + self.conv1 = Deimv2ConvNormLayer(config, in_channels, hidden_channels * 2, 1, 1, activation=activation) + self.bottlenecks = nn.ModuleList( + [Deimv2RepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)] + ) + self.conv3 = Deimv2ConvNormLayer(config, hidden_channels, out_channels, 3, 1, activation=activation) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + chunks = list(self.conv1(hidden_state).chunk(2, 1)) + bottleneck_out = chunks[1] + for bottleneck in self.bottlenecks: + bottleneck_out = bottleneck(bottleneck_out) + return self.conv3(chunks[0] + bottleneck_out) + + +class Deimv2RepNCSPELAN5(nn.Module): + def __init__(self, config: Deimv2Config, numb_blocks: int = 3): + super().__init__() + act = config.activation_function + c1 = config.encoder_hidden_dim + c2 = config.encoder_hidden_dim + c3 = config.encoder_hidden_dim * 2 + c4 = round(config.hidden_expansion * config.encoder_hidden_dim // 2) + self.conv_dim = c3 // 2 + self.conv1 = Deimv2ConvNormLayer(config, c1, c3, 1, 1, activation=act) + self.csp_rep1 = Deimv2CSPRepLayer2(config, c3 // 2, c4, num_blocks=numb_blocks) + self.csp_rep2 = Deimv2CSPRepLayer2(config, c4, c4, num_blocks=numb_blocks) + self.conv4 = Deimv2ConvNormLayer(config, c3 + (2 * c4), c2, 1, 1, activation=act) + + def forward(self, input_features: torch.Tensor) -> torch.Tensor: + split_features = list(self.conv1(input_features).split((self.conv_dim, self.conv_dim), 1)) + branch1 = self.csp_rep1(split_features[-1]) + branch2 = self.csp_rep2(branch1) + split_features.extend([branch1, branch2]) + merged_features = torch.cat(split_features, 1) + return self.conv4(merged_features) + + +class Deimv2SCDown(nn.Module): + def __init__(self, config: Deimv2Config, kernel_size: int, stride: int): + super().__init__() + self.conv1 = Deimv2ConvNormLayer(config, config.encoder_hidden_dim, config.encoder_hidden_dim, 1, 1) + self.conv2 = Deimv2ConvNormLayer( + config, + config.encoder_hidden_dim, + config.encoder_hidden_dim, + kernel_size, + stride, + config.encoder_hidden_dim, + ) + + def forward(self, input_features: torch.Tensor) -> torch.Tensor: + input_features = self.conv1(input_features) + input_features = self.conv2(input_features) + return input_features + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float | None = None, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Deimv2SelfAttention(nn.Module): + """ + Multi-headed self-attention from 'Attention Is All You Need' paper. + + In DEIMV2, position embeddings are added to both queries and keys (but not values) in self-attention. + """ + + def __init__( + self, + config: Deimv2Config, + hidden_size: int, + num_attention_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + self.config = config + self.head_dim = hidden_size // num_attention_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = dropout + self.is_causal = False + + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Position embeddings are added to both queries and keys (but not values). + """ + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states + + query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Deimv2EncoderLayer(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + self.normalize_before = config.normalize_before + self.hidden_size = config.encoder_hidden_dim + + # self-attention + self.self_attn = Deimv2SelfAttention( + config=config, + hidden_size=self.hidden_size, + num_attention_heads=config.num_attention_heads, + dropout=config.dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) + self.dropout = config.dropout + self.mlp = Deimv2MLP( + self.hidden_size, config.encoder_ffn_dim, self.hidden_size, 2, config.encoder_activation_function + ) + self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + spatial_position_embeddings: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + spatial_position_embeddings (`torch.FloatTensor`, *optional*): + Spatial position embeddings (2D positional encodings of image locations), to be added to both + the queries and keys in self-attention (but not to values). + """ + residual = hidden_states + if self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=spatial_position_embeddings, + **kwargs, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + if self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + residual = hidden_states + + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + + if self.training: + if not torch.isfinite(hidden_states).all(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + return hidden_states + + +class Deimv2SinePositionEmbedding(nn.Module): + """ + 2D sinusoidal position embedding used in RT-DETR hybrid encoder. + """ + + def __init__(self, embed_dim: int = 256, temperature: int = 10000): + super().__init__() + self.embed_dim = embed_dim + self.temperature = temperature + + @compile_compatible_method_lru_cache(maxsize=32) + def forward( + self, + width: int, + height: int, + device: torch.device | str, + dtype: torch.dtype, + ) -> torch.Tensor: + """ + Generate 2D sinusoidal position embeddings. + + Returns: + Position embeddings of shape (1, height*width, embed_dim) + """ + grid_w = torch.arange(torch_int(width), device=device).to(dtype) + grid_h = torch.arange(torch_int(height), device=device).to(dtype) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy") + if self.embed_dim % 4 != 0: + raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding") + pos_dim = self.embed_dim // 4 + omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim + omega = 1.0 / (self.temperature**omega) + + out_w = grid_w.flatten()[..., None] @ omega[None] + out_h = grid_h.flatten()[..., None] @ omega[None] + + return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :] + + +class Deimv2AIFILayer(nn.Module): + """ + AIFI (Attention-based Intra-scale Feature Interaction) layer used in RT-DETR hybrid encoder. + """ + + def __init__(self, config: Deimv2Config): + super().__init__() + self.config = config + self.encoder_hidden_dim = config.encoder_hidden_dim + self.eval_size = config.eval_size + + self.position_embedding = Deimv2SinePositionEmbedding( + embed_dim=self.encoder_hidden_dim, + temperature=config.positional_encoding_temperature, + ) + self.layers = nn.ModuleList([Deimv2EncoderLayer(config) for _ in range(config.encoder_layers)]) + + def forward( + self, + hidden_states: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`): + Feature map to process. + """ + batch_size = hidden_states.shape[0] + height, width = hidden_states.shape[2:] + + hidden_states = hidden_states.flatten(2).permute(0, 2, 1) + + if self.training or self.eval_size is None: + pos_embed = self.position_embedding( + width=width, + height=height, + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + else: + pos_embed = None + + for layer in self.layers: + hidden_states = layer( + hidden_states, + attention_mask=None, + spatial_position_embeddings=pos_embed, + **kwargs, + ) + + hidden_states = ( + hidden_states.permute(0, 2, 1).reshape(batch_size, self.encoder_hidden_dim, height, width).contiguous() + ) + + return hidden_states + + +class Deimv2SpatialTuningAdapter(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + inplanes = config.sta_inplanes + self.stem = nn.Sequential( + nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(inplanes), + nn.GELU(), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + self.conv2 = nn.Sequential( + nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(2 * inplanes), + ) + self.conv3 = nn.Sequential( + nn.GELU(), + nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(4 * inplanes), + ) + self.conv4 = nn.Sequential( + nn.GELU(), + nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(4 * inplanes), + ) + + def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + c1 = self.stem(pixel_values) + c2 = self.conv2(c1) + c3 = self.conv3(c2) + c4 = self.conv4(c3) + return c2, c3, c4 + + +class Deimv2Integral(nn.Module): + """ + A static layer that calculates integral results from a distribution. + + This layer computes the target location using the formula: `sum{Pr(n) * W(n)}`, + where Pr(n) is the softmax probability vector representing the discrete + distribution, and W(n) is the non-uniform Weighting Function. + + Args: + max_num_bins (int): Max number of the discrete bins. Default is 32. + It can be adjusted based on the dataset or task requirements. + """ + + def __init__(self, config: Deimv2Config): + super().__init__() + self.max_num_bins = config.max_num_bins + + def forward(self, pred_corners: torch.Tensor, project: torch.Tensor) -> torch.Tensor: + batch_size, num_queries, _ = pred_corners.shape + pred_corners = F.softmax(pred_corners.reshape(-1, self.max_num_bins + 1), dim=1) + pred_corners = F.linear(pred_corners, project.to(pred_corners.device)).reshape(-1, 4) + pred_corners = pred_corners.reshape(batch_size, num_queries, -1) + return pred_corners + + +class Deimv2LQE(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + self.top_prob_values = config.top_prob_values + self.max_num_bins = config.max_num_bins + self.reg_conf = Deimv2MLP(4 * (self.top_prob_values + 1), config.lqe_hidden_dim, 1, config.lqe_layers) + + def forward(self, scores: torch.Tensor, pred_corners: torch.Tensor) -> torch.Tensor: + batch_size, length, _ = pred_corners.size() + prob = F.softmax(pred_corners.reshape(batch_size, length, 4, self.max_num_bins + 1), dim=-1) + prob_topk, _ = prob.topk(self.top_prob_values, dim=-1) + stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1) + quality_score = self.reg_conf(stat.reshape(batch_size, length, -1)) + scores = scores + quality_score + return scores + + +class Deimv2DecoderLayer(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + self.hidden_size = config.d_model + + # self-attention + self.self_attn = Deimv2SelfAttention( + config=config, + hidden_size=self.hidden_size, + num_attention_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + ) + self.dropout = config.dropout + self.self_attn_layer_norm = Deimv2RMSNorm(config.d_model) + self.encoder_attn = Deimv2MultiscaleDeformableAttention(config=config) + self.mlp = Deimv2SwiGLUFFN(config.d_model, config.decoder_ffn_dim // 2, config.d_model) + self.final_layer_norm = Deimv2RMSNorm(config.d_model) + self.gateway = Deimv2Gate(config.d_model) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor | None = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, hidden_size)`. + object_queries_position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings for the object query slots. These are added to both queries and keys + in the self-attention layer (not values). + reference_points (`torch.FloatTensor`, *optional*): + Reference points. + spatial_shapes (`torch.LongTensor`, *optional*): + Spatial shapes. + level_start_index (`torch.LongTensor`, *optional*): + Level start index. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, hidden_size)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + """ + residual = hidden_states + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=encoder_attention_mask, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + + # Cross-Attention + hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings + hidden_states, _ = self.encoder_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.gateway(residual, hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states.clamp(min=-65504, max=65504)) + + return hidden_states + + +class Deimv2MLPPredictionHead(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act: str = "relu"): + super().__init__() + self.num_layers = num_layers + hidden_dims = [hidden_dim] * (num_layers - 1) + input_dims = [input_dim] + hidden_dims + output_dims = hidden_dims + [output_dim] + self.layers = nn.ModuleList(nn.Linear(in_dim, out_dim) for in_dim, out_dim in zip(input_dims, output_dims)) + self.act = ACT2CLS[act]() + + def forward(self, stat_features: torch.Tensor) -> torch.Tensor: + for i, layer in enumerate(self.layers): + stat_features = self.act(layer(stat_features)) if i < self.num_layers - 1 else layer(stat_features) + return stat_features + + +@auto_docstring +class Deimv2PreTrainedModel(PreTrainedModel): + config: Deimv2Config + base_model_prefix = "deimv2" + main_input_name = "pixel_values" + input_modalities = ("image",) + _no_split_modules = [r"Deimv2HybridEncoder", r"Deimv2DecoderLayer"] + _supports_sdpa = True + _supports_flash_attn = True + _supports_attention_backend = True + _supports_flex_attn = True + + @torch.no_grad() + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (Deimv2ForObjectDetection, Deimv2Decoder)): + if module.class_embed is not None: + for layer in module.class_embed: + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias = float(-math.log((1 - prior_prob) / prior_prob)) + init.xavier_uniform_(layer.weight) + init.constant_(layer.bias, bias) + + if module.bbox_embed is not None: + for layer in module.bbox_embed: + init.constant_(layer.layers[-1].weight, 0) + init.constant_(layer.layers[-1].bias, 0) + + if hasattr(module, "reg_scale"): + init.constant_(module.reg_scale, self.config.reg_scale) + + if hasattr(module, "up"): + init.constant_(module.up, self.config.up) + + if isinstance(module, Deimv2MultiscaleDeformableAttention): + init.constant_(module.sampling_offsets.weight, 0.0) + default_dtype = torch.get_default_dtype() + thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( + 2.0 * math.pi / module.n_heads + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values + grid_init = grid_init.reshape(module.n_heads, 1, 2).tile([1, sum(module.num_points_list), 1]) + scaling = torch.concat([torch.arange(1, n + 1) for n in module.num_points_list]).reshape(1, -1, 1) + grid_init *= scaling + init.copy_(module.sampling_offsets.bias, grid_init.flatten()) + + init.constant_(module.attention_weights.weight, 0.0) + init.constant_(module.attention_weights.bias, 0.0) + + num_points_scale = [1 / n for n in module.num_points_list for _ in range(n)] + init.copy_(module.num_points_scale, torch.tensor(num_points_scale, dtype=torch.float32)) + + if isinstance(module, Deimv2Model): + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias = float(-math.log((1 - prior_prob) / prior_prob)) + init.xavier_uniform_(module.enc_score_head.weight) + init.constant_(module.enc_score_head.bias, bias) + + if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + init.zeros_(module.bias) + if getattr(module, "running_mean", None) is not None: + init.zeros_(module.running_mean) + init.ones_(module.running_var) + init.zeros_(module.num_batches_tracked) + + if isinstance(module, Deimv2Gate): + bias = float(-math.log((1 - 0.5) / 0.5)) + init.constant_(module.gate.bias, bias) + init.constant_(module.gate.weight, 0) + + if isinstance(module, Deimv2LQE): + init.constant_(module.reg_conf.layers[-1].bias, 0) + init.constant_(module.reg_conf.layers[-1].weight, 0) + + if isinstance(module, Deimv2SwiGLUFFN): + init.xavier_uniform_(module.w12.weight) + init.constant_(module.w12.bias, 0) + init.xavier_uniform_(module.w3.weight) + init.constant_(module.w3.bias, 0) + + if isinstance(module, Deimv2RMSNorm): + init.ones_(module.scale) + + if isinstance(module, nn.LayerNorm): + init.ones_(module.weight) + init.zeros_(module.bias) + + if hasattr(module, "weight_embedding") and self.config.learn_initial_query: + init.xavier_uniform_(module.weight_embedding.weight) + if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: + init.xavier_uniform_(module.denoising_class_embed.weight) + + +class Deimv2HybridEncoder(Deimv2PreTrainedModel): + """ + Hybrid encoder consisting of AIFI (Attention-based Intra-scale Feature Interaction) layers, + a top-down Feature Pyramid Network (FPN) and a bottom-up Path Aggregation Network (PAN). + More details on the paper: https://huggingface.co/papers/2304.08069 + + Args: + config: Deimv2Config + """ + + _can_record_outputs = { + "hidden_states": Deimv2AIFILayer, + "attentions": Deimv2SelfAttention, + } + + def __init__(self, config: Deimv2Config): + super().__init__(config) + self.config = config + self.in_channels = config.encoder_in_channels + self.num_fpn_stages = len(self.in_channels) - 1 + self.feat_strides = config.feat_strides + self.encoder_hidden_dim = config.encoder_hidden_dim + self.encode_proj_layers = config.encode_proj_layers + self.positional_encoding_temperature = config.positional_encoding_temperature + self.eval_size = config.eval_size + self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels] + self.out_strides = self.feat_strides + self.encoder_fuse_op = config.encoder_fuse_op + + self.aifi = nn.ModuleList([Deimv2AIFILayer(config) for _ in range(len(self.encode_proj_layers))]) + + self.lateral_convs = nn.ModuleList() + self.fpn_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1, 0, -1): + lateral_layer = Deimv2ConvNormLayer(config, self.encoder_hidden_dim, self.encoder_hidden_dim, 1, 1) + self.lateral_convs.append(lateral_layer) + num_blocks = round(3 * config.depth_mult) + fpn_layer = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks) + self.fpn_blocks.append(fpn_layer) + + self.downsample_convs = nn.ModuleList() + self.pan_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1): + self.downsample_convs.append(Deimv2SCDown(config, 3, 2)) + num_blocks = round(3 * config.depth_mult) + self.pan_blocks.append(Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks)) + + self.post_init() + + @merge_with_config_defaults + @capture_outputs(tie_last_hidden_states=False) + def forward( + self, + inputs_embeds=None, + **kwargs, + ) -> BaseModelOutput: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + """ + feature_maps = inputs_embeds + + if self.config.encoder_layers > 0: + for i, enc_ind in enumerate(self.encode_proj_layers): + feature_maps[enc_ind] = self.aifi[i](feature_maps[enc_ind], **kwargs) + + fpn_feature_maps = [feature_maps[-1]] + for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)): + backbone_feature_map = feature_maps[self.num_fpn_stages - idx - 1] + top_fpn_feature_map = fpn_feature_maps[-1] + top_fpn_feature_map = lateral_conv(top_fpn_feature_map) + fpn_feature_maps[-1] = top_fpn_feature_map + top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest") + if self.encoder_fuse_op == "sum": + fused_feature_map = top_fpn_feature_map + backbone_feature_map + else: + fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1) + new_fpn_feature_map = fpn_block(fused_feature_map) + fpn_feature_maps.append(new_fpn_feature_map) + + fpn_feature_maps.reverse() + + pan_feature_maps = [fpn_feature_maps[0]] + for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)): + top_pan_feature_map = pan_feature_maps[-1] + fpn_feature_map = fpn_feature_maps[idx + 1] + downsampled_feature_map = downsample_conv(top_pan_feature_map) + if self.encoder_fuse_op == "sum": + fused_feature_map = downsampled_feature_map + fpn_feature_map + else: + fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1) + new_pan_feature_map = pan_block(fused_feature_map) + pan_feature_maps.append(new_pan_feature_map) + + return BaseModelOutput(last_hidden_state=pan_feature_maps) + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +def weighting_function(max_num_bins: int, up: torch.Tensor, reg_scale: int) -> torch.Tensor: + """ + Generates the non-uniform Weighting Function W(n) for bounding box regression. + + Args: + max_num_bins (int): Max number of the discrete bins. + up (Tensor): Controls upper bounds of the sequence, + where maximum offset is ±up * H / W. + reg_scale (float): Controls the curvature of the Weighting Function. + Larger values result in flatter weights near the central axis W(max_num_bins/2)=0 + and steeper weights at both ends. + Returns: + Tensor: Sequence of Weighting Function. + """ + upper_bound1 = abs(up[0]) * abs(reg_scale) + upper_bound2 = abs(up[0]) * abs(reg_scale) * 2 + step = (upper_bound1 + 1) ** (2 / (max_num_bins - 2)) + left_values = [-((step) ** i) + 1 for i in range(max_num_bins // 2 - 1, 0, -1)] + right_values = [(step) ** i - 1 for i in range(1, max_num_bins // 2)] + values = [-upper_bound2] + left_values + [torch.zeros_like(up[0][None])] + right_values + [upper_bound2] + values = torch.cat(values, 0) + return values + + +def distance2bbox(points, distance: torch.Tensor, reg_scale: float) -> torch.Tensor: + """ + Decodes edge-distances into bounding box coordinates. + + Args: + points (`torch.Tensor`): + (batch_size, num_boxes, 4) or (num_boxes, 4) format, representing [x_center, y_center, width, height] + distance (`torch.Tensor`): + (batch_size, num_boxes, 4) or (num_boxes, 4), representing distances from the point to the left, top, right, and bottom boundaries. + reg_scale (`float`): + Controls the curvature of the Weighting Function. + Returns: + `torch.Tensor`: Bounding boxes in (batch_size, num_boxes, 4) or (num_boxes, 4) format, representing [x_center, y_center, width, height] + """ + reg_scale = abs(reg_scale) + top_left_x = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * (points[..., 2] / reg_scale) + top_left_y = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * (points[..., 3] / reg_scale) + bottom_right_x = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * (points[..., 2] / reg_scale) + bottom_right_y = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * (points[..., 3] / reg_scale) + + bboxes = torch.stack([top_left_x, top_left_y, bottom_right_x, bottom_right_y], -1) + + return corners_to_center_format(bboxes) + + +class Deimv2Decoder(Deimv2PreTrainedModel): + """ + D-FINE Decoder implementing Fine-grained Distribution Refinement (FDR). + + This decoder refines object detection predictions through iterative updates across multiple layers, + utilizing attention mechanisms, location quality estimators, and distribution refinement techniques + to improve bounding box accuracy and robustness. + """ + + _can_record_outputs = { + "hidden_states": Deimv2DecoderLayer, + "attentions": Deimv2SelfAttention, + "cross_attentions": Deimv2MultiscaleDeformableAttention, + } + + def __init__(self, config: Deimv2Config): + super().__init__(config) + self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx + self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx + + self.dropout = config.dropout + self.layers = nn.ModuleList( + [Deimv2DecoderLayer(config) for _ in range(config.decoder_layers)] + + [Deimv2DecoderLayer(config) for _ in range(config.decoder_layers - self.eval_idx - 1)] + ) + self.query_pos_head = Deimv2MLP(4, config.d_model, config.d_model, 3, config.decoder_activation_function) + + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + self.reg_scale = nn.Parameter(torch.tensor([config.reg_scale]), requires_grad=False) + self.max_num_bins = config.max_num_bins + self.d_model = config.d_model + self.layer_scale = config.layer_scale + self.pre_bbox_head = Deimv2MLP(config.hidden_size, config.hidden_size, 4, 3) + self.integral = Deimv2Integral(config) + self.num_head = config.decoder_attention_heads + self.up = nn.Parameter(torch.tensor([config.up]), requires_grad=False) + self.lqe_layers = nn.ModuleList([Deimv2LQE(config) for _ in range(config.decoder_layers)]) + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + def forward( + self, + encoder_hidden_states: torch.Tensor, + reference_points: torch.Tensor, + inputs_embeds: torch.Tensor, + spatial_shapes, + level_start_index=None, + spatial_shapes_list=None, + encoder_attention_mask=None, + memory_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> Deimv2DecoderOutput: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + The query embeddings that are passed into the decoder. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected + in `[0, 1]`: + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*): + Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area. + spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of the feature maps. + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*): + Indexes for the start of each feature level. In range `[0, sequence_length]`. + """ + if inputs_embeds is not None: + hidden_states = inputs_embeds + + # decoder layers + intermediate = () + intermediate_reference_points = () + intermediate_logits = () + intermediate_predicted_corners = () + initial_reference_points = () + + output_detach = pred_corners_undetach = 0 + + project = weighting_function(self.max_num_bins, self.up, self.reg_scale) + ref_points_detach = F.sigmoid(reference_points) + + for i, decoder_layer in enumerate(self.layers): + ref_points_input = ref_points_detach.unsqueeze(2) + query_pos_embed = self.query_pos_head(ref_points_detach).clamp(min=-10, max=10) + + hidden_states = decoder_layer( + hidden_states, + position_embeddings=query_pos_embed, + reference_points=ref_points_input, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + **kwargs, + ) + + if i == 0: + # Initial bounding box predictions with inverse sigmoid refinement + new_reference_points = F.sigmoid( + self.pre_bbox_head(hidden_states) + inverse_sigmoid(ref_points_detach) + ) + ref_points_initial = new_reference_points.detach() + + # Refine bounding box corners using FDR, integrating previous layer's corrections + if self.bbox_embed is not None: + pred_corners = self.bbox_embed[i](hidden_states + output_detach) + pred_corners_undetach + inter_ref_bbox = distance2bbox( + ref_points_initial, self.integral(pred_corners, project), self.reg_scale + ) + pred_corners_undetach = pred_corners + ref_points_detach = inter_ref_bbox.detach() + + output_detach = hidden_states.detach() + + intermediate += (hidden_states,) + + if self.class_embed is not None and (self.training or i == self.eval_idx): + scores = self.class_embed[i](hidden_states) + # Add initial logits and reference points with pre-bbox head + if i == 0: + intermediate_logits += (scores,) + intermediate_reference_points += (new_reference_points,) + # Lqe does not affect the performance here. + scores = self.lqe_layers[i](scores, pred_corners) + intermediate_logits += (scores,) + intermediate_reference_points += (inter_ref_bbox,) + initial_reference_points += (ref_points_initial,) + intermediate_predicted_corners += (pred_corners,) + + # Keep batch_size as first dimension + intermediate = torch.stack(intermediate) + if self.class_embed is not None and self.bbox_embed is not None: + intermediate_logits = torch.stack(intermediate_logits, dim=1) + intermediate_predicted_corners = torch.stack(intermediate_predicted_corners, dim=1) + initial_reference_points = torch.stack(initial_reference_points, dim=1) + intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1) + + return Deimv2DecoderOutput( + last_hidden_state=hidden_states, + intermediate_hidden_states=intermediate, + intermediate_logits=intermediate_logits, + intermediate_reference_points=intermediate_reference_points, + intermediate_predicted_corners=intermediate_predicted_corners, + initial_reference_points=initial_reference_points, + ) + + +class Deimv2FrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the RT-DETR encoder-decoder model. + """ +) +class Deimv2ModelOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate predicted corners (predicted corners of each layer of the decoder). + initial_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points used for the first decoder layer. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`): + Logits of predicted bounding boxes coordinates in the encoder stage. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values. + """ + + last_hidden_state: torch.FloatTensor | None = None + intermediate_hidden_states: torch.FloatTensor | None = None + intermediate_logits: torch.FloatTensor | None = None + intermediate_reference_points: torch.FloatTensor | None = None + intermediate_predicted_corners: torch.FloatTensor | None = None + initial_reference_points: torch.FloatTensor | None = None + decoder_hidden_states: tuple[torch.FloatTensor] | None = None + decoder_attentions: tuple[torch.FloatTensor] | None = None + cross_attentions: tuple[torch.FloatTensor] | None = None + encoder_last_hidden_state: torch.FloatTensor | None = None + encoder_hidden_states: tuple[torch.FloatTensor] | None = None + encoder_attentions: tuple[torch.FloatTensor] | None = None + init_reference_points: torch.FloatTensor | None = None + enc_topk_logits: torch.FloatTensor | None = None + enc_topk_bboxes: torch.FloatTensor | None = None + enc_outputs_class: torch.FloatTensor | None = None + enc_outputs_coord_logits: torch.FloatTensor | None = None + denoising_meta_values: dict | None = None + + +def replace_batch_norm(model): + r""" + Recursively replace all `torch.nn.BatchNorm2d` with `Deimv2FrozenBatchNorm2d`. + + Args: + model (torch.nn.Module): + input model + """ + for name, module in model.named_children(): + if isinstance(module, nn.BatchNorm2d): + new_module = Deimv2FrozenBatchNorm2d(module.num_features) + + if module.weight.device != torch.device("meta"): + new_module.weight.copy_(module.weight) + new_module.bias.copy_(module.bias) + new_module.running_mean.copy_(module.running_mean) + new_module.running_var.copy_(module.running_var) + + model._modules[name] = new_module + + if len(list(module.children())) > 0: + replace_batch_norm(module) + + +class Deimv2ConvEncoder(nn.Module): + """ + Convolutional backbone using the modeling_deimv2_resnet.py. + + nn.BatchNorm2d layers are replaced by Deimv2FrozenBatchNorm2d as defined above. + https://github.com/lyuwenyu/RT-DETR/blob/main/Deimv2_pytorch/src/nn/backbone/presnet.py#L142 + """ + + def __init__(self, config): + super().__init__() + + backbone = load_backbone(config) + + if config.freeze_backbone_batch_norms: + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) + self.model = backbone + self.intermediate_channel_sizes = self.model.channels + + def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): + # send pixel_values through the model to get list of feature maps + features = self.model(pixel_values).feature_maps + + out = [] + for feature_map in features: + # downsample pixel_mask to match shape of corresponding feature_map + mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] + out.append((feature_map, mask)) + return out + + +def get_contrastive_denoising_training_group( + targets, + num_classes, + num_queries, + class_embed, + num_denoising_queries=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, +): + """ + Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes. + + Args: + targets (`list[dict]`): + The target objects, each containing 'class_labels' and 'boxes' for objects in an image. + num_classes (`int`): + Total number of classes in the dataset. + num_queries (`int`): + Number of query slots in the transformer. + class_embed (`callable`): + A function or a model layer to embed class labels. + num_denoising_queries (`int`, *optional*, defaults to 100): + Number of denoising queries. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + Ratio of noise applied to labels. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale of noise applied to bounding boxes. + Returns: + `tuple` comprising various elements: + - **input_query_class** (`torch.FloatTensor`) -- + Class queries with applied label noise. + - **input_query_bbox** (`torch.FloatTensor`) -- + Bounding box queries with applied box noise. + - **attn_mask** (`torch.FloatTensor`) -- + Attention mask for separating denoising and reconstruction queries. + - **denoising_meta_values** (`dict`) -- + Metadata including denoising positive indices, number of groups, and split sizes. + """ + + if num_denoising_queries <= 0: + return None, None, None, None + + num_ground_truths = [len(t["class_labels"]) for t in targets] + device = targets[0]["class_labels"].device + + max_gt_num = max(num_ground_truths) + if max_gt_num == 0: + return None, None, None, None + + num_groups_denoising_queries = num_denoising_queries // max_gt_num + num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries + # pad gt to max_num of a batch + batch_size = len(num_ground_truths) + + input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device) + input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device) + pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device) + + for i in range(batch_size): + num_gt = num_ground_truths[i] + if num_gt > 0: + input_query_class[i, :num_gt] = targets[i]["class_labels"] + input_query_bbox[i, :num_gt] = targets[i]["boxes"] + pad_gt_mask[i, :num_gt] = 1 + # each group has positive and negative queries. + input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries]) + input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1]) + pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries]) + # positive and negative mask + negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device) + negative_gt_mask[:, max_gt_num:] = 1 + negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1]) + positive_gt_mask = 1 - negative_gt_mask + # contrastive denoising training positive index + positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask + denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1] + denoise_positive_idx = torch.split( + denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths] + ) + # total denoising queries + num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries) + + if label_noise_ratio > 0: + mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5) + # randomly put a new one here + new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype) + input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class) + + if box_noise_scale > 0: + known_bbox = center_to_corners_format(input_query_bbox) + diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale + rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0 + rand_part = torch.rand_like(input_query_bbox) + rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask) + rand_part *= rand_sign + known_bbox += rand_part * diff + known_bbox.clip_(min=0.0, max=1.0) + input_query_bbox = corners_to_center_format(known_bbox) + input_query_bbox = inverse_sigmoid(input_query_bbox) + + input_query_class = class_embed(input_query_class) + + target_size = num_denoising_queries + num_queries + attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device) + # match query cannot see the reconstruction + attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf + + # reconstructions cannot see each other + for i in range(num_groups_denoising_queries): + idx_block_start = max_gt_num * 2 * i + idx_block_end = max_gt_num * 2 * (i + 1) + attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf + attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf + + denoising_meta_values = { + "dn_positive_idx": denoise_positive_idx, + "dn_num_group": num_groups_denoising_queries, + "dn_num_split": [num_denoising_queries, num_queries], + } + + return input_query_class, input_query_bbox, attn_mask, denoising_meta_values + + +@auto_docstring( + custom_intro=""" + RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top. + """ +) +class Deimv2Model(Deimv2PreTrainedModel): + def __init__(self, config: Deimv2Config): + super().__init__(config) + + # Create backbone + self.backbone = Deimv2ConvEncoder(config) + intermediate_channel_sizes = self.backbone.intermediate_channel_sizes + num_backbone_outs = len(config.decoder_in_channels) + encoder_input_proj_list = [] + for i in range(num_backbone_outs): + in_channels = intermediate_channel_sizes[i] + encoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(config.encoder_hidden_dim), + ) + ) + self.encoder_input_proj = nn.ModuleList(encoder_input_proj_list) + self.encoder = Deimv2HybridEncoder(config=config) + + # denoising part + if config.num_denoising > 0: + self.denoising_class_embed = nn.Embedding( + config.num_labels + 1, config.d_model, padding_idx=config.num_labels + ) + + # decoder embedding + if config.learn_initial_query: + self.weight_embedding = nn.Embedding(config.num_queries, config.d_model) + + # encoder head + self.enc_output = nn.Sequential( + nn.Linear(config.d_model, config.d_model), + nn.LayerNorm(config.d_model, eps=config.layer_norm_eps), + ) + self.enc_score_head = nn.Linear(config.d_model, config.num_labels) + self.enc_bbox_head = Deimv2MLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3) + + # init encoder output anchors and valid_mask + if config.anchor_image_size: + self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype) + num_backbone_outs = len(config.decoder_in_channels) + decoder_input_proj_list = [] + for i in range(num_backbone_outs): + in_channels = config.decoder_in_channels[i] + decoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False), + nn.BatchNorm2d(config.d_model, config.batch_norm_eps), + ) + ) + for _ in range(config.num_feature_levels - num_backbone_outs): + decoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(config.d_model, config.batch_norm_eps), + ) + ) + in_channels = config.d_model + self.decoder = Deimv2Decoder(config) + decoder_input_proj = [] + in_channels = config.decoder_in_channels[-1] + for _ in range(num_backbone_outs): + if config.hidden_size == config.decoder_in_channels[-1]: + decoder_input_proj.append(nn.Identity()) + else: + conv = nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False) + batchnorm = nn.BatchNorm2d(config.d_model, config.batch_norm_eps) + decoder_input_proj.append(nn.Sequential(conv, batchnorm)) + for _ in range(config.num_feature_levels - num_backbone_outs): + if config.hidden_size == config.decoder_in_channels[-1]: + decoder_input_proj.append(nn.Identity()) + else: + conv = nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False) + batchnorm = nn.BatchNorm2d(config.d_model, config.batch_norm_eps) + decoder_input_proj.append(nn.Sequential(conv, batchnorm)) + self.decoder_input_proj = nn.ModuleList(decoder_input_proj) + + if config.use_spatial_tuning_adapter: + self.spatial_tuning_adapter = Deimv2SpatialTuningAdapter(config) + + self.post_init() + + def freeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad_(False) + + def unfreeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad_(True) + + @compile_compatible_method_lru_cache(maxsize=32) + def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32): + if spatial_shapes is None: + spatial_shapes = [ + [int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)] + for s in self.config.feat_strides + ] + anchors = [] + for level, (height, width) in enumerate(spatial_shapes): + grid_y, grid_x = torch.meshgrid( + torch.arange(end=height, device=device).to(dtype), + torch.arange(end=width, device=device).to(dtype), + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], -1) + grid_xy = grid_xy.unsqueeze(0) + 0.5 + grid_xy[..., 0] /= width + grid_xy[..., 1] /= height + wh = torch.ones_like(grid_xy) * grid_size * (2.0**level) + anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4)) + # define the valid range for anchor coordinates + eps = 1e-2 + anchors = torch.concat(anchors, 1) + valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) + anchors = torch.log(anchors / (1 - anchors)) + anchors = torch.where(valid_mask, anchors, torch.tensor(torch.finfo(dtype).max, dtype=dtype, device=device)) + + return anchors, valid_mask + + @auto_docstring + @can_return_tuple + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: torch.LongTensor | None = None, + encoder_outputs: torch.FloatTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: list[dict] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor] | Deimv2ModelOutput: + r""" + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + labels (`list[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, Deimv2Model + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("PekingU/Deimv2_r50vd") + >>> model = Deimv2Model.from_pretrained("PekingU/Deimv2_r50vd") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 300, 256] + ```""" + if pixel_values is None and inputs_embeds is None: + raise ValueError("You have to specify either pixel_values or inputs_embeds") + + if inputs_embeds is None: + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + features = self.backbone(pixel_values, pixel_mask) + proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)] + else: + batch_size = inputs_embeds.shape[0] + device = inputs_embeds.device + proj_feats = inputs_embeds + + if encoder_outputs is None: + encoder_outputs = self.encoder( + proj_feats, + **kwargs, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput + elif not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Equivalent to def _get_encoder_input + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/Deimv2_pytorch/src/zoo/Deimv2/Deimv2_decoder.py#L412 + sources = [] + for level, source in enumerate(encoder_outputs.last_hidden_state): + sources.append(self.decoder_input_proj[level](source)) + + # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage + if self.config.num_feature_levels > len(sources): + _len_sources = len(sources) + sources.append(self.decoder_input_proj[_len_sources](encoder_outputs.last_hidden_state)[-1]) + for i in range(_len_sources + 1, self.config.num_feature_levels): + sources.append(self.decoder_input_proj[i](encoder_outputs.last_hidden_state[-1])) + + # Prepare encoder inputs (by flattening) + source_flatten = [] + spatial_shapes_list = [] + spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long) + for level, source in enumerate(sources): + height, width = source.shape[-2:] + spatial_shapes[level, 0] = height + spatial_shapes[level, 1] = width + spatial_shapes_list.append((height, width)) + source = source.flatten(2).transpose(1, 2) + source_flatten.append(source) + source_flatten = torch.cat(source_flatten, 1) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + + # prepare denoising training + if self.training and self.config.num_denoising > 0 and labels is not None: + ( + denoising_class, + denoising_bbox_unact, + attention_mask, + denoising_meta_values, + ) = get_contrastive_denoising_training_group( + targets=labels, + num_classes=self.config.num_labels, + num_queries=self.config.num_queries, + class_embed=self.denoising_class_embed, + num_denoising_queries=self.config.num_denoising, + label_noise_ratio=self.config.label_noise_ratio, + box_noise_scale=self.config.box_noise_scale, + ) + else: + denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None + + batch_size = len(source_flatten) + device = source_flatten.device + dtype = source_flatten.dtype + + # prepare input for decoder + if self.training or self.config.anchor_image_size is None: + # Pass spatial_shapes as tuple to make it hashable and make sure + # lru_cache is working for generate_anchors() + spatial_shapes_tuple = tuple(spatial_shapes_list) + anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype) + else: + anchors, valid_mask = self.anchors, self.valid_mask + anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype) + + # use the valid_mask to selectively retain values in the feature map where the mask is `True` + memory = valid_mask.to(source_flatten.dtype) * source_flatten + + output_memory = self.enc_output(memory) + + enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors + + _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1) + + reference_points_unact = enc_outputs_coord_logits.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1]) + ) + + enc_topk_bboxes = F.sigmoid(reference_points_unact) + if denoising_bbox_unact is not None: + reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1) + + enc_topk_logits = enc_outputs_class.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]) + ) + + # extract region features + if self.config.learn_initial_query: + target = self.weight_embedding.tile([batch_size, 1, 1]) + else: + target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1])) + target = target.detach() + + if denoising_class is not None: + target = torch.concat([denoising_class, target], 1) + + init_reference_points = reference_points_unact.detach() + + # decoder + decoder_outputs = self.decoder( + inputs_embeds=target, + encoder_hidden_states=source_flatten, + encoder_attention_mask=attention_mask, + reference_points=init_reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + **kwargs, + ) + + return Deimv2ModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + intermediate_logits=decoder_outputs.intermediate_logits, + intermediate_reference_points=decoder_outputs.intermediate_reference_points, + intermediate_predicted_corners=decoder_outputs.intermediate_predicted_corners, + initial_reference_points=decoder_outputs.initial_reference_points, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + init_reference_points=init_reference_points, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + enc_outputs_class=enc_outputs_class, + enc_outputs_coord_logits=enc_outputs_coord_logits, + denoising_meta_values=denoising_meta_values, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`Deimv2ForObjectDetection`]. + """ +) +class Deimv2ObjectDetectionOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~Deimv2ImageProcessor.post_process_object_detection`] to retrieve the + unnormalized (absolute) bounding boxes. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate predicted corners (predicted corners of each layer of the decoder). + initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked initial reference points (initial reference points of each layer of the decoder). + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the encoder. + enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the encoder. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values + """ + + loss: torch.FloatTensor | None = None + loss_dict: dict | None = None + logits: torch.FloatTensor | None = None + pred_boxes: torch.FloatTensor | None = None + auxiliary_outputs: list[dict] | None = None + last_hidden_state: torch.FloatTensor | None = None + intermediate_hidden_states: torch.FloatTensor | None = None + intermediate_logits: torch.FloatTensor | None = None + intermediate_reference_points: torch.FloatTensor | None = None + intermediate_predicted_corners: torch.FloatTensor | None = None + initial_reference_points: torch.FloatTensor | None = None + decoder_hidden_states: tuple[torch.FloatTensor] | None = None + decoder_attentions: tuple[torch.FloatTensor] | None = None + cross_attentions: tuple[torch.FloatTensor] | None = None + encoder_last_hidden_state: torch.FloatTensor | None = None + encoder_hidden_states: tuple[torch.FloatTensor] | None = None + encoder_attentions: tuple[torch.FloatTensor] | None = None + init_reference_points: tuple[torch.FloatTensor] | None = None + enc_topk_logits: torch.FloatTensor | None = None + enc_topk_bboxes: torch.FloatTensor | None = None + enc_outputs_class: torch.FloatTensor | None = None + enc_outputs_coord_logits: torch.FloatTensor | None = None + denoising_meta_values: dict | None = None + + +@auto_docstring( + custom_intro=""" + RT-DETR Model (consisting of a backbone and encoder-decoder) outputting bounding boxes and logits to be further + decoded into scores and classes. + """ +) +class Deimv2ForObjectDetection(Deimv2PreTrainedModel): + _no_split_modules = None + _tied_weights_keys = { + r"bbox_embed.(?![0])\d+": r"bbox_embed.0", + r"class_embed.(?![0])\d+": r"^class_embed.0", + "class_embed": "model.decoder.class_embed", + "bbox_embed": "model.decoder.bbox_embed", + } + + def __init__(self, config: Deimv2Config): + super().__init__(config) + + self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx + self.model = Deimv2Model(config) + scaled_dim = round(config.layer_scale * config.hidden_size) + num_pred = config.decoder_layers + self.class_embed = nn.ModuleList([nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList( + [ + Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) + for _ in range(self.eval_idx + 1) + ] + + [ + Deimv2MLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3) + for _ in range(config.decoder_layers - self.eval_idx - 1) + ] + ) + + self.model.decoder.class_embed = self.class_embed + self.model.decoder.bbox_embed = self.bbox_embed + self.post_init() + + def _set_aux_loss(self, outputs_class, outputs_coord): + return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)] + + @auto_docstring + @can_return_tuple + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: torch.LongTensor | None = None, + encoder_outputs: torch.FloatTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: list[dict] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor] | Deimv2ObjectDetectionOutput: + r""" + Example: + + ```python + >>> import torch + >>> from transformers.image_utils import load_image + >>> from transformers import AutoImageProcessor, Deimv2ForObjectDetection + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> image_processor = AutoImageProcessor.from_pretrained("Intellindust/DEIMv2_HGNetv2_N_COCO") + >>> model = Deimv2ForObjectDetection.from_pretrained("Intellindust/DEIMv2_HGNetv2_N_COCO") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 300, 80] + + >>> boxes = outputs.pred_boxes + >>> list(boxes.shape) + [1, 300, 4] + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = torch.tensor([image.size[::-1]]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes) + >>> result = results[0] # first image in batch + + >>> for score, label, box in zip(result["scores"], result["labels"], result["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + ``` + """ + outputs = self.model( + pixel_values, + pixel_mask=pixel_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + labels=labels, + **kwargs, + ) + + denoising_meta_values = outputs.denoising_meta_values if self.training else None + + outputs_class = outputs.intermediate_logits + outputs_coord = outputs.intermediate_reference_points + predicted_corners = outputs.intermediate_predicted_corners + initial_reference_points = outputs.initial_reference_points + + logits = outputs_class[:, -1] + pred_boxes = outputs_coord[:, -1] + + loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None + if labels is not None: + enc_topk_logits = outputs.enc_topk_logits + enc_topk_bboxes = outputs.enc_topk_bboxes + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, + labels, + self.device, + pred_boxes, + self.config, + outputs_class, + outputs_coord, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + denoising_meta_values=denoising_meta_values, + predicted_corners=predicted_corners, + initial_reference_points=initial_reference_points, + **kwargs, + ) + + return Deimv2ObjectDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=outputs.last_hidden_state, + intermediate_hidden_states=outputs.intermediate_hidden_states, + intermediate_logits=outputs.intermediate_logits, + intermediate_reference_points=outputs.intermediate_reference_points, + intermediate_predicted_corners=outputs.intermediate_predicted_corners, + initial_reference_points=outputs.initial_reference_points, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + init_reference_points=outputs.init_reference_points, + enc_topk_logits=outputs.enc_topk_logits, + enc_topk_bboxes=outputs.enc_topk_bboxes, + enc_outputs_class=outputs.enc_outputs_class, + enc_outputs_coord_logits=outputs.enc_outputs_coord_logits, + denoising_meta_values=outputs.denoising_meta_values, + ) + + +__all__ = ["Deimv2Model", "Deimv2PreTrainedModel", "Deimv2ForObjectDetection"] diff --git a/src/transformers/models/deimv2/modular_deimv2.py b/src/transformers/models/deimv2/modular_deimv2.py new file mode 100644 index 000000000000..be7cc90214e0 --- /dev/null +++ b/src/transformers/models/deimv2/modular_deimv2.py @@ -0,0 +1,819 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ... import initialization as init +from ...modeling_outputs import BaseModelOutput +from ...utils import logging +from ..d_fine.configuration_d_fine import DFineConfig +from ..d_fine.modeling_d_fine import ( + DFineConvNormLayer, + DFineDecoder, + DFineDecoderLayer, + DFineDecoderOutput, + DFineEncoderLayer, + DFineForObjectDetection, + DFineGate, + DFineHybridEncoder, + DFineIntegral, + DFineLQE, + DFineMLP, + DFineModel, + DFineMultiscaleDeformableAttention, + DFinePreTrainedModel, + DFineRepVggBlock, + DFineSCDown, +) +from ..rt_detr.modeling_rt_detr import RTDetrAIFILayer + + +logger = logging.get_logger(__name__) + + +class Deimv2Config(DFineConfig): + """ + This is the configuration class to store the configuration of a [`Deimv2Model`]. It is used to instantiate a + DEIMv2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of DEIMv2-L-COCO. + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + initializer_range (`float`, *optional*, defaults to 0.01): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_bias_prior_prob (`float`, *optional*): + The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`. + If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + batch_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the batch normalization layers. + backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `HGNetV2Config()`): + The configuration of the backbone model. + freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`): + Whether to freeze the batch normalization layers in the backbone. + encoder_hidden_dim (`int`, *optional*, defaults to 256): + Dimension of the layers in hybrid encoder. + encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`): + Multi level features input for encoder. + feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`): + Strides used in each feature map. + encoder_layers (`int`, *optional*, defaults to 1): + Total of layers to be used by the encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in encoder. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + dropout (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`): + Indexes of the projected layers to be used in the encoder. + positional_encoding_temperature (`int`, *optional*, defaults to 10000): + The temperature parameter used to create the positional encodings. + encoder_activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. + activation_function (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the general layer. + eval_size (`tuple[int, int]`, *optional*): + Height and width used to computes the effective height and width of the position embeddings after taking + into account the stride. + normalize_before (`bool`, *optional*, defaults to `False`): + Determine whether to apply layer normalization in the transformer encoder layer before self-attention and + feed-forward modules. + hidden_expansion (`float`, *optional*, defaults to 1.0): + Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer. + d_model (`int`, *optional*, defaults to 256): + Dimension of the layers exclude hybrid encoder. + num_queries (`int`, *optional*, defaults to 300): + Number of object queries. + decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`): + Multi level features dimension for decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of input feature levels. + decoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the decoder. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_activation_function (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the decoder. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_denoising (`int`, *optional*, defaults to 100): + The total number of denoising tasks or queries to be used for contrastive denoising. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + The fraction of denoising labels to which random noise should be added. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale or magnitude of noise to be added to the bounding boxes. + learn_initial_query (`bool`, *optional*, defaults to `False`): + Indicates whether the initial query embeddings for the decoder should be learned during training. + anchor_image_size (`tuple[int, int]`, *optional*): + Height and width of the input image used during evaluation to generate the bounding box anchors. + with_box_refine (`bool`, *optional*, defaults to `True`): + Whether to apply iterative bounding box refinement. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the architecture has an encoder decoder structure. + matcher_alpha (`float`, *optional*, defaults to 0.25): + Parameter alpha used by the Hungarian Matcher. + matcher_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used by the Hungarian Matcher. + matcher_class_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the class loss used by the Hungarian Matcher. + matcher_bbox_cost (`float`, *optional*, defaults to 5.0): + The relative weight of the bounding box loss used by the Hungarian Matcher. + matcher_giou_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the giou loss of used by the Hungarian Matcher. + use_focal_loss (`bool`, *optional*, defaults to `True`): + Parameter informing if focal loss should be used. + auxiliary_loss (`bool`, *optional*, defaults to `True`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + focal_loss_alpha (`float`, *optional*, defaults to 0.75): + Parameter alpha used to compute the focal loss. + focal_loss_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used to compute the focal loss. + weight_loss_vfl (`float`, *optional*, defaults to 1.0): + Relative weight of the varifocal loss in the object detection loss. + weight_loss_mal (`float`, *optional*, defaults to 1.0): + Relative weight of the matching auxiliary loss in the object detection loss. + weight_loss_bbox (`float`, *optional*, defaults to 5.0): + Relative weight of the L1 bounding box loss in the object detection loss. + weight_loss_giou (`float`, *optional*, defaults to 2.0): + Relative weight of the generalized IoU loss in the object detection loss. + weight_loss_fgl (`float`, *optional*, defaults to 0.15): + Relative weight of the fine-grained localization loss in the object detection loss. + weight_loss_ddf (`float`, *optional*, defaults to 1.5): + Relative weight of the decoupled distillation focal loss in the object detection loss. + eos_coefficient (`float`, *optional*, defaults to 0.0001): + Relative classification weight of the 'no-object' class in the object detection loss. + eval_idx (`int`, *optional*, defaults to -1): + Index of the decoder layer to use for evaluation. + layer_scale (`float`, *optional*, defaults to `1.0`): + Scaling factor for the hidden dimension in later decoder layers. + max_num_bins (`int`, *optional*, defaults to 32): + Maximum number of bins for the distribution-guided bounding box refinement. + reg_scale (`float`, *optional*, defaults to 4.0): + Scale factor for the regression distribution. + depth_mult (`float`, *optional*, defaults to 1.0): + Multiplier for the number of blocks in RepNCSPELAN5 layers. + top_prob_values (`int`, *optional*, defaults to 4): + Number of top probability values to consider from each corner's distribution. + lqe_hidden_dim (`int`, *optional*, defaults to 64): + Hidden dimension size for the Location Quality Estimator (LQE) network. + lqe_layers (`int`, *optional*, defaults to 2): + Number of layers in the Location Quality Estimator MLP. + decoder_offset_scale (`float`, *optional*, defaults to 0.5): + Offset scale used in deformable attention. + decoder_method (`str`, *optional*, defaults to `"default"`): + The method to use for the decoder: `"default"` or `"discrete"`. + up (`float`, *optional*, defaults to 0.5): + Controls the upper bounds of the Weighting Function. + use_dense_o2o (`bool`, *optional*, defaults to `True`): + Whether to use dense one-to-one matching across decoder layers. + mal_alpha (`float`, *optional*): + Alpha parameter for the Matching Auxiliary Loss (MAL). If `None`, uses `focal_loss_alpha`. + encoder_fuse_op (`str`, *optional*, defaults to `"sum"`): + Fusion operation used in the encoder FPN. DEIMv2 uses `"sum"` instead of D-Fine's `"cat"`. + use_spatial_tuning_adapter (`bool`, *optional*, defaults to `False`): + Whether to use the Spatial Tuning Adapter (STA) for DINOv2 backbone variants. + sta_inplanes (`int`, *optional*, defaults to 16): + Number of input planes for the STA convolutional stem. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings. + """ + + model_type = "deimv2" + + def __init__( + self, + initializer_range=0.01, + initializer_bias_prior_prob=None, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + backbone_config=None, + freeze_backbone_batch_norms=True, + encoder_hidden_dim=256, + encoder_in_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=1024, + encoder_attention_heads=8, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + hidden_expansion=1.0, + d_model=256, + num_queries=300, + decoder_in_channels=[256, 256, 256], + decoder_ffn_dim=1024, + num_feature_levels=3, + decoder_n_points=4, + decoder_layers=6, + decoder_attention_heads=8, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + with_box_refine=True, + is_encoder_decoder=True, + matcher_alpha=0.25, + matcher_gamma=2.0, + matcher_class_cost=2.0, + matcher_bbox_cost=5.0, + matcher_giou_cost=2.0, + use_focal_loss=True, + auxiliary_loss=True, + focal_loss_alpha=0.75, + focal_loss_gamma=2.0, + weight_loss_vfl=1.0, + weight_loss_bbox=5.0, + weight_loss_giou=2.0, + weight_loss_fgl=0.15, + weight_loss_ddf=1.5, + eos_coefficient=1e-4, + eval_idx=-1, + layer_scale=1, + max_num_bins=32, + reg_scale=4.0, + depth_mult=1.0, + top_prob_values=4, + lqe_hidden_dim=64, + lqe_layers=2, + decoder_offset_scale=0.5, + decoder_method="default", + up=0.5, + weight_loss_mal=1.0, + use_dense_o2o=True, + mal_alpha=None, + encoder_fuse_op="sum", + use_spatial_tuning_adapter=False, + sta_inplanes=16, + tie_word_embeddings=True, + **kwargs, + ): + self.weight_loss_mal = weight_loss_mal + self.use_dense_o2o = use_dense_o2o + self.mal_alpha = mal_alpha + self.encoder_fuse_op = encoder_fuse_op + self.use_spatial_tuning_adapter = use_spatial_tuning_adapter + self.sta_inplanes = sta_inplanes + super().__init__( + initializer_range=initializer_range, + initializer_bias_prior_prob=initializer_bias_prior_prob, + layer_norm_eps=layer_norm_eps, + batch_norm_eps=batch_norm_eps, + backbone_config=backbone_config, + freeze_backbone_batch_norms=freeze_backbone_batch_norms, + encoder_hidden_dim=encoder_hidden_dim, + encoder_in_channels=encoder_in_channels, + feat_strides=feat_strides, + encoder_layers=encoder_layers, + encoder_ffn_dim=encoder_ffn_dim, + encoder_attention_heads=encoder_attention_heads, + dropout=dropout, + activation_dropout=activation_dropout, + encode_proj_layers=encode_proj_layers, + positional_encoding_temperature=positional_encoding_temperature, + encoder_activation_function=encoder_activation_function, + activation_function=activation_function, + eval_size=eval_size, + normalize_before=normalize_before, + hidden_expansion=hidden_expansion, + d_model=d_model, + num_queries=num_queries, + decoder_in_channels=decoder_in_channels, + decoder_ffn_dim=decoder_ffn_dim, + num_feature_levels=num_feature_levels, + decoder_n_points=decoder_n_points, + decoder_layers=decoder_layers, + decoder_attention_heads=decoder_attention_heads, + decoder_activation_function=decoder_activation_function, + attention_dropout=attention_dropout, + num_denoising=num_denoising, + label_noise_ratio=label_noise_ratio, + box_noise_scale=box_noise_scale, + learn_initial_query=learn_initial_query, + anchor_image_size=anchor_image_size, + with_box_refine=with_box_refine, + is_encoder_decoder=is_encoder_decoder, + matcher_alpha=matcher_alpha, + matcher_gamma=matcher_gamma, + matcher_class_cost=matcher_class_cost, + matcher_bbox_cost=matcher_bbox_cost, + matcher_giou_cost=matcher_giou_cost, + use_focal_loss=use_focal_loss, + auxiliary_loss=auxiliary_loss, + focal_loss_alpha=focal_loss_alpha, + focal_loss_gamma=focal_loss_gamma, + weight_loss_vfl=weight_loss_vfl, + weight_loss_bbox=weight_loss_bbox, + weight_loss_giou=weight_loss_giou, + weight_loss_fgl=weight_loss_fgl, + weight_loss_ddf=weight_loss_ddf, + eos_coefficient=eos_coefficient, + eval_idx=eval_idx, + layer_scale=layer_scale, + max_num_bins=max_num_bins, + reg_scale=reg_scale, + depth_mult=depth_mult, + top_prob_values=top_prob_values, + lqe_hidden_dim=lqe_hidden_dim, + lqe_layers=lqe_layers, + decoder_offset_scale=decoder_offset_scale, + decoder_method=decoder_method, + up=up, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class Deimv2DecoderOutput(DFineDecoderOutput): + pass + + +class Deimv2RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.float() + hidden_states = hidden_states * torch.rsqrt(hidden_states.pow(2).mean(-1, keepdim=True) + self.eps) + return (hidden_states * self.scale).to(input_dtype) + + +class Deimv2SwiGLUFFN(nn.Module): + def __init__(self, in_features: int, hidden_features: int, out_features: int): + super().__init__() + self.w12 = nn.Linear(in_features, 2 * hidden_features) + self.w3 = nn.Linear(hidden_features, out_features) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + x12 = self.w12(hidden_states) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +class Deimv2Gate(DFineGate): + def __init__(self, d_model: int): + super().__init__(d_model) + self.norm = Deimv2RMSNorm(d_model) + + +class Deimv2MLP(DFineMLP): + pass + + +class Deimv2MultiscaleDeformableAttention(DFineMultiscaleDeformableAttention): + pass + + +class Deimv2ConvNormLayer(DFineConvNormLayer): + pass + + +class Deimv2RepVggBlock(DFineRepVggBlock): + pass + + +class Deimv2CSPRepLayer2(nn.Module): + def __init__( + self, config: Deimv2Config, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0 + ): + super().__init__() + activation = config.activation_function + hidden_channels = int(out_channels * expansion) + self.conv1 = Deimv2ConvNormLayer(config, in_channels, hidden_channels * 2, 1, 1, activation=activation) + self.bottlenecks = nn.ModuleList( + [Deimv2RepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)] + ) + self.conv3 = Deimv2ConvNormLayer(config, hidden_channels, out_channels, 3, 1, activation=activation) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + chunks = list(self.conv1(hidden_state).chunk(2, 1)) + bottleneck_out = chunks[1] + for bottleneck in self.bottlenecks: + bottleneck_out = bottleneck(bottleneck_out) + return self.conv3(chunks[0] + bottleneck_out) + + +class Deimv2RepNCSPELAN5(nn.Module): + def __init__(self, config: Deimv2Config, numb_blocks: int = 3): + super().__init__() + act = config.activation_function + c1 = config.encoder_hidden_dim + c2 = config.encoder_hidden_dim + c3 = config.encoder_hidden_dim * 2 + c4 = round(config.hidden_expansion * config.encoder_hidden_dim // 2) + self.conv_dim = c3 // 2 + self.conv1 = Deimv2ConvNormLayer(config, c1, c3, 1, 1, activation=act) + self.csp_rep1 = Deimv2CSPRepLayer2(config, c3 // 2, c4, num_blocks=numb_blocks) + self.csp_rep2 = Deimv2CSPRepLayer2(config, c4, c4, num_blocks=numb_blocks) + self.conv4 = Deimv2ConvNormLayer(config, c3 + (2 * c4), c2, 1, 1, activation=act) + + def forward(self, input_features: torch.Tensor) -> torch.Tensor: + split_features = list(self.conv1(input_features).split((self.conv_dim, self.conv_dim), 1)) + branch1 = self.csp_rep1(split_features[-1]) + branch2 = self.csp_rep2(branch1) + split_features.extend([branch1, branch2]) + merged_features = torch.cat(split_features, 1) + return self.conv4(merged_features) + + +class Deimv2SCDown(DFineSCDown): + pass + + +class Deimv2EncoderLayer(DFineEncoderLayer): + pass + + +class Deimv2AIFILayer(RTDetrAIFILayer): + pass + + +class Deimv2SpatialTuningAdapter(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + inplanes = config.sta_inplanes + self.stem = nn.Sequential( + nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(inplanes), + nn.GELU(), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + self.conv2 = nn.Sequential( + nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(2 * inplanes), + ) + self.conv3 = nn.Sequential( + nn.GELU(), + nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(4 * inplanes), + ) + self.conv4 = nn.Sequential( + nn.GELU(), + nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(4 * inplanes), + ) + + def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + c1 = self.stem(pixel_values) + c2 = self.conv2(c1) + c3 = self.conv3(c2) + c4 = self.conv4(c3) + return c2, c3, c4 + + +class Deimv2Integral(DFineIntegral): + pass + + +class Deimv2LQE(DFineLQE): + pass + + +class Deimv2DecoderLayer(DFineDecoderLayer): + def __init__(self, config: Deimv2Config): + super().__init__(config) + self.encoder_attn = Deimv2MultiscaleDeformableAttention(config=config) + self.gateway = Deimv2Gate(config.d_model) + self.self_attn_layer_norm = Deimv2RMSNorm(config.d_model) + self.final_layer_norm = Deimv2RMSNorm(config.d_model) + self.mlp = Deimv2SwiGLUFFN(config.d_model, config.decoder_ffn_dim // 2, config.d_model) + + +class Deimv2MLPPredictionHead(DFineMLP): + pass + + +class Deimv2PreTrainedModel(DFinePreTrainedModel): + @torch.no_grad() + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (Deimv2ForObjectDetection, Deimv2Decoder)): + if module.class_embed is not None: + for layer in module.class_embed: + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias = float(-math.log((1 - prior_prob) / prior_prob)) + init.xavier_uniform_(layer.weight) + init.constant_(layer.bias, bias) + + if module.bbox_embed is not None: + for layer in module.bbox_embed: + init.constant_(layer.layers[-1].weight, 0) + init.constant_(layer.layers[-1].bias, 0) + + if hasattr(module, "reg_scale"): + init.constant_(module.reg_scale, self.config.reg_scale) + + if hasattr(module, "up"): + init.constant_(module.up, self.config.up) + + if isinstance(module, Deimv2MultiscaleDeformableAttention): + init.constant_(module.sampling_offsets.weight, 0.0) + default_dtype = torch.get_default_dtype() + thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( + 2.0 * math.pi / module.n_heads + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values + grid_init = grid_init.reshape(module.n_heads, 1, 2).tile([1, sum(module.num_points_list), 1]) + scaling = torch.concat([torch.arange(1, n + 1) for n in module.num_points_list]).reshape(1, -1, 1) + grid_init *= scaling + init.copy_(module.sampling_offsets.bias, grid_init.flatten()) + + init.constant_(module.attention_weights.weight, 0.0) + init.constant_(module.attention_weights.bias, 0.0) + + num_points_scale = [1 / n for n in module.num_points_list for _ in range(n)] + init.copy_(module.num_points_scale, torch.tensor(num_points_scale, dtype=torch.float32)) + + if isinstance(module, Deimv2Model): + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias = float(-math.log((1 - prior_prob) / prior_prob)) + init.xavier_uniform_(module.enc_score_head.weight) + init.constant_(module.enc_score_head.bias, bias) + + if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + init.zeros_(module.bias) + if getattr(module, "running_mean", None) is not None: + init.zeros_(module.running_mean) + init.ones_(module.running_var) + init.zeros_(module.num_batches_tracked) + + if isinstance(module, Deimv2Gate): + bias = float(-math.log((1 - 0.5) / 0.5)) + init.constant_(module.gate.bias, bias) + init.constant_(module.gate.weight, 0) + + if isinstance(module, Deimv2LQE): + init.constant_(module.reg_conf.layers[-1].bias, 0) + init.constant_(module.reg_conf.layers[-1].weight, 0) + + if isinstance(module, Deimv2SwiGLUFFN): + init.xavier_uniform_(module.w12.weight) + init.constant_(module.w12.bias, 0) + init.xavier_uniform_(module.w3.weight) + init.constant_(module.w3.bias, 0) + + if isinstance(module, Deimv2RMSNorm): + init.ones_(module.scale) + + if isinstance(module, nn.LayerNorm): + init.ones_(module.weight) + init.zeros_(module.bias) + + if hasattr(module, "weight_embedding") and self.config.learn_initial_query: + init.xavier_uniform_(module.weight_embedding.weight) + if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: + init.xavier_uniform_(module.denoising_class_embed.weight) + + +class Deimv2HybridEncoder(DFineHybridEncoder): + def __init__(self, config: Deimv2Config): + Deimv2PreTrainedModel.__init__(config) + self.config = config + self.in_channels = config.encoder_in_channels + self.num_fpn_stages = len(self.in_channels) - 1 + self.feat_strides = config.feat_strides + self.encoder_hidden_dim = config.encoder_hidden_dim + self.encode_proj_layers = config.encode_proj_layers + self.positional_encoding_temperature = config.positional_encoding_temperature + self.eval_size = config.eval_size + self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels] + self.out_strides = self.feat_strides + self.encoder_fuse_op = config.encoder_fuse_op + + self.aifi = nn.ModuleList([Deimv2AIFILayer(config) for _ in range(len(self.encode_proj_layers))]) + + self.lateral_convs = nn.ModuleList() + self.fpn_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1, 0, -1): + lateral_layer = Deimv2ConvNormLayer(config, self.encoder_hidden_dim, self.encoder_hidden_dim, 1, 1) + self.lateral_convs.append(lateral_layer) + num_blocks = round(3 * config.depth_mult) + fpn_layer = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks) + self.fpn_blocks.append(fpn_layer) + + self.downsample_convs = nn.ModuleList() + self.pan_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1): + self.downsample_convs.append(Deimv2SCDown(config, 3, 2)) + num_blocks = round(3 * config.depth_mult) + self.pan_blocks.append(Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks)) + + self.post_init() + + def forward( + self, + inputs_embeds=None, + **kwargs, + ) -> BaseModelOutput: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + """ + feature_maps = inputs_embeds + + if self.config.encoder_layers > 0: + for i, enc_ind in enumerate(self.encode_proj_layers): + feature_maps[enc_ind] = self.aifi[i](feature_maps[enc_ind], **kwargs) + + fpn_feature_maps = [feature_maps[-1]] + for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)): + backbone_feature_map = feature_maps[self.num_fpn_stages - idx - 1] + top_fpn_feature_map = fpn_feature_maps[-1] + top_fpn_feature_map = lateral_conv(top_fpn_feature_map) + fpn_feature_maps[-1] = top_fpn_feature_map + top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest") + if self.encoder_fuse_op == "sum": + fused_feature_map = top_fpn_feature_map + backbone_feature_map + else: + fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1) + new_fpn_feature_map = fpn_block(fused_feature_map) + fpn_feature_maps.append(new_fpn_feature_map) + + fpn_feature_maps.reverse() + + pan_feature_maps = [fpn_feature_maps[0]] + for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)): + top_pan_feature_map = pan_feature_maps[-1] + fpn_feature_map = fpn_feature_maps[idx + 1] + downsampled_feature_map = downsample_conv(top_pan_feature_map) + if self.encoder_fuse_op == "sum": + fused_feature_map = downsampled_feature_map + fpn_feature_map + else: + fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1) + new_pan_feature_map = pan_block(fused_feature_map) + pan_feature_maps.append(new_pan_feature_map) + + return BaseModelOutput(last_hidden_state=pan_feature_maps) + + +class Deimv2Decoder(DFineDecoder): + def __init__(self, config: Deimv2Config): + self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx + super().__init__(config=config) + self.query_pos_head = Deimv2MLP(4, config.d_model, config.d_model, 3, config.decoder_activation_function) + self.reg_scale = nn.Parameter(torch.tensor([config.reg_scale]), requires_grad=False) + self.max_num_bins = config.max_num_bins + self.d_model = config.d_model + self.layer_scale = config.layer_scale + self.pre_bbox_head = Deimv2MLP(config.hidden_size, config.hidden_size, 4, 3) + self.integral = Deimv2Integral(config) + self.num_head = config.decoder_attention_heads + self.up = nn.Parameter(torch.tensor([config.up]), requires_grad=False) + self.lqe_layers = nn.ModuleList([Deimv2LQE(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList( + [Deimv2DecoderLayer(config) for _ in range(config.decoder_layers)] + + [Deimv2DecoderLayer(config) for _ in range(config.decoder_layers - self.eval_idx - 1)] + ) + + +class Deimv2Model(DFineModel): + def __init__(self, config: Deimv2Config): + super().__init__(config) + del self.decoder_input_proj + self.encoder = Deimv2HybridEncoder(config=config) + num_backbone_outs = len(config.decoder_in_channels) + decoder_input_proj = [] + in_channels = config.decoder_in_channels[-1] + for _ in range(num_backbone_outs): + if config.hidden_size == config.decoder_in_channels[-1]: + decoder_input_proj.append(nn.Identity()) + else: + conv = nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False) + batchnorm = nn.BatchNorm2d(config.d_model, config.batch_norm_eps) + decoder_input_proj.append(nn.Sequential(conv, batchnorm)) + for _ in range(config.num_feature_levels - num_backbone_outs): + if config.hidden_size == config.decoder_in_channels[-1]: + decoder_input_proj.append(nn.Identity()) + else: + conv = nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False) + batchnorm = nn.BatchNorm2d(config.d_model, config.batch_norm_eps) + decoder_input_proj.append(nn.Sequential(conv, batchnorm)) + self.decoder_input_proj = nn.ModuleList(decoder_input_proj) + self.decoder = Deimv2Decoder(config) + + if config.use_spatial_tuning_adapter: + self.spatial_tuning_adapter = Deimv2SpatialTuningAdapter(config) + + +class Deimv2ForObjectDetection(DFineForObjectDetection): + _no_split_modules = None + _tied_weights_keys = { + r"bbox_embed.(?![0])\d+": r"bbox_embed.0", + r"class_embed.(?![0])\d+": r"^class_embed.0", + "class_embed": "model.decoder.class_embed", + "bbox_embed": "model.decoder.bbox_embed", + } + + def __init__(self, config: Deimv2Config): + Deimv2PreTrainedModel.__init__(self, config) + + self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx + self.model = Deimv2Model(config) + scaled_dim = round(config.layer_scale * config.hidden_size) + num_pred = config.decoder_layers + self.class_embed = nn.ModuleList([nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList( + [ + Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) + for _ in range(self.eval_idx + 1) + ] + + [ + Deimv2MLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3) + for _ in range(config.decoder_layers - self.eval_idx - 1) + ] + ) + + self.model.decoder.class_embed = self.class_embed + self.model.decoder.bbox_embed = self.bbox_embed + self.post_init() + + def forward(**super_kwargs): + r""" + Example: + + ```python + >>> import torch + >>> from transformers.image_utils import load_image + >>> from transformers import AutoImageProcessor, Deimv2ForObjectDetection + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> image_processor = AutoImageProcessor.from_pretrained("Intellindust/DEIMv2_HGNetv2_N_COCO") + >>> model = Deimv2ForObjectDetection.from_pretrained("Intellindust/DEIMv2_HGNetv2_N_COCO") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 300, 80] + + >>> boxes = outputs.pred_boxes + >>> list(boxes.shape) + [1, 300, 4] + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = torch.tensor([image.size[::-1]]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes) + >>> result = results[0] # first image in batch + + >>> for score, label, box in zip(result["scores"], result["labels"], result["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + ``` + """ + super().forward(**super_kwargs) + + +__all__ = [ + "Deimv2Config", + "Deimv2Model", + "Deimv2PreTrainedModel", + "Deimv2ForObjectDetection", +] diff --git a/tests/models/deimv2/__init__.py b/tests/models/deimv2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/deimv2/test_modeling_deimv2.py b/tests/models/deimv2/test_modeling_deimv2.py new file mode 100644 index 000000000000..9b612c7833d1 --- /dev/null +++ b/tests/models/deimv2/test_modeling_deimv2.py @@ -0,0 +1,655 @@ +# coding = utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch DEIMv2 model.""" + +import copy +import inspect +import math +import tempfile +import unittest + +from parameterized import parameterized + +from transformers import ( + Deimv2Config, + HGNetV2Config, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import ( + require_torch, + require_torch_accelerator, + slow, + torch_device, +) + + +if is_torch_available(): + import torch + + from transformers import Deimv2ForObjectDetection, Deimv2Model + +if is_vision_available(): + from PIL import Image + + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +class Deimv2ModelTester: + def __init__( + self, + parent, + batch_size=3, + is_training=True, + use_labels=True, + n_targets=3, + num_labels=10, + initializer_range=0.02, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + backbone_config=None, + encoder_hidden_dim=32, + encoder_in_channels=[128, 256, 512], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=64, + encoder_attention_heads=2, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + d_model=32, + num_queries=30, + decoder_in_channels=[32, 32, 32], + decoder_ffn_dim=64, + num_feature_levels=3, + decoder_n_points=[3, 6, 3], + decoder_n_levels=3, + decoder_layers=2, + decoder_attention_heads=2, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=0, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + image_size=64, + disable_custom_kernels=True, + with_box_refine=True, + decoder_offset_scale=0.5, + eval_idx=-1, + layer_scale=1, + reg_max=32, + reg_scale=4.0, + depth_mult=0.34, + hidden_expansion=0.5, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = 3 + self.is_training = is_training + self.use_labels = use_labels + self.n_targets = n_targets + self.num_labels = num_labels + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + self.backbone_config = backbone_config + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_layers = encoder_layers + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = encode_proj_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.eval_size = eval_size + self.normalize_before = normalize_before + self.d_model = d_model + self.num_queries = num_queries + self.decoder_in_channels = decoder_in_channels + self.decoder_ffn_dim = decoder_ffn_dim + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_n_levels = decoder_n_levels + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.decoder_offset_scale = decoder_offset_scale + self.eval_idx = eval_idx + self.layer_scale = layer_scale + self.reg_max = reg_max + self.reg_scale = reg_scale + self.depth_mult = depth_mult + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.image_size = image_size + self.disable_custom_kernels = disable_custom_kernels + self.with_box_refine = with_box_refine + self.hidden_expansion = hidden_expansion + + self.encoder_seq_length = math.ceil(self.image_size / 32) * math.ceil(self.image_size / 32) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device) + + labels = None + if self.use_labels: + labels = [] + for i in range(self.batch_size): + target = {} + target["class_labels"] = torch.randint( + high=self.num_labels, size=(self.n_targets,), device=torch_device + ) + target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device) + labels.append(target) + + config = self.get_config() + config.num_labels = self.num_labels + return config, pixel_values, pixel_mask, labels + + def get_config(self): + hidden_sizes = [64, 128, 256, 512] + backbone_config = HGNetV2Config( + stage_in_channels=[16, 64, 128, 256], + stage_mid_channels=[16, 32, 64, 128], + stage_out_channels=[64, 128, 256, 512], + stage_num_blocks=[1, 1, 2, 1], + stage_downsample=[False, True, True, True], + stage_light_block=[False, False, True, True], + stage_kernel_size=[3, 3, 5, 5], + stage_numb_of_layers=[3, 3, 3, 3], + embeddings_size=10, + hidden_sizes=hidden_sizes, + depths=[1, 1, 2, 1], + out_features=["stage2", "stage3", "stage4"], + out_indices=[2, 3, 4], + stem_channels=[3, 16, 16], + use_lab=True, + ) + return Deimv2Config( + backbone_config=backbone_config, + encoder_hidden_dim=self.encoder_hidden_dim, + encoder_in_channels=self.encoder_in_channels, + feat_strides=self.feat_strides, + encoder_layers=self.encoder_layers, + encoder_ffn_dim=self.encoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + dropout=self.dropout, + activation_dropout=self.activation_dropout, + encode_proj_layers=self.encode_proj_layers, + positional_encoding_temperature=self.positional_encoding_temperature, + encoder_activation_function=self.encoder_activation_function, + activation_function=self.activation_function, + eval_size=self.eval_size, + normalize_before=self.normalize_before, + d_model=self.d_model, + num_queries=self.num_queries, + decoder_in_channels=self.decoder_in_channels, + decoder_ffn_dim=self.decoder_ffn_dim, + num_feature_levels=self.num_feature_levels, + decoder_n_points=self.decoder_n_points, + decoder_n_levels=self.decoder_n_levels, + decoder_layers=self.decoder_layers, + decoder_attention_heads=self.decoder_attention_heads, + decoder_activation_function=self.decoder_activation_function, + decoder_offset_scale=self.decoder_offset_scale, + eval_idx=self.eval_idx, + layer_scale=self.layer_scale, + reg_max=self.reg_max, + reg_scale=self.reg_scale, + depth_mult=self.depth_mult, + attention_dropout=self.attention_dropout, + num_denoising=self.num_denoising, + label_noise_ratio=self.label_noise_ratio, + box_noise_scale=self.box_noise_scale, + learn_initial_query=self.learn_initial_query, + anchor_image_size=self.anchor_image_size, + image_size=self.image_size, + disable_custom_kernels=self.disable_custom_kernels, + with_box_refine=self.with_box_refine, + ) + + def prepare_config_and_inputs_for_common(self): + config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + def create_and_check_deimv2_model(self, config, pixel_values, pixel_mask, labels): + model = Deimv2Model(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.d_model)) + + def create_and_check_deimv2_object_detection_head_model(self, config, pixel_values, pixel_mask, labels): + model = Deimv2ForObjectDetection(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels) + + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + +@require_torch +class Deimv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (Deimv2Model, Deimv2ForObjectDetection) if is_torch_available() else () + pipeline_model_mapping = ( + {"image-feature-extraction": Deimv2Model, "object-detection": Deimv2ForObjectDetection} + if is_torch_available() + else {} + ) + is_encoder_decoder = True + + test_missing_keys = False + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "Deimv2ForObjectDetection": + labels = [] + for i in range(self.model_tester.batch_size): + target = {} + target["class_labels"] = torch.ones( + size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long + ) + target["boxes"] = torch.ones( + self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float + ) + labels.append(target) + inputs_dict["labels"] = labels + + return inputs_dict + + def setUp(self): + self.model_tester = Deimv2ModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=Deimv2Config, + has_text_modality=False, + common_properties=["hidden_size", "num_attention_heads"], + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_deimv2_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deimv2_model(*config_and_inputs) + + def test_deimv2_object_detection_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deimv2_object_detection_head_model(*config_and_inputs) + + @unittest.skip(reason="Deimv2 doesn't work well with `nn.DataParallel") + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip(reason="Deimv2 does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Deimv2 does not use test_inputs_embeds_matches_input_ids") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip(reason="Deimv2 does not support input and output embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="Deimv2 does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="Deimv2 does not use token embeddings") + def test_resize_tokens_embeddings(self): + pass + + @unittest.skip(reason="Feed forward chunking is not implemented") + def test_feed_forward_chunking(self): + pass + + @unittest.skip(reason="Weight tying is hardcoded (module_x = module_y) and always `True`") + def test_load_save_without_tied_weights(self): + pass + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class._from_config(config, attn_implementation="eager") + config = model.config + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions + self.assertEqual(len(attentions), self.model_tester.encoder_layers) + + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions + self.assertEqual(len(attentions), self.model_tester.encoder_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [ + self.model_tester.encoder_attention_heads, + self.model_tester.encoder_seq_length, + self.model_tester.encoder_seq_length, + ], + ) + out_len = len(outputs) + + correct_outlen = 15 + + if "labels" in inputs_dict: + correct_outlen += 1 + if model_class.__name__ == "Deimv2ForObjectDetection": + correct_outlen += 2 + + self.assertEqual(out_len, correct_outlen) + + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.decoder_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [ + self.model_tester.decoder_attention_heads, + self.model_tester.num_queries, + self.model_tester.num_queries, + ], + ) + + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.decoder_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_queries, + self.model_tester.decoder_attention_heads, + self.model_tester.decoder_n_levels * self.model_tester.decoder_n_points + if isinstance(self.model_tester.decoder_n_points, int) + else sum(self.model_tester.decoder_n_points), + ], + ) + + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + else: + added_hidden_states = 2 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions + + self.assertEqual(len(self_attentions), self.model_tester.encoder_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [ + self.model_tester.encoder_attention_heads, + self.model_tester.encoder_seq_length, + self.model_tester.encoder_seq_length, + ], + ) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", len(self.model_tester.encoder_in_channels) - 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[1].shape[-2:]), + [ + self.model_tester.image_size // self.model_tester.feat_strides[-1], + self.model_tester.image_size // self.model_tester.feat_strides[-1], + ], + ) + + if config.is_encoder_decoder: + hidden_states = outputs.decoder_hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.decoder_layers + 1 + ) + + self.assertIsInstance(hidden_states, (list, tuple)) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [self.model_tester.num_queries, self.model_tester.d_model], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + + outputs = model(**inputs) + + output = outputs[0] + + encoder_hidden_states = outputs.encoder_hidden_states[0] + encoder_attentions = outputs.encoder_attentions[0] + encoder_hidden_states.retain_grad() + encoder_attentions.retain_grad() + + decoder_attentions = outputs.decoder_attentions[0] + decoder_attentions.retain_grad() + + cross_attentions = outputs.cross_attentions[0] + cross_attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(encoder_hidden_states.grad) + self.assertIsNotNone(encoder_attentions.grad) + self.assertIsNotNone(decoder_attentions.grad) + self.assertIsNotNone(cross_attentions.grad) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + arg_names = [*signature.parameters.keys()] + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_backbone_selection(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def _validate_backbone_init(config): + for model_class in self.all_model_classes: + model = model_class(copy.deepcopy(config)) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if model_class.__name__ == "Deimv2ForObjectDetection": + expected_shape = ( + self.model_tester.batch_size, + self.model_tester.num_queries, + self.model_tester.num_labels, + ) + self.assertEqual(outputs.logits.shape, expected_shape) + self.assertEqual(len(model.model.backbone.intermediate_channel_sizes), 3) + else: + self.assertEqual(len(model.backbone.intermediate_channel_sizes), 3) + + self.assertTrue(outputs) + + config_dict = config.to_dict() + config_dict["encoder_in_channels"] = [24, 40, 432] + config_dict["backbone"] = "tf_mobilenetv3_small_075" + config_dict["backbone_config"] = None + config_dict["use_timm_backbone"] = True + config_dict["backbone_kwargs"] = {"out_indices": [2, 3, 4]} + config = config.__class__(**config_dict) + _validate_backbone_init(config) + + config_dict = config.to_dict() + config_dict["backbone"] = "microsoft/resnet-18" + config_dict["backbone_config"] = None + config_dict["use_timm_backbone"] = False + config_dict["use_pretrained_backbone"] = True + config_dict["backbone_kwargs"] = {"out_indices": [2, 3, 4]} + config = config.__class__(**config_dict) + _validate_backbone_init(config) + + @parameterized.expand(["float32", "float16", "bfloat16"]) + @require_torch_accelerator + @slow + def test_inference_with_different_dtypes(self, dtype_str): + dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[dtype_str] + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device).to(dtype) + model.eval() + for key, tensor in inputs_dict.items(): + if tensor.dtype == torch.float32: + inputs_dict[key] = tensor.to(dtype) + with torch.no_grad(): + _ = model(**self._prepare_for_class(inputs_dict, model_class)) + + @parameterized.expand(["float32", "float16", "bfloat16"]) + @require_torch_accelerator + @slow + def test_inference_equivalence_for_static_and_dynamic_anchors(self, dtype_str): + dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[dtype_str] + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + h, w = inputs_dict["pixel_values"].shape[-2:] + + for key, tensor in inputs_dict.items(): + if tensor.dtype == torch.float32: + inputs_dict[key] = tensor.to(dtype) + + for model_class in self.all_model_classes: + with tempfile.TemporaryDirectory() as tmpdirname: + model_class(config).save_pretrained(tmpdirname) + model_static = model_class.from_pretrained( + tmpdirname, anchor_image_size=[h, w], device_map=torch_device, dtype=dtype + ).eval() + model_dynamic = model_class.from_pretrained( + tmpdirname, anchor_image_size=None, device_map=torch_device, dtype=dtype + ).eval() + + self.assertIsNotNone(model_static.config.anchor_image_size) + self.assertIsNone(model_dynamic.config.anchor_image_size) + + with torch.no_grad(): + outputs_static = model_static(**self._prepare_for_class(inputs_dict, model_class)) + outputs_dynamic = model_dynamic(**self._prepare_for_class(inputs_dict, model_class)) + + torch.testing.assert_close( + outputs_static.last_hidden_state, outputs_dynamic.last_hidden_state, rtol=1e-4, atol=1e-4 + ) + + +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image From ddc1bd71b93335e8789e9768e285f2917463cd20 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Fri, 27 Feb 2026 22:25:05 +0400 Subject: [PATCH 02/17] fix: Fix ci/circleci: check_repository_consistency --- src/transformers/models/deimv2/modular_deimv2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/deimv2/modular_deimv2.py b/src/transformers/models/deimv2/modular_deimv2.py index be7cc90214e0..05bd3b8d4da5 100644 --- a/src/transformers/models/deimv2/modular_deimv2.py +++ b/src/transformers/models/deimv2/modular_deimv2.py @@ -154,8 +154,6 @@ class Deimv2Config(DFineConfig): Parameter gamma used to compute the focal loss. weight_loss_vfl (`float`, *optional*, defaults to 1.0): Relative weight of the varifocal loss in the object detection loss. - weight_loss_mal (`float`, *optional*, defaults to 1.0): - Relative weight of the matching auxiliary loss in the object detection loss. weight_loss_bbox (`float`, *optional*, defaults to 5.0): Relative weight of the L1 bounding box loss in the object detection loss. weight_loss_giou (`float`, *optional*, defaults to 2.0): @@ -188,6 +186,8 @@ class Deimv2Config(DFineConfig): The method to use for the decoder: `"default"` or `"discrete"`. up (`float`, *optional*, defaults to 0.5): Controls the upper bounds of the Weighting Function. + weight_loss_mal (`float`, *optional*, defaults to 1.0): + Relative weight of the matching auxiliary loss in the object detection loss. use_dense_o2o (`bool`, *optional*, defaults to `True`): Whether to use dense one-to-one matching across decoder layers. mal_alpha (`float`, *optional*): From 85c7356ece2be84c5aa18d08e5a94513e8114cce Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Sun, 1 Mar 2026 11:33:22 +0400 Subject: [PATCH 03/17] feat: Add support and test harness for all variants --- .../models/deimv2/configuration_deimv2.py | 210 +++- ...eimv2_original_pytorch_checkpoint_to_hf.py | 424 ++++++- .../models/deimv2/modeling_deimv2.py | 843 ++++++++++---- .../models/deimv2/modular_deimv2.py | 549 ++++++++- tests/models/deimv2/test_modeling_deimv2.py | 1010 ++++++++++++++++- 5 files changed, 2749 insertions(+), 287 deletions(-) diff --git a/src/transformers/models/deimv2/configuration_deimv2.py b/src/transformers/models/deimv2/configuration_deimv2.py index 56936d2cdd7a..d98372c44509 100644 --- a/src/transformers/models/deimv2/configuration_deimv2.py +++ b/src/transformers/models/deimv2/configuration_deimv2.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ...backbone_utils import consolidate_backbone_kwargs_to_config +from ...backbone_utils import BackboneConfigMixin, consolidate_backbone_kwargs_to_config from ...configuration_utils import PreTrainedConfig from ..auto import AutoConfig @@ -28,7 +28,8 @@ class Deimv2Config(PreTrainedConfig): """ This is the configuration class to store the configuration of a [`Deimv2Model`]. It is used to instantiate a DEIMv2 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of DEIMv2-L-COCO. + with the defaults will yield a similar configuration to that of DEIMv2-HGNetv2-N-COCO + [Intellindust/DEIMv2_HGNetv2_N_COCO](https://huggingface.co/Intellindust/DEIMv2_HGNetv2_N_COCO). Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PreTrainedConfig`] for more information. @@ -177,6 +178,29 @@ class Deimv2Config(PreTrainedConfig): Whether to use the Spatial Tuning Adapter (STA) for DINOv2 backbone variants. sta_inplanes (`int`, *optional*, defaults to 16): Number of input planes for the STA convolutional stem. + encoder_type (`str`, *optional*, defaults to `"hybrid"`): + Type of encoder to use. `"hybrid"` uses the full HybridEncoder with AIFI, FPN, and PAN. + `"lite"` uses the lightweight LiteEncoder with GAP fusion for smaller variants (Atto, Femto, Pico). + use_gateway (`bool`, *optional*, defaults to `True`): + Whether to use the gateway mechanism (cross-attention gating) in decoder layers. When `False`, + uses RMSNorm on the encoder attention output instead. + share_bbox_head (`bool`, *optional*, defaults to `False`): + Whether to share the bounding box prediction head across all decoder layers. + backbone_type (`str`, *optional*, defaults to `"hgnetv2"`): + Type of backbone to use. `"hgnetv2"` uses HGNetV2, `"dinov3"` uses DINOv3 ViT backbone with STA. + dinov3_backbone_config (`dict`, *optional*): + Configuration dictionary for the DINOv3 ViT backbone. Passed as kwargs to `DINOv3ViTConfig`. + dinov3_interaction_indexes (`list[int]`, *optional*): + Layer indices in the DINOv3 ViT backbone from which to extract intermediate features. + dinov3_hidden_dim (`int`, *optional*): + Hidden dimension for the DINOv3 backbone projection convolutions. If `None`, uses `hidden_size` from + the DINOv3 ViT config. + dinov3_apply_layernorm (`bool`, *optional*, defaults to `False`): + Whether to apply LayerNorm to intermediate features extracted from the DINOv3 ViT backbone. + encoder_has_trailing_conv (`bool`, *optional*, defaults to `True`): + Whether the encoder's CSP blocks include a trailing 3x3 convolution after the bottleneck path. + `True` for RepNCSPELAN4 (used by HGNetV2 N and LiteEncoder variants). + `False` for RepNCSPELAN5 (used by DINOv3 variants). tie_word_embeddings (`bool`, *optional*, defaults to `True`): Whether to tie weight embeddings. """ @@ -261,6 +285,15 @@ def __init__( encoder_fuse_op="sum", use_spatial_tuning_adapter=False, sta_inplanes=16, + encoder_type="hybrid", + use_gateway=True, + share_bbox_head=False, + backbone_type="hgnetv2", + dinov3_backbone_config=None, + dinov3_interaction_indexes=None, + dinov3_hidden_dim=None, + dinov3_apply_layernorm=False, + encoder_has_trailing_conv=True, tie_word_embeddings=True, **kwargs, ): @@ -270,6 +303,15 @@ def __init__( self.encoder_fuse_op = encoder_fuse_op self.use_spatial_tuning_adapter = use_spatial_tuning_adapter self.sta_inplanes = sta_inplanes + self.encoder_type = encoder_type + self.use_gateway = use_gateway + self.share_bbox_head = share_bbox_head + self.backbone_type = backbone_type + self.dinov3_backbone_config = dinov3_backbone_config + self.dinov3_interaction_indexes = dinov3_interaction_indexes + self.dinov3_hidden_dim = dinov3_hidden_dim + self.dinov3_apply_layernorm = dinov3_apply_layernorm + self.encoder_has_trailing_conv = encoder_has_trailing_conv self.initializer_range = initializer_range self.initializer_bias_prior_prob = initializer_bias_prior_prob self.layer_norm_eps = layer_norm_eps @@ -362,4 +404,168 @@ def __init__( super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) +class Deimv2DINOv3ViTConfig(BackboneConfigMixin, PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DINOv3Model`]. It is used to instantiate an + DINOv3 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the DINOv3 + [facebook/dinov3-vits16-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vits16-pretrain-lvd1689m) architecture. + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_size (`int`, *optional*, defaults to 384): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 1536): + Dimensionality of the "intermediate" (i.e., feed-forward) layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 6): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + rope_theta (`float`, *optional*, defaults to 100.0): + The base period of the RoPE embeddings. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + query_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the query projection. + key_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the key projection. + value_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the value projection. + proj_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the output projection. + mlp_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the MLP layers. + layerscale_value (`float`, *optional*, defaults to 1.0): + Initial value to use for layer scale. + drop_path_rate (`float`, *optional*, defaults to 0.0): + Stochastic depth rate per sample (when applied in the main path of residual layers). + use_gated_mlp (`bool`, *optional*, defaults to `False`): + Whether to use the SwiGLU feedforward neural network. + num_register_tokens (`int`, *optional*, defaults to 0): + The number of register tokens. + pos_embed_shift (`float`, *optional*): + Amount to randomly shift position embedding coordinates in [-shift, shift], + applied only in training mode if not `None`. + pos_embed_jitter (`float`, *optional*): + Amount to randomly jitter position embedding coordinates in log-uniform value in [1/jitter, jitter], + applied only in training mode if not `None`. + pos_embed_rescale (`float`, *optional*, defaults to 2.0): + Amount to randomly rescale position embedding coordinates in log-uniform value in [1/rescale, rescale], + applied only in training mode if not `None`. + out_features (`list[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). Will default to the last stage if unset. + out_indices (`list[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. + (depending on how many stages the model has). Will default to the last stage if unset. + apply_layernorm (`bool`, *optional*, defaults to `True`): + Whether to apply layer normalization to the feature maps when used as backbone. + reshape_hidden_states (`bool`, *optional*, defaults to `True`): + Whether to reshape the hidden states to spatial dimensions when used as backbone. + + Example: + + ```python + >>> from transformers import Deimv2DINOv3ViTConfig, Deimv2DINOv3ViTModel + + >>> # Initializing a DINOv3 ViT-small style configuration + >>> config = Deimv2DINOv3ViTConfig() + + >>> # Initializing a model (with random weights) from the config + >>> model = Deimv2DINOv3ViTModel(config) + + >>> # Accessing the model config + >>> config = model.config + ```""" + + model_type = "deimv2_dinov3_vit" + + def __init__( + self, + patch_size: int = 16, + hidden_size: int = 384, + intermediate_size: int = 1536, + num_hidden_layers: int = 12, + num_attention_heads: int = 6, + hidden_act: str = "gelu", + attention_dropout: float = 0.0, + initializer_range: float = 0.02, + layer_norm_eps: float = 1e-5, + rope_theta: float = 100.0, + image_size: int = 224, + num_channels: int = 3, + query_bias: bool = True, + key_bias: bool = False, + value_bias: bool = True, + proj_bias: bool = True, + mlp_bias: bool = True, + layerscale_value: float = 1.0, + drop_path_rate: float = 0.0, + use_gated_mlp: bool = False, + num_register_tokens: int = 0, + # train augs + pos_embed_shift: float | None = None, + pos_embed_jitter: float | None = None, + pos_embed_rescale: float | None = 2.0, + out_features: list[str] | None = None, + out_indices: list[int] | None = None, + apply_layernorm: bool = True, + reshape_hidden_states: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.layerscale_value = layerscale_value + self.drop_path_rate = drop_path_rate + self.use_gated_mlp = use_gated_mlp + self.rope_theta = rope_theta + self.query_bias = query_bias + self.key_bias = key_bias + self.value_bias = value_bias + self.proj_bias = proj_bias + self.mlp_bias = mlp_bias + self.num_register_tokens = num_register_tokens + + # train augs + self.pos_embed_shift = pos_embed_shift + self.pos_embed_jitter = pos_embed_jitter + self.pos_embed_rescale = pos_embed_rescale + # Initialize backbone-specific configuration + self.apply_layernorm = apply_layernorm + self.reshape_hidden_states = reshape_hidden_states + + # Initialize backbone stage names + stage_names = ["stem"] + [f"stage{i}" for i in range(1, num_hidden_layers + 1)] + self.stage_names = stage_names + + # Initialize backbone features/indices + self.set_output_features_output_indices(out_indices=out_indices, out_features=out_features) + + __all__ = ["Deimv2Config"] diff --git a/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py b/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py index 5aaecc844594..a6125952deeb 100644 --- a/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py +++ b/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py @@ -60,13 +60,10 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: decoder_cfg = orig_config["DEIMTransformer"] if "HybridEncoder" in orig_config: encoder_cfg = orig_config["HybridEncoder"] + encoder_type = "hybrid" elif "LiteEncoder" in orig_config: - raise ValueError( - "LiteEncoder variants (pico/femto/atto) are not yet supported. " - "The LiteEncoder uses a different architecture (AvgPool downsampling, GAP fusion, " - "RepNCSPELAN4 blocks) that requires a dedicated Deimv2LiteEncoder implementation. " - "Supported variants: deimv2_hgnetv2_n_coco and DINOv3 variants." - ) + encoder_cfg = orig_config["LiteEncoder"] + encoder_type = "lite" else: raise ValueError(f"No encoder config found. Available keys: {list(orig_config.keys())}") @@ -76,18 +73,34 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: config.label2id = {v: k for k, v in id2label.items()} # Encoder settings - config.encoder_hidden_dim = encoder_cfg["hidden_dim"] - config.encoder_in_channels = encoder_cfg["in_channels"] - config.feat_strides = encoder_cfg["feat_strides"] - config.activation_function = encoder_cfg.get("act", "silu") - config.depth_mult = encoder_cfg.get("depth_mult", 1.0) - config.hidden_expansion = encoder_cfg.get("expansion", 1.0) - config.encoder_fuse_op = encoder_cfg.get("fuse_op", "sum") - config.encoder_ffn_dim = encoder_cfg["dim_feedforward"] - config.encoder_attention_heads = encoder_cfg["nhead"] - config.dropout = encoder_cfg.get("dropout", 0.0) - config.encode_proj_layers = encoder_cfg["use_encoder_idx"] - config.encoder_activation_function = encoder_cfg.get("enc_act", "gelu") + config.encoder_type = encoder_type + if encoder_type == "lite": + config.encoder_hidden_dim = encoder_cfg["hidden_dim"] + config.encoder_in_channels = encoder_cfg["in_channels"] + config.feat_strides = encoder_cfg.get("feat_strides", [16]) + config.activation_function = encoder_cfg.get("act", "silu") + config.depth_mult = encoder_cfg.get("depth_mult", 1.0) + config.hidden_expansion = encoder_cfg.get("expansion", 1.0) + config.encoder_fuse_op = "sum" + config.encoder_ffn_dim = 1024 + config.encoder_attention_heads = 8 + config.dropout = 0.0 + config.encode_proj_layers = [2] + config.encoder_activation_function = "gelu" + config.encoder_layers = 0 + else: + config.encoder_hidden_dim = encoder_cfg["hidden_dim"] + config.encoder_in_channels = encoder_cfg["in_channels"] + config.feat_strides = encoder_cfg.get("feat_strides", [8, 16, 32]) + config.activation_function = encoder_cfg.get("act", "silu") + config.depth_mult = encoder_cfg.get("depth_mult", 1.0) + config.hidden_expansion = encoder_cfg.get("expansion", 1.0) + config.encoder_fuse_op = encoder_cfg.get("fuse_op", "sum") + config.encoder_ffn_dim = encoder_cfg.get("dim_feedforward", 1024) + config.encoder_attention_heads = encoder_cfg.get("nhead", 8) + config.dropout = encoder_cfg.get("dropout", 0.0) + config.encode_proj_layers = encoder_cfg.get("use_encoder_idx", [2]) + config.encoder_activation_function = encoder_cfg.get("enc_act", "gelu") # Decoder settings config.d_model = decoder_cfg["hidden_dim"] @@ -106,6 +119,8 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: config.decoder_in_channels = decoder_cfg["feat_channels"] config.eval_size = tuple(decoder_cfg["eval_spatial_size"]) if "eval_spatial_size" in decoder_cfg else None config.decoder_activation_function = decoder_cfg.get("activation", "silu") + config.share_bbox_head = decoder_cfg.get("share_bbox_head", False) + config.use_gateway = decoder_cfg.get("use_gateway", True) # Backbone settings if "HGNetv2" in orig_config: @@ -115,8 +130,7 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: config.backbone_config.out_indices = [i + 1 for i in return_idx] config.backbone_config.use_learnable_affine_block = backbone_cfg.get("use_lab", True) - # Set backbone sizes based on the model variant - if backbone_name == "B0": + if backbone_name in ["B0", "B1", "B2"]: config.backbone_config.hidden_sizes = [128, 256, 512, 1024] config.backbone_config.stem_channels = [3, 16, 16] config.backbone_config.stage_in_channels = [16, 64, 256, 512] @@ -127,28 +141,128 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: config.backbone_config.stage_light_block = [False, False, True, True] config.backbone_config.stage_kernel_size = [3, 3, 5, 5] config.backbone_config.stage_numb_of_layers = [3, 3, 3, 3] - elif backbone_name in ["B1", "B2"]: - config.backbone_config.hidden_sizes = [128, 256, 512, 1024] + elif backbone_name == "Atto": + config.backbone_config.hidden_sizes = [64, 256, 256] config.backbone_config.stem_channels = [3, 16, 16] - config.backbone_config.stage_in_channels = [16, 64, 256, 512] - config.backbone_config.stage_mid_channels = [16, 32, 64, 128] - config.backbone_config.stage_out_channels = [64, 256, 512, 1024] - config.backbone_config.stage_num_blocks = [1, 1, 2, 1] - config.backbone_config.stage_downsample = [False, True, True, True] - config.backbone_config.stage_light_block = [False, False, True, True] - config.backbone_config.stage_kernel_size = [3, 3, 5, 5] - config.backbone_config.stage_numb_of_layers = [3, 3, 3, 3] + config.backbone_config.stage_in_channels = [16, 64, 256] + config.backbone_config.stage_mid_channels = [16, 32, 64] + config.backbone_config.stage_out_channels = [64, 256, 256] + config.backbone_config.stage_num_blocks = [1, 1, 1] + config.backbone_config.stage_downsample = [False, True, True] + config.backbone_config.stage_light_block = [False, False, True] + config.backbone_config.stage_kernel_size = [3, 3, 3] + config.backbone_config.stage_numb_of_layers = [3, 3, 3] + elif backbone_name == "Femto": + config.backbone_config.hidden_sizes = [64, 256, 512] + config.backbone_config.stem_channels = [3, 16, 16] + config.backbone_config.stage_in_channels = [16, 64, 256] + config.backbone_config.stage_mid_channels = [16, 32, 64] + config.backbone_config.stage_out_channels = [64, 256, 512] + config.backbone_config.stage_num_blocks = [1, 1, 1] + config.backbone_config.stage_downsample = [False, True, True] + config.backbone_config.stage_light_block = [False, False, True] + config.backbone_config.stage_kernel_size = [3, 3, 5] + config.backbone_config.stage_numb_of_layers = [3, 3, 3] + elif backbone_name == "Pico": + config.backbone_config.hidden_sizes = [64, 256, 512] + config.backbone_config.stem_channels = [3, 16, 16] + config.backbone_config.stage_in_channels = [16, 64, 256] + config.backbone_config.stage_mid_channels = [16, 32, 64] + config.backbone_config.stage_out_channels = [64, 256, 512] + config.backbone_config.stage_num_blocks = [1, 1, 2] + config.backbone_config.stage_downsample = [False, True, True] + config.backbone_config.stage_light_block = [False, False, True] + config.backbone_config.stage_kernel_size = [3, 3, 5] + config.backbone_config.stage_numb_of_layers = [3, 3, 3] else: raise ValueError(f"Unknown HGNetv2 variant: {backbone_name}") + num_stages = len(config.backbone_config.hidden_sizes) + config.backbone_config.depths = config.backbone_config.stage_numb_of_layers + config.backbone_config.stage_names = ["stem"] + [f"stage{i}" for i in range(1, num_stages + 1)] + config.use_spatial_tuning_adapter = False elif "DINOv3STAs" in orig_config: - raise ValueError( - "DINOv3 backbone variants are not yet supported. " - "The DINOv3+STA architecture requires ViT backbone key mappings and " - "STA adapter integration that are not yet implemented in the conversion script. " - "Supported variants: deimv2_hgnetv2_n_coco." - ) + dinov3_cfg = orig_config["DINOv3STAs"] + name = dinov3_cfg["name"] + config.backbone_type = "dinov3" + config.dinov3_interaction_indexes = dinov3_cfg["interaction_indexes"] + config.sta_inplanes = dinov3_cfg.get("conv_inplane") or 16 + config.use_spatial_tuning_adapter = True + + is_dinov3 = "dinov3" in name + + DINOV3_PRESETS = { + "vit_tiny": { + "ffn_ratio": 4, + "use_gated_mlp": False, + "layerscale_value": 1.0, + "num_register_tokens": 0, + "pos_embed_rescale": None, + "key_bias": False, + }, + "vit_tinyplus": { + "ffn_ratio": 4, + "use_gated_mlp": False, + "layerscale_value": 1.0, + "num_register_tokens": 0, + "pos_embed_rescale": None, + "key_bias": False, + }, + "dinov3_vits16": { + "vit_hidden_size": 384, + "vit_num_heads": 6, + "ffn_ratio": 4, + "use_gated_mlp": False, + "layerscale_value": 1e-5, + "num_register_tokens": 4, + "pos_embed_rescale": 2.0, + "key_bias": True, + }, + "dinov3_vits16plus": { + "vit_hidden_size": 384, + "vit_num_heads": 6, + "ffn_ratio": 6, + "use_gated_mlp": True, + "layerscale_value": 1e-5, + "num_register_tokens": 4, + "pos_embed_rescale": 2.0, + "key_bias": True, + }, + } + preset = DINOV3_PRESETS[name] + + if is_dinov3: + vit_hidden_size = preset["vit_hidden_size"] + vit_num_heads = preset["vit_num_heads"] + else: + vit_hidden_size = dinov3_cfg.get("embed_dim") or 192 + vit_num_heads = dinov3_cfg.get("num_heads") or 3 + + config.dinov3_hidden_dim = dinov3_cfg.get("hidden_dim") or vit_hidden_size + + ffn_ratio = preset["ffn_ratio"] + if preset["use_gated_mlp"]: + hidden_features = vit_hidden_size * ffn_ratio + d = int(hidden_features * 2 / 3) + intermediate_size = d + (-d % 8) + else: + intermediate_size = vit_hidden_size * ffn_ratio + + config.dinov3_backbone_config = { + "hidden_size": vit_hidden_size, + "num_attention_heads": vit_num_heads, + "num_hidden_layers": 12, + "intermediate_size": intermediate_size, + "layerscale_value": preset["layerscale_value"], + "use_gated_mlp": preset["use_gated_mlp"], + "num_register_tokens": preset["num_register_tokens"], + "pos_embed_rescale": preset["pos_embed_rescale"], + "key_bias": preset["key_bias"], + "rope_theta": 100.0, + } + config.dinov3_apply_layernorm = is_dinov3 + config.encoder_has_trailing_conv = False else: raise ValueError(f"Unknown backbone in config: {list(orig_config.keys())}") @@ -280,10 +394,132 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: r"decoder\.decoder\.reg_scale": r"model.decoder.reg_scale", } +LITE_ENCODER_KEY_MAPPING = { + # LiteEncoder input_proj + r"encoder\.input_proj\.(\d+)\.conv\.weight": r"model.encoder.input_proj.\1.0.weight", + r"encoder\.input_proj\.(\d+)\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.input_proj.\1.1.\2", + # Downsamples + r"encoder\.down_sample(\d+)\.1\.weight": r"model.encoder.down_sample\1.1.weight", + r"encoder\.down_sample(\d+)\.2\.(weight|bias|running_mean|running_var)": r"model.encoder.down_sample\1.2.\2", + # GAP_Fusion + r"encoder\.bi_fusion\.cv\.conv\.weight": r"model.encoder.bi_fusion.cv.conv.weight", + r"encoder\.bi_fusion\.cv\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.bi_fusion.cv.norm.\1", + # FPN block (RepNCSPELAN4) + r"encoder\.fpn_block\.cv1\.conv\.weight": r"model.encoder.fpn_block.conv1.conv.weight", + r"encoder\.fpn_block\.cv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.conv1.norm.\1", + r"encoder\.fpn_block\.cv4\.conv\.weight": r"model.encoder.fpn_block.conv4.conv.weight", + r"encoder\.fpn_block\.cv4\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.conv4.norm.\1", + r"encoder\.fpn_block\.cv2\.0\.conv1\.conv\.weight": r"model.encoder.fpn_block.csp_rep1.conv1.conv.weight", + r"encoder\.fpn_block\.cv2\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep1.conv1.norm.\1", + r"encoder\.fpn_block\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.fpn_block.csp_rep1.bottlenecks.\1.conv1.conv.weight", + r"encoder\.fpn_block\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep1.bottlenecks.\1.conv1.norm.\2", + r"encoder\.fpn_block\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.fpn_block.csp_rep1.bottlenecks.\1.conv2.conv.weight", + r"encoder\.fpn_block\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep1.bottlenecks.\1.conv2.norm.\2", + r"encoder\.fpn_block\.cv2\.1\.conv\.weight": r"model.encoder.fpn_block.csp_rep1.conv3.conv.weight", + r"encoder\.fpn_block\.cv2\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep1.conv3.norm.\1", + r"encoder\.fpn_block\.cv3\.0\.conv1\.conv\.weight": r"model.encoder.fpn_block.csp_rep2.conv1.conv.weight", + r"encoder\.fpn_block\.cv3\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep2.conv1.norm.\1", + r"encoder\.fpn_block\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.fpn_block.csp_rep2.bottlenecks.\1.conv1.conv.weight", + r"encoder\.fpn_block\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep2.bottlenecks.\1.conv1.norm.\2", + r"encoder\.fpn_block\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.fpn_block.csp_rep2.bottlenecks.\1.conv2.conv.weight", + r"encoder\.fpn_block\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep2.bottlenecks.\1.conv2.norm.\2", + r"encoder\.fpn_block\.cv3\.1\.conv\.weight": r"model.encoder.fpn_block.csp_rep2.conv3.conv.weight", + r"encoder\.fpn_block\.cv3\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep2.conv3.norm.\1", + # PAN block (same structure as FPN) + r"encoder\.pan_block\.cv1\.conv\.weight": r"model.encoder.pan_block.conv1.conv.weight", + r"encoder\.pan_block\.cv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.conv1.norm.\1", + r"encoder\.pan_block\.cv4\.conv\.weight": r"model.encoder.pan_block.conv4.conv.weight", + r"encoder\.pan_block\.cv4\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.conv4.norm.\1", + r"encoder\.pan_block\.cv2\.0\.conv1\.conv\.weight": r"model.encoder.pan_block.csp_rep1.conv1.conv.weight", + r"encoder\.pan_block\.cv2\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep1.conv1.norm.\1", + r"encoder\.pan_block\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.pan_block.csp_rep1.bottlenecks.\1.conv1.conv.weight", + r"encoder\.pan_block\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep1.bottlenecks.\1.conv1.norm.\2", + r"encoder\.pan_block\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.pan_block.csp_rep1.bottlenecks.\1.conv2.conv.weight", + r"encoder\.pan_block\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep1.bottlenecks.\1.conv2.norm.\2", + r"encoder\.pan_block\.cv2\.1\.conv\.weight": r"model.encoder.pan_block.csp_rep1.conv3.conv.weight", + r"encoder\.pan_block\.cv2\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep1.conv3.norm.\1", + r"encoder\.pan_block\.cv3\.0\.conv1\.conv\.weight": r"model.encoder.pan_block.csp_rep2.conv1.conv.weight", + r"encoder\.pan_block\.cv3\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep2.conv1.norm.\1", + r"encoder\.pan_block\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.pan_block.csp_rep2.bottlenecks.\1.conv1.conv.weight", + r"encoder\.pan_block\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep2.bottlenecks.\1.conv1.norm.\2", + r"encoder\.pan_block\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.pan_block.csp_rep2.bottlenecks.\1.conv2.conv.weight", + r"encoder\.pan_block\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep2.bottlenecks.\1.conv2.norm.\2", + r"encoder\.pan_block\.cv3\.1\.conv\.weight": r"model.encoder.pan_block.csp_rep2.conv3.conv.weight", + r"encoder\.pan_block\.cv3\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep2.conv3.norm.\1", +} + +DECODER_NO_GATEWAY_KEY_MAPPING = { + r"decoder\.decoder\.layers\.(\d+)\.norm2\.scale": r"model.decoder.layers.\1.encoder_attn_layer_norm.scale", +} + +DINOV3_KEY_MAPPING = { + # ViT embeddings + r"backbone\.dinov3\.patch_embed\.proj\.(weight|bias)": r"model.dinov3_backbone.embeddings.patch_embeddings.\1", + r"backbone\.dinov3\.cls_token": r"model.dinov3_backbone.embeddings.cls_token", + r"backbone\.dinov3\.storage_tokens": r"model.dinov3_backbone.embeddings.register_tokens", + r"backbone\.dinov3\.mask_token": r"model.dinov3_backbone.embeddings.mask_token", + # ViT blocks + r"backbone\.dinov3\.blocks\.(\d+)\.norm1\.(weight|bias)": r"model.dinov3_backbone.layers.\1.norm1.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.norm2\.(weight|bias)": r"model.dinov3_backbone.layers.\1.norm2.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.attn\.qkv\.(weight|bias)": r"model.dinov3_backbone.layers.\1.attention.qkv.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.attn\.proj\.(weight|bias)": r"model.dinov3_backbone.layers.\1.attention.o_proj.\2", + # Standard MLP (S/M/L) + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.fc1\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.up_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.fc2\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.down_proj.\2", + # SwiGLU MLP (X only) + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w1\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.gate_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w2\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.up_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w3\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.down_proj.\2", + # LayerScale (L/X only) + r"backbone\.dinov3\.blocks\.(\d+)\.ls1\.gamma": r"model.dinov3_backbone.layers.\1.layer_scale1.lambda1", + r"backbone\.dinov3\.blocks\.(\d+)\.ls2\.gamma": r"model.dinov3_backbone.layers.\1.layer_scale2.lambda1", + # Norm (L/X only) + r"backbone\.dinov3\.norm\.(weight|bias)": r"model.dinov3_backbone.norm.\1", + # STA adapter + r"backbone\.sta\.stem\.0\.(weight|bias)": r"model.dinov3_backbone.sta.stem.0.\1", + r"backbone\.sta\.stem\.1\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.sta.stem.1.\1", + r"backbone\.sta\.conv2\.0\.(weight)": r"model.dinov3_backbone.sta.conv2.0.\1", + r"backbone\.sta\.conv2\.1\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.sta.conv2.1.\1", + r"backbone\.sta\.conv3\.1\.(weight)": r"model.dinov3_backbone.sta.conv3.1.\1", + r"backbone\.sta\.conv3\.2\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.sta.conv3.2.\1", + r"backbone\.sta\.conv4\.1\.(weight)": r"model.dinov3_backbone.sta.conv4.1.\1", + r"backbone\.sta\.conv4\.2\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.sta.conv4.2.\1", + # Projection convs/norms + r"backbone\.convs\.(\d+)\.weight": r"model.dinov3_backbone.convs.\1.weight", + r"backbone\.norms\.(\d+)\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.norms.\1.\2", +} + -def convert_old_keys_to_new_keys(state_dict): - # Use the mapping to rename keys - for original_key, converted_key in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): +def convert_old_keys_to_new_keys(state_dict, config=None): + mapping = dict(ORIGINAL_TO_CONVERTED_KEY_MAPPING) + + if config is not None: + if config.encoder_type == "lite": + for k in list(mapping.keys()): + if ( + k.startswith(r"encoder\.input_proj") + or k.startswith(r"encoder\.lateral") + or k.startswith(r"encoder\.fpn_blocks") + or k.startswith(r"encoder\.pan_blocks") + or k.startswith(r"encoder\.downsample") + or k.startswith(r"encoder\.encoder") + ): + del mapping[k] + mapping.update(LITE_ENCODER_KEY_MAPPING) + + if not config.use_gateway: + mapping.update(DECODER_NO_GATEWAY_KEY_MAPPING) + for k in list(mapping.keys()): + if "gateway" in k: + del mapping[k] + + if config.backbone_type == "dinov3": + for k in list(mapping.keys()): + if k.startswith(r"backbone\."): + del mapping[k] + mapping.update(DINOV3_KEY_MAPPING) + + for original_key, converted_key in mapping.items(): for key in list(state_dict.keys()): new_key = re.sub(f"^{original_key}$", converted_key, key) if new_key != key: @@ -333,6 +569,39 @@ def read_in_q_k_v(state_dict, config): state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-d_model:] +def strip_dinov3_model_prefix(state_dict): + for key in list(state_dict.keys()): + if "backbone.dinov3._model." in key: + new_key = key.replace("backbone.dinov3._model.", "backbone.dinov3.") + state_dict[new_key] = state_dict.pop(key) + return state_dict + + +def read_in_q_k_v_vit(state_dict, config): + from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig + + vit_config = DINOv3ViTConfig(**config.dinov3_backbone_config) + has_key_bias = config.dinov3_backbone_config.get("key_bias", True) + prefix = "model.dinov3_backbone" + for i in range(vit_config.num_hidden_layers): + qkv_key = f"{prefix}.layers.{i}.attention.qkv.weight" + if qkv_key in state_dict: + qkv_w = state_dict.pop(qkv_key) + q, k, v = qkv_w.chunk(3, dim=0) + state_dict[f"{prefix}.layers.{i}.attention.q_proj.weight"] = q + state_dict[f"{prefix}.layers.{i}.attention.k_proj.weight"] = k + state_dict[f"{prefix}.layers.{i}.attention.v_proj.weight"] = v + + qkv_bias_key = f"{prefix}.layers.{i}.attention.qkv.bias" + if qkv_bias_key in state_dict: + qkv_b = state_dict.pop(qkv_bias_key) + q_b, k_b, v_b = qkv_b.chunk(3, dim=0) + state_dict[f"{prefix}.layers.{i}.attention.q_proj.bias"] = q_b + if has_key_bias: + state_dict[f"{prefix}.layers.{i}.attention.k_proj.bias"] = k_b + state_dict[f"{prefix}.layers.{i}.attention.v_proj.bias"] = v_b + + def load_original_state_dict(repo_id): filepath = hf_hub_download(repo_id=repo_id, filename="model.safetensors") return load_file(filepath) @@ -361,10 +630,26 @@ def convert_deimv2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, state_dict.pop("decoder.valid_mask", None) state_dict.pop("decoder.anchors", None) + for key in list(state_dict.keys()): + if key.endswith(".num_batches_tracked"): + state_dict.pop(key) + + if config.backbone_type == "dinov3": + strip_dinov3_model_prefix(state_dict) + for key in list(state_dict.keys()): + if "rope_embed.periods" in key or "qkv.bias_mask" in key: + state_dict.pop(key) + # query, key and value matrices need special treatment read_in_q_k_v(state_dict, config) - state_dict = convert_old_keys_to_new_keys(state_dict) + state_dict = convert_old_keys_to_new_keys(state_dict, config) + + if config.backbone_type == "dinov3": + read_in_q_k_v_vit(state_dict, config) + mask_key = "model.dinov3_backbone.embeddings.mask_token" + if mask_key in state_dict and state_dict[mask_key].dim() == 2: + state_dict[mask_key] = state_dict[mask_key].unsqueeze(1) if "model.enc_output.0.weight" not in state_dict: d_model = config.d_model @@ -373,37 +658,74 @@ def convert_deimv2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, state_dict["model.enc_output.1.weight"] = torch.ones(d_model) state_dict["model.enc_output.1.bias"] = torch.zeros(d_model) + if config.share_bbox_head: + num_decoder_layers = config.decoder_layers + for key in list(state_dict.keys()): + if "model.decoder.bbox_embed.0." in key: + for i in range(1, num_decoder_layers): + new_key = key.replace("bbox_embed.0.", f"bbox_embed.{i}.") + if new_key not in state_dict: + state_dict[new_key] = state_dict[key] + # for two_stage for key in list(state_dict.keys()): if "bbox_embed" in key or ("class_embed" in key and "denoising_" not in key): new_key = key.split("model.decoder.")[-1] - if new_key not in state_dict: + if new_key != key and new_key not in state_dict: state_dict[new_key] = state_dict[key] model = Deimv2ForObjectDetection(config) missing, unexpected = model.load_state_dict(state_dict, strict=False) - if missing: - logger.warning(f"Missing keys ({len(missing)}): {missing[:10]}...") + expected_missing = {"mask_token", "register_tokens", "layer_scale1", "layer_scale2"} + unexpected_missing = [k for k in missing if not any(e in k for e in expected_missing)] + if unexpected_missing: + logger.warning(f"Missing keys ({len(unexpected_missing)}): {unexpected_missing[:10]}...") + elif missing: + logger.info( + f"All {len(missing)} missing keys are expected model-init defaults (mask_token, register_tokens, layer_scale)" + ) if unexpected: logger.warning(f"Unexpected keys ({len(unexpected)}): {unexpected[:10]}...") model.eval() - image_processor = RTDetrImageProcessor() + is_dinov3 = config.backbone_type == "dinov3" + if is_dinov3: + image_processor = RTDetrImageProcessor( + do_normalize=True, + image_mean=[0.485, 0.456, 0.406], + image_std=[0.229, 0.224, 0.225], + ) + else: + image_processor = RTDetrImageProcessor() + img = prepare_img() - transformations = transforms.Compose( - [ - transforms.Resize([640, 640], interpolation=transforms.InterpolationMode.BILINEAR), - transforms.ToTensor(), - ] - ) + if is_dinov3: + transformations = transforms.Compose( + [ + transforms.Resize([640, 640], interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + else: + transformations = transforms.Compose( + [ + transforms.Resize([640, 640], interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + ] + ) original_pixel_values = transformations(img).unsqueeze(0) encoding = image_processor(images=img, return_tensors="pt") pixel_values = encoding["pixel_values"] - assert torch.allclose(original_pixel_values, pixel_values), "Image preprocessing mismatch!" + if not torch.allclose(original_pixel_values, pixel_values, atol=1e-4): + max_diff = (original_pixel_values - pixel_values).abs().max().item() + logger.warning(f"Image preprocessing mismatch! Max diff: {max_diff:.6f}") + if max_diff > 1e-2: + raise ValueError(f"Image preprocessing mismatch too large: {max_diff}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) diff --git a/src/transformers/models/deimv2/modeling_deimv2.py b/src/transformers/models/deimv2/modeling_deimv2.py index 154170174227..9173609eb0df 100644 --- a/src/transformers/models/deimv2/modeling_deimv2.py +++ b/src/transformers/models/deimv2/modeling_deimv2.py @@ -21,23 +21,24 @@ from collections.abc import Callable from dataclasses import dataclass +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from ... import initialization as init -from ...activations import ACT2CLS -from ...backbone_utils import load_backbone +from ...activations import ACT2CLS, ACT2FN from ...image_transforms import center_to_corners_format, corners_to_center_format +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import compile_compatible_method_lru_cache from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int -from ...utils.generic import can_return_tuple, merge_with_config_defaults +from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from .configuration_deimv2 import Deimv2Config +from .configuration_deimv2 import Deimv2Config, Deimv2DINOv3ViTConfig @dataclass @@ -351,6 +352,399 @@ def forward(self, x): return self.activation(y) +class Deimv2DINOv3ViTEmbeddings(nn.Module): + """ + Construct the CLS token, mask token, position and patch embeddings. + """ + + def __init__(self, config: Deimv2DINOv3ViTConfig): + super().__init__() + self.config = config + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.register_tokens = nn.Parameter(torch.empty(1, config.num_register_tokens, config.hidden_size)) + self.patch_embeddings = nn.Conv2d( + config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size + ) + + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embeddings.weight.dtype + + # (batch_size, num_channels, height, width) -> (batch_size, num_patches, hidden_size) + patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2) + + if bool_masked_pos is not None: + mask_token = self.mask_token.to(patch_embeddings.dtype) + patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings) + + # Add CLS and register tokens + cls_token = self.cls_token.expand(batch_size, -1, -1) + register_tokens = self.register_tokens.expand(batch_size, -1, -1) + embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) + + return embeddings + + +class Deimv2DINOv3ViTLayerScale(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size)) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return hidden_state * self.lambda1 + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float | None = None, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, **kwargs +) -> tuple[torch.Tensor, torch.Tensor]: + """Applies Rotary Position Embedding to the query and key tensors, but only to the patch tokens, + ignoring the prefix tokens (cls token and register tokens). + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + num_tokens = q.shape[-2] + num_patches = sin.shape[-2] + num_prefix_tokens = num_tokens - num_patches # cls token + register tokens + + q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2) + k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2) + + # apply rope only to patch tokens + q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin) + k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin) + + q = torch.cat((q_prefix_tokens, q_patches), dim=-2) + k = torch.cat((k_prefix_tokens, k_patches), dim=-2) + + return q, k + + +class Deimv2DINOv3ViTAttention(nn.Module): + """ + Multi-headed attention compatible with ALL_ATTENTION_FUNCTIONS. + """ + + def __init__(self, config: Deimv2DINOv3ViTConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.is_causal = False + + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.dropout = config.attention_dropout + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.key_bias) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.value_bias) + + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.query_bias) + self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.proj_bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Input shape: Batch x Time x Channel""" + + batch_size, patches, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +class Deimv2DINOv3ViTDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float | None = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return f"p={self.drop_prob}" + + +class Deimv2DINOv3ViTMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + + +class Deimv2DINOv3ViTGatedMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Deimv2DINOv3ViTLayer(GradientCheckpointingLayer): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: Deimv2DINOv3ViTConfig): + super().__init__() + + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = Deimv2DINOv3ViTAttention(config) + self.layer_scale1 = Deimv2DINOv3ViTLayerScale(config) + self.drop_path = ( + Deimv2DINOv3ViTDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + ) + + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.use_gated_mlp: + self.mlp = Deimv2DINOv3ViTGatedMLP(config) + else: + self.mlp = Deimv2DINOv3ViTMLP(config) + self.layer_scale2 = Deimv2DINOv3ViTLayerScale(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + # Attention with residual connection + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states, _ = self.attention( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + hidden_states = self.layer_scale1(hidden_states) + hidden_states = self.drop_path(hidden_states) + residual + + # MLP with residual connection + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.layer_scale2(hidden_states) + hidden_states = self.drop_path(hidden_states) + residual + + return hidden_states + + +@compile_compatible_method_lru_cache(maxsize=32) +def get_patches_center_coordinates( + num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + """ + Computes the 2D coordinates of the centers of image patches, normalized to the range [-1, +1]. + The center of each patch is exactly halfway between its top-left and bottom-right corners. + + Args: + num_patches_h (int): Number of patches along the vertical (height) axis. + num_patches_w (int): Number of patches along the horizontal (width) axis. + dtype (torch.dtype): The desired data type of the returned tensor. + + Returns: + torch.Tensor: A tensor of shape (height * width, 2), where each row contains the (y, x) + coordinates of a patch center, normalized to [-1, +1]. + """ + coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device) + coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device) + coords_h = coords_h / num_patches_h + coords_w = coords_w / num_patches_w + # (height, width, 2) -> (height * width, 2) + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) + coords = coords.flatten(0, 1) + # Shift range [0, 1] to [-1, +1] + coords = 2.0 * coords - 1.0 + return coords + + +def augment_patches_center_coordinates( + coords: torch.Tensor, + shift: float | None = None, + jitter: float | None = None, + rescale: float | None = None, +) -> torch.Tensor: + # Shift coords by adding a uniform value in [-shift, shift] + if shift is not None: + shift_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype) + shift_hw = shift_hw.uniform_(-shift, shift) + coords = coords + shift_hw + + # Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter] + if jitter is not None: + jitter_range = np.log(jitter) + jitter_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype) + jitter_hw = jitter_hw.uniform_(-jitter_range, jitter_range).exp() + coords = coords * jitter_hw + + # Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale] + if rescale is not None: + rescale_range = np.log(rescale) + rescale_hw = torch.empty(1, device=coords.device, dtype=coords.dtype) + rescale_hw = rescale_hw.uniform_(-rescale_range, rescale_range).exp() + coords = coords * rescale_hw + + return coords + + +class Deimv2DINOv3ViTRopePositionEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, config: Deimv2DINOv3ViTConfig): + super().__init__() + + self.config = config + self.base = config.rope_theta + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_patches_h = config.image_size // config.patch_size + self.num_patches_w = config.image_size // config.patch_size + + inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32) # (head_dim / 4,) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + _, _, height, width = pixel_values.shape + num_patches_h = height // self.config.patch_size + num_patches_w = width // self.config.patch_size + + device = pixel_values.device + device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu" + + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + # Although we could precompute static patch_coords from image_size and patch_size in the config, + # the model was trained with random_scale, so it can process images of varying sizes. + # Therefore, it's better to compute patch_coords dynamically (with lru_cache). + patch_coords = get_patches_center_coordinates( + num_patches_h, num_patches_w, dtype=torch.float32, device=device + ) + if self.training: + patch_coords = augment_patches_center_coordinates( + patch_coords, + shift=self.config.pos_embed_shift, + jitter=self.config.pos_embed_jitter, + rescale=self.config.pos_embed_rescale, + ) + + # (height * width, 2, head_dim / 4) -> (height * width, head_dim / 2) -> (height * width, head_dim) + angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] + angles = angles.flatten(1, 2) + angles = angles.tile(2) + + cos = torch.cos(angles) + sin = torch.sin(angles) + + dtype = pixel_values.dtype + return cos.to(dtype=dtype), sin.to(dtype=dtype) + + class Deimv2CSPRepLayer2(nn.Module): def __init__( self, config: Deimv2Config, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0 @@ -362,7 +756,10 @@ def __init__( self.bottlenecks = nn.ModuleList( [Deimv2RepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)] ) - self.conv3 = Deimv2ConvNormLayer(config, hidden_channels, out_channels, 3, 1, activation=activation) + if config.encoder_has_trailing_conv: + self.conv3 = Deimv2ConvNormLayer(config, hidden_channels, out_channels, 3, 1, activation=activation) + else: + self.conv3 = nn.Identity() def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: chunks = list(self.conv1(hidden_state).chunk(2, 1)) @@ -373,13 +770,21 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: class Deimv2RepNCSPELAN5(nn.Module): - def __init__(self, config: Deimv2Config, numb_blocks: int = 3): + def __init__( + self, + config: Deimv2Config, + numb_blocks: int = 3, + c1: int | None = None, + c2: int | None = None, + c3: int | None = None, + c4: int | None = None, + ): super().__init__() act = config.activation_function - c1 = config.encoder_hidden_dim - c2 = config.encoder_hidden_dim - c3 = config.encoder_hidden_dim * 2 - c4 = round(config.hidden_expansion * config.encoder_hidden_dim // 2) + c1 = c1 if c1 is not None else config.encoder_hidden_dim + c2 = c2 if c2 is not None else config.encoder_hidden_dim + c3 = c3 if c3 is not None else config.encoder_hidden_dim * 2 + c4 = c4 if c4 is not None else round(config.hidden_expansion * config.encoder_hidden_dim // 2) self.conv_dim = c3 // 2 self.conv1 = Deimv2ConvNormLayer(config, c1, c3, 1, 1, activation=act) self.csp_rep1 = Deimv2CSPRepLayer2(config, c3 // 2, c4, num_blocks=numb_blocks) @@ -414,34 +819,6 @@ def forward(self, input_features: torch.Tensor) -> torch.Tensor: return input_features -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor | None, - scaling: float | None = None, - dropout: float = 0.0, - **kwargs: Unpack[TransformersKwargs], -): - if scaling is None: - scaling = query.size(-1) ** -0.5 - - # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - - attn_output = torch.matmul(attn_weights, value) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - class Deimv2SelfAttention(nn.Module): """ Multi-headed self-attention from 'Attention Is All You Need' paper. @@ -708,6 +1085,135 @@ def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tenso return c2, c3, c4 +class Deimv2GAPFusion(nn.Module): + def __init__(self, config: Deimv2Config, channels: int): + super().__init__() + self.cv = Deimv2ConvNormLayer(config, channels, channels, 1, 1, activation=config.activation_function) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return self.cv(hidden_state + F.adaptive_avg_pool2d(hidden_state, 1)) + + +class Deimv2LiteEncoder(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + hidden_dim = config.encoder_hidden_dim + act = config.activation_function + + self.input_proj = nn.ModuleList() + for in_channel in config.encoder_in_channels: + self.input_proj.append( + nn.Sequential( + nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(hidden_dim), + ) + ) + + self.down_sample1 = nn.Sequential( + nn.AvgPool2d(kernel_size=3, stride=2, padding=1), + nn.Conv2d(hidden_dim, hidden_dim, 1, 1, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.SiLU(inplace=True) if act == "silu" else nn.ReLU(inplace=True), + ) + self.down_sample2 = nn.Sequential( + nn.AvgPool2d(kernel_size=3, stride=2, padding=1), + nn.Conv2d(hidden_dim, hidden_dim, 1, 1, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.SiLU(inplace=True) if act == "silu" else nn.ReLU(inplace=True), + ) + + self.bi_fusion = Deimv2GAPFusion(config, hidden_dim) + + c1, c2 = hidden_dim, hidden_dim + c3 = hidden_dim * 2 + c4 = round(config.hidden_expansion * hidden_dim // 2) + num_blocks = round(3 * config.depth_mult) + self.fpn_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks, c1=c1, c2=c2, c3=c3, c4=c4) + self.pan_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks, c1=c1, c2=c2, c3=c3, c4=c4) + + def forward(self, inputs_embeds=None, **kwargs) -> BaseModelOutput: + feats = inputs_embeds + proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] + proj_feats.append(self.down_sample1(proj_feats[-1])) + + proj_feats[-1] = self.bi_fusion(proj_feats[-1]) + + outs = [] + fuse_feat = proj_feats[0] + F.interpolate(proj_feats[1], scale_factor=2.0, mode="nearest") + outs.append(self.fpn_block(fuse_feat)) + + fuse_feat = proj_feats[1] + self.down_sample2(outs[-1]) + outs.append(self.pan_block(fuse_feat)) + + return BaseModelOutput(last_hidden_state=outs) + + +class Deimv2DINOv3Backbone(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + vit_config = Deimv2DINOv3ViTConfig(**config.dinov3_backbone_config) + + self.embeddings = Deimv2DINOv3ViTEmbeddings(vit_config) + self.rope_embeddings = Deimv2DINOv3ViTRopePositionEmbedding(vit_config) + self.layers = nn.ModuleList([Deimv2DINOv3ViTLayer(vit_config) for _ in range(vit_config.num_hidden_layers)]) + + self.apply_layernorm = config.dinov3_apply_layernorm + if self.apply_layernorm: + self.norm = nn.LayerNorm(vit_config.hidden_size, eps=vit_config.layer_norm_eps) + + self.interaction_indexes = config.dinov3_interaction_indexes + self.patch_size = vit_config.patch_size + self.num_prefix_tokens = 1 + vit_config.num_register_tokens + + self.sta = Deimv2SpatialTuningAdapter(config) + + embed_dim = vit_config.hidden_size + hidden_dim = config.dinov3_hidden_dim or embed_dim + sta_ch = config.sta_inplanes + self.convs = nn.ModuleList( + [ + nn.Conv2d(embed_dim + sta_ch * 2, hidden_dim, 1, bias=False), + nn.Conv2d(embed_dim + sta_ch * 4, hidden_dim, 1, bias=False), + nn.Conv2d(embed_dim + sta_ch * 4, hidden_dim, 1, bias=False), + ] + ) + self.norms = nn.ModuleList([nn.BatchNorm2d(hidden_dim) for _ in range(3)]) + + def forward(self, pixel_values: torch.Tensor) -> list[torch.Tensor]: + hidden_states = self.embeddings(pixel_values) + position_embeddings = self.rope_embeddings(pixel_values) + + intermediate_outputs = [] + for i, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, position_embeddings=position_embeddings) + if i in self.interaction_indexes: + out = self.norm(hidden_states) if self.apply_layernorm else hidden_states + intermediate_outputs.append(out) + + batch_size = pixel_values.shape[0] + h_patches = pixel_values.shape[2] // self.patch_size + w_patches = pixel_values.shape[3] // self.patch_size + + sem_feats = [] + num_scales = len(intermediate_outputs) + for i, feat in enumerate(intermediate_outputs): + patch_tokens = feat[:, self.num_prefix_tokens :] + spatial = patch_tokens.transpose(1, 2).reshape(batch_size, -1, h_patches, w_patches).contiguous() + resize_h = int(h_patches * 2 ** (num_scales - 2 - i)) + resize_w = int(w_patches * 2 ** (num_scales - 2 - i)) + spatial = F.interpolate(spatial, size=[resize_h, resize_w], mode="bilinear", align_corners=False) + sem_feats.append(spatial) + + detail_feats = self.sta(pixel_values) + + outs = [] + for i, (sem_feat, detail_feat) in enumerate(zip(sem_feats, detail_feats)): + fused = torch.cat([sem_feat, detail_feat], dim=1) + outs.append(self.norms[i](self.convs[i](fused))) + + return outs + + class Deimv2Integral(nn.Module): """ A static layer that calculates integral results from a distribution. @@ -767,7 +1273,13 @@ def __init__(self, config: Deimv2Config): self.encoder_attn = Deimv2MultiscaleDeformableAttention(config=config) self.mlp = Deimv2SwiGLUFFN(config.d_model, config.decoder_ffn_dim // 2, config.d_model) self.final_layer_norm = Deimv2RMSNorm(config.d_model) + # gate self.gateway = Deimv2Gate(config.d_model) + if config.use_gateway: + self.gateway = Deimv2Gate(config.d_model) + else: + del self.gateway + self.encoder_attn_layer_norm = Deimv2RMSNorm(config.d_model) def forward( self, @@ -778,7 +1290,7 @@ def forward( spatial_shapes_list=None, encoder_hidden_states: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, - **kwargs: Unpack[TransformersKwargs], + **kwargs, ) -> torch.Tensor: """ Args: @@ -801,7 +1313,6 @@ def forward( """ residual = hidden_states - # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=encoder_attention_mask, @@ -815,7 +1326,6 @@ def forward( residual = hidden_states - # Cross-Attention hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings hidden_states, _ = self.encoder_attn( hidden_states=hidden_states, @@ -825,9 +1335,13 @@ def forward( spatial_shapes_list=spatial_shapes_list, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = self.gateway(residual, hidden_states) - # Fully Connected + if hasattr(self, "gateway"): + hidden_states = self.gateway(residual, hidden_states) + else: + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + residual = hidden_states hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states @@ -942,6 +1456,20 @@ def _init_weights(self, module): init.ones_(module.weight) init.zeros_(module.bias) + if isinstance(module, Deimv2DINOv3ViTEmbeddings): + init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) + if module.config.num_register_tokens > 0: + init.trunc_normal_(module.register_tokens, mean=0.0, std=self.config.initializer_range) + init.zeros_(module.mask_token) + + if isinstance(module, Deimv2DINOv3ViTLayerScale) and self.config.dinov3_backbone_config is not None: + layerscale_value = self.config.dinov3_backbone_config.get("layerscale_value", 1.0) + init.constant_(module.lambda1, layerscale_value) + + if isinstance(module, Deimv2DINOv3ViTRopePositionEmbedding): + inv_freq = 1 / module.base ** torch.arange(0, 1, 4 / module.head_dim, dtype=torch.float32) + init.copy_(module.inv_freq, inv_freq) + if hasattr(module, "weight_embedding") and self.config.learn_initial_query: init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: @@ -1259,45 +1787,6 @@ def forward( ) -class Deimv2FrozenBatchNorm2d(nn.Module): - """ - BatchNorm2d where the batch statistics and the affine parameters are fixed. - - Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than - torchvision.models.resnet[18,34,50,101] produce nans. - """ - - def __init__(self, n): - super().__init__() - self.register_buffer("weight", torch.ones(n)) - self.register_buffer("bias", torch.zeros(n)) - self.register_buffer("running_mean", torch.zeros(n)) - self.register_buffer("running_var", torch.ones(n)) - - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - num_batches_tracked_key = prefix + "num_batches_tracked" - if num_batches_tracked_key in state_dict: - del state_dict[num_batches_tracked_key] - - super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) - - def forward(self, x): - # move reshapes to the beginning - # to make it user-friendly - weight = self.weight.reshape(1, -1, 1, 1) - bias = self.bias.reshape(1, -1, 1, 1) - running_var = self.running_var.reshape(1, -1, 1, 1) - running_mean = self.running_mean.reshape(1, -1, 1, 1) - epsilon = 1e-5 - scale = weight * (running_var + epsilon).rsqrt() - bias = bias - running_mean * scale - return x * scale + bias - - @dataclass @auto_docstring( custom_intro=""" @@ -1356,62 +1845,6 @@ class Deimv2ModelOutput(ModelOutput): denoising_meta_values: dict | None = None -def replace_batch_norm(model): - r""" - Recursively replace all `torch.nn.BatchNorm2d` with `Deimv2FrozenBatchNorm2d`. - - Args: - model (torch.nn.Module): - input model - """ - for name, module in model.named_children(): - if isinstance(module, nn.BatchNorm2d): - new_module = Deimv2FrozenBatchNorm2d(module.num_features) - - if module.weight.device != torch.device("meta"): - new_module.weight.copy_(module.weight) - new_module.bias.copy_(module.bias) - new_module.running_mean.copy_(module.running_mean) - new_module.running_var.copy_(module.running_var) - - model._modules[name] = new_module - - if len(list(module.children())) > 0: - replace_batch_norm(module) - - -class Deimv2ConvEncoder(nn.Module): - """ - Convolutional backbone using the modeling_deimv2_resnet.py. - - nn.BatchNorm2d layers are replaced by Deimv2FrozenBatchNorm2d as defined above. - https://github.com/lyuwenyu/RT-DETR/blob/main/Deimv2_pytorch/src/nn/backbone/presnet.py#L142 - """ - - def __init__(self, config): - super().__init__() - - backbone = load_backbone(config) - - if config.freeze_backbone_batch_norms: - # replace batch norm by frozen batch norm - with torch.no_grad(): - replace_batch_norm(backbone) - self.model = backbone - self.intermediate_channel_sizes = self.model.channels - - def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): - # send pixel_values through the model to get list of feature maps - features = self.model(pixel_values).feature_maps - - out = [] - for feature_map in features: - # downsample pixel_mask to match shape of corresponding feature_map - mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] - out.append((feature_map, mask)) - return out - - def get_contrastive_denoising_training_group( targets, num_classes, @@ -1544,62 +1977,48 @@ class Deimv2Model(Deimv2PreTrainedModel): def __init__(self, config: Deimv2Config): super().__init__(config) - # Create backbone - self.backbone = Deimv2ConvEncoder(config) - intermediate_channel_sizes = self.backbone.intermediate_channel_sizes - num_backbone_outs = len(config.decoder_in_channels) - encoder_input_proj_list = [] - for i in range(num_backbone_outs): - in_channels = intermediate_channel_sizes[i] - encoder_input_proj_list.append( - nn.Sequential( - nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False), - nn.BatchNorm2d(config.encoder_hidden_dim), - ) - ) - self.encoder_input_proj = nn.ModuleList(encoder_input_proj_list) - self.encoder = Deimv2HybridEncoder(config=config) + if config.backbone_type == "dinov3": + self.dinov3_backbone = Deimv2DINOv3Backbone(config) + else: + from ..d_fine.modeling_d_fine import DFineConvEncoder + + self.backbone = DFineConvEncoder(config) + if config.encoder_type != "lite": + intermediate_channel_sizes = self.backbone.intermediate_channel_sizes + encoder_input_proj = [] + for in_channel in intermediate_channel_sizes: + encoder_input_proj.append( + nn.Sequential( + nn.Conv2d(in_channel, config.encoder_hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(config.encoder_hidden_dim), + ) + ) + self.encoder_input_proj = nn.ModuleList(encoder_input_proj) + + if config.encoder_type == "lite": + self.encoder = Deimv2LiteEncoder(config) + else: + self.encoder = Deimv2HybridEncoder(config=config) - # denoising part if config.num_denoising > 0: self.denoising_class_embed = nn.Embedding( config.num_labels + 1, config.d_model, padding_idx=config.num_labels ) - # decoder embedding if config.learn_initial_query: self.weight_embedding = nn.Embedding(config.num_queries, config.d_model) - # encoder head self.enc_output = nn.Sequential( nn.Linear(config.d_model, config.d_model), nn.LayerNorm(config.d_model, eps=config.layer_norm_eps), ) self.enc_score_head = nn.Linear(config.d_model, config.num_labels) - self.enc_bbox_head = Deimv2MLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3) + self.enc_bbox_head = Deimv2MLP(config.d_model, config.d_model, 4, 3) - # init encoder output anchors and valid_mask if config.anchor_image_size: self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype) + num_backbone_outs = len(config.decoder_in_channels) - decoder_input_proj_list = [] - for i in range(num_backbone_outs): - in_channels = config.decoder_in_channels[i] - decoder_input_proj_list.append( - nn.Sequential( - nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False), - nn.BatchNorm2d(config.d_model, config.batch_norm_eps), - ) - ) - for _ in range(config.num_feature_levels - num_backbone_outs): - decoder_input_proj_list.append( - nn.Sequential( - nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False), - nn.BatchNorm2d(config.d_model, config.batch_norm_eps), - ) - ) - in_channels = config.d_model - self.decoder = Deimv2Decoder(config) decoder_input_proj = [] in_channels = config.decoder_in_channels[-1] for _ in range(num_backbone_outs): @@ -1617,8 +2036,9 @@ def __init__(self, config: Deimv2Config): batchnorm = nn.BatchNorm2d(config.d_model, config.batch_norm_eps) decoder_input_proj.append(nn.Sequential(conv, batchnorm)) self.decoder_input_proj = nn.ModuleList(decoder_input_proj) + self.decoder = Deimv2Decoder(config) - if config.use_spatial_tuning_adapter: + if config.use_spatial_tuning_adapter and config.backbone_type != "dinov3": self.spatial_tuning_adapter = Deimv2SpatialTuningAdapter(config) self.post_init() @@ -1669,7 +2089,7 @@ def forward( encoder_outputs: torch.FloatTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: list[dict] | None = None, - **kwargs: Unpack[TransformersKwargs], + **kwargs, ) -> tuple[torch.FloatTensor] | Deimv2ModelOutput: r""" inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): @@ -1710,8 +2130,15 @@ def forward( device = pixel_values.device if pixel_mask is None: pixel_mask = torch.ones(((batch_size, height, width)), device=device) - features = self.backbone(pixel_values, pixel_mask) - proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)] + + if self.config.backbone_type == "dinov3": + proj_feats = self.dinov3_backbone(pixel_values) + elif self.config.encoder_type == "lite": + features = self.backbone(pixel_values, pixel_mask) + proj_feats = [source for source, mask in features] + else: + features = self.backbone(pixel_values, pixel_mask) + proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)] else: batch_size = inputs_embeds.shape[0] device = inputs_embeds.device @@ -1722,7 +2149,6 @@ def forward( proj_feats, **kwargs, ) - # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput elif not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], @@ -1730,20 +2156,16 @@ def forward( attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - # Equivalent to def _get_encoder_input - # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/Deimv2_pytorch/src/zoo/Deimv2/Deimv2_decoder.py#L412 sources = [] for level, source in enumerate(encoder_outputs.last_hidden_state): sources.append(self.decoder_input_proj[level](source)) - # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage if self.config.num_feature_levels > len(sources): _len_sources = len(sources) - sources.append(self.decoder_input_proj[_len_sources](encoder_outputs.last_hidden_state)[-1]) + sources.append(self.decoder_input_proj[_len_sources](encoder_outputs.last_hidden_state[-1])) for i in range(_len_sources + 1, self.config.num_feature_levels): sources.append(self.decoder_input_proj[i](encoder_outputs.last_hidden_state[-1])) - # Prepare encoder inputs (by flattening) source_flatten = [] spatial_shapes_list = [] spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long) @@ -1757,8 +2179,9 @@ def forward( source_flatten = torch.cat(source_flatten, 1) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) - # prepare denoising training if self.training and self.config.num_denoising > 0 and labels is not None: + from ..d_fine.modeling_d_fine import get_contrastive_denoising_training_group + ( denoising_class, denoising_bbox_unact, @@ -1780,17 +2203,13 @@ def forward( device = source_flatten.device dtype = source_flatten.dtype - # prepare input for decoder if self.training or self.config.anchor_image_size is None: - # Pass spatial_shapes as tuple to make it hashable and make sure - # lru_cache is working for generate_anchors() spatial_shapes_tuple = tuple(spatial_shapes_list) anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype) else: anchors, valid_mask = self.anchors, self.valid_mask anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype) - # use the valid_mask to selectively retain values in the feature map where the mask is `True` memory = valid_mask.to(source_flatten.dtype) * source_flatten output_memory = self.enc_output(memory) @@ -1812,7 +2231,6 @@ def forward( dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]) ) - # extract region features if self.config.learn_initial_query: target = self.weight_embedding.tile([batch_size, 1, 1]) else: @@ -1824,7 +2242,8 @@ def forward( init_reference_points = reference_points_unact.detach() - # decoder + from ..d_fine.modeling_d_fine import DFineModelOutput + decoder_outputs = self.decoder( inputs_embeds=target, encoder_hidden_states=source_flatten, @@ -1836,7 +2255,7 @@ def forward( **kwargs, ) - return Deimv2ModelOutput( + return DFineModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, intermediate_logits=decoder_outputs.intermediate_logits, @@ -1959,16 +2378,20 @@ def __init__(self, config: Deimv2Config): scaled_dim = round(config.layer_scale * config.hidden_size) num_pred = config.decoder_layers self.class_embed = nn.ModuleList([nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)]) - self.bbox_embed = nn.ModuleList( - [ - Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) - for _ in range(self.eval_idx + 1) - ] - + [ - Deimv2MLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3) - for _ in range(config.decoder_layers - self.eval_idx - 1) - ] - ) + if config.share_bbox_head: + shared_bbox = Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) + self.bbox_embed = nn.ModuleList([shared_bbox] * num_pred) + else: + self.bbox_embed = nn.ModuleList( + [ + Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) + for _ in range(self.eval_idx + 1) + ] + + [ + Deimv2MLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3) + for _ in range(config.decoder_layers - self.eval_idx - 1) + ] + ) self.model.decoder.class_embed = self.class_embed self.model.decoder.bbox_embed = self.bbox_embed @@ -2094,5 +2517,17 @@ def forward( denoising_meta_values=outputs.denoising_meta_values, ) + @property + def _tied_weights_keys(self): + keys = { + r"class_embed.(?![0])\d+": r"^class_embed.0", + "class_embed": "model.decoder.class_embed", + "bbox_embed": "model.decoder.bbox_embed", + } + if getattr(self.config, "share_bbox_head", False): + keys[r"model\.decoder\.bbox_embed\.(?![0])\d+"] = r"model.decoder.bbox_embed.0" + keys[r"bbox_embed.(?![0])\d+"] = r"bbox_embed.0" + return keys + __all__ = ["Deimv2Model", "Deimv2PreTrainedModel", "Deimv2ForObjectDetection"] diff --git a/src/transformers/models/deimv2/modular_deimv2.py b/src/transformers/models/deimv2/modular_deimv2.py index 05bd3b8d4da5..5a7edf292378 100644 --- a/src/transformers/models/deimv2/modular_deimv2.py +++ b/src/transformers/models/deimv2/modular_deimv2.py @@ -39,6 +39,13 @@ DFineRepVggBlock, DFineSCDown, ) +from ..dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig +from ..dinov3_vit.modeling_dinov3_vit import ( + DINOv3ViTEmbeddings, + DINOv3ViTLayer, + DINOv3ViTLayerScale, + DINOv3ViTRopePositionEmbedding, +) from ..rt_detr.modeling_rt_detr import RTDetrAIFILayer @@ -49,7 +56,8 @@ class Deimv2Config(DFineConfig): """ This is the configuration class to store the configuration of a [`Deimv2Model`]. It is used to instantiate a DEIMv2 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of DEIMv2-L-COCO. + with the defaults will yield a similar configuration to that of DEIMv2-HGNetv2-N-COCO + [Intellindust/DEIMv2_HGNetv2_N_COCO](https://huggingface.co/Intellindust/DEIMv2_HGNetv2_N_COCO). Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PreTrainedConfig`] for more information. @@ -198,6 +206,29 @@ class Deimv2Config(DFineConfig): Whether to use the Spatial Tuning Adapter (STA) for DINOv2 backbone variants. sta_inplanes (`int`, *optional*, defaults to 16): Number of input planes for the STA convolutional stem. + encoder_type (`str`, *optional*, defaults to `"hybrid"`): + Type of encoder to use. `"hybrid"` uses the full HybridEncoder with AIFI, FPN, and PAN. + `"lite"` uses the lightweight LiteEncoder with GAP fusion for smaller variants (Atto, Femto, Pico). + use_gateway (`bool`, *optional*, defaults to `True`): + Whether to use the gateway mechanism (cross-attention gating) in decoder layers. When `False`, + uses RMSNorm on the encoder attention output instead. + share_bbox_head (`bool`, *optional*, defaults to `False`): + Whether to share the bounding box prediction head across all decoder layers. + backbone_type (`str`, *optional*, defaults to `"hgnetv2"`): + Type of backbone to use. `"hgnetv2"` uses HGNetV2, `"dinov3"` uses DINOv3 ViT backbone with STA. + dinov3_backbone_config (`dict`, *optional*): + Configuration dictionary for the DINOv3 ViT backbone. Passed as kwargs to `DINOv3ViTConfig`. + dinov3_interaction_indexes (`list[int]`, *optional*): + Layer indices in the DINOv3 ViT backbone from which to extract intermediate features. + dinov3_hidden_dim (`int`, *optional*): + Hidden dimension for the DINOv3 backbone projection convolutions. If `None`, uses `hidden_size` from + the DINOv3 ViT config. + dinov3_apply_layernorm (`bool`, *optional*, defaults to `False`): + Whether to apply LayerNorm to intermediate features extracted from the DINOv3 ViT backbone. + encoder_has_trailing_conv (`bool`, *optional*, defaults to `True`): + Whether the encoder's CSP blocks include a trailing 3x3 convolution after the bottleneck path. + `True` for RepNCSPELAN4 (used by HGNetV2 N and LiteEncoder variants). + `False` for RepNCSPELAN5 (used by DINOv3 variants). tie_word_embeddings (`bool`, *optional*, defaults to `True`): Whether to tie weight embeddings. """ @@ -276,6 +307,15 @@ def __init__( encoder_fuse_op="sum", use_spatial_tuning_adapter=False, sta_inplanes=16, + encoder_type="hybrid", + use_gateway=True, + share_bbox_head=False, + backbone_type="hgnetv2", + dinov3_backbone_config=None, + dinov3_interaction_indexes=None, + dinov3_hidden_dim=None, + dinov3_apply_layernorm=False, + encoder_has_trailing_conv=True, tie_word_embeddings=True, **kwargs, ): @@ -285,6 +325,15 @@ def __init__( self.encoder_fuse_op = encoder_fuse_op self.use_spatial_tuning_adapter = use_spatial_tuning_adapter self.sta_inplanes = sta_inplanes + self.encoder_type = encoder_type + self.use_gateway = use_gateway + self.share_bbox_head = share_bbox_head + self.backbone_type = backbone_type + self.dinov3_backbone_config = dinov3_backbone_config + self.dinov3_interaction_indexes = dinov3_interaction_indexes + self.dinov3_hidden_dim = dinov3_hidden_dim + self.dinov3_apply_layernorm = dinov3_apply_layernorm + self.encoder_has_trailing_conv = encoder_has_trailing_conv super().__init__( initializer_range=initializer_range, initializer_bias_prior_prob=initializer_bias_prior_prob, @@ -355,6 +404,10 @@ def __init__( ) +class Deimv2DINOv3ViTConfig(DINOv3ViTConfig): + model_type = "deimv2_dinov3_vit" + + class Deimv2DecoderOutput(DFineDecoderOutput): pass @@ -407,6 +460,22 @@ class Deimv2RepVggBlock(DFineRepVggBlock): pass +class Deimv2DINOv3ViTEmbeddings(DINOv3ViTEmbeddings): + pass + + +class Deimv2DINOv3ViTLayerScale(DINOv3ViTLayerScale): + pass + + +class Deimv2DINOv3ViTLayer(DINOv3ViTLayer): + pass + + +class Deimv2DINOv3ViTRopePositionEmbedding(DINOv3ViTRopePositionEmbedding): + pass + + class Deimv2CSPRepLayer2(nn.Module): def __init__( self, config: Deimv2Config, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0 @@ -418,7 +487,10 @@ def __init__( self.bottlenecks = nn.ModuleList( [Deimv2RepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)] ) - self.conv3 = Deimv2ConvNormLayer(config, hidden_channels, out_channels, 3, 1, activation=activation) + if config.encoder_has_trailing_conv: + self.conv3 = Deimv2ConvNormLayer(config, hidden_channels, out_channels, 3, 1, activation=activation) + else: + self.conv3 = nn.Identity() def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: chunks = list(self.conv1(hidden_state).chunk(2, 1)) @@ -429,13 +501,21 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: class Deimv2RepNCSPELAN5(nn.Module): - def __init__(self, config: Deimv2Config, numb_blocks: int = 3): + def __init__( + self, + config: Deimv2Config, + numb_blocks: int = 3, + c1: int | None = None, + c2: int | None = None, + c3: int | None = None, + c4: int | None = None, + ): super().__init__() act = config.activation_function - c1 = config.encoder_hidden_dim - c2 = config.encoder_hidden_dim - c3 = config.encoder_hidden_dim * 2 - c4 = round(config.hidden_expansion * config.encoder_hidden_dim // 2) + c1 = c1 if c1 is not None else config.encoder_hidden_dim + c2 = c2 if c2 is not None else config.encoder_hidden_dim + c3 = c3 if c3 is not None else config.encoder_hidden_dim * 2 + c4 = c4 if c4 is not None else round(config.hidden_expansion * config.encoder_hidden_dim // 2) self.conv_dim = c3 // 2 self.conv1 = Deimv2ConvNormLayer(config, c1, c3, 1, 1, activation=act) self.csp_rep1 = Deimv2CSPRepLayer2(config, c3 // 2, c4, num_blocks=numb_blocks) @@ -496,6 +576,135 @@ def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tenso return c2, c3, c4 +class Deimv2GAPFusion(nn.Module): + def __init__(self, config: Deimv2Config, channels: int): + super().__init__() + self.cv = Deimv2ConvNormLayer(config, channels, channels, 1, 1, activation=config.activation_function) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return self.cv(hidden_state + F.adaptive_avg_pool2d(hidden_state, 1)) + + +class Deimv2LiteEncoder(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + hidden_dim = config.encoder_hidden_dim + act = config.activation_function + + self.input_proj = nn.ModuleList() + for in_channel in config.encoder_in_channels: + self.input_proj.append( + nn.Sequential( + nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(hidden_dim), + ) + ) + + self.down_sample1 = nn.Sequential( + nn.AvgPool2d(kernel_size=3, stride=2, padding=1), + nn.Conv2d(hidden_dim, hidden_dim, 1, 1, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.SiLU(inplace=True) if act == "silu" else nn.ReLU(inplace=True), + ) + self.down_sample2 = nn.Sequential( + nn.AvgPool2d(kernel_size=3, stride=2, padding=1), + nn.Conv2d(hidden_dim, hidden_dim, 1, 1, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.SiLU(inplace=True) if act == "silu" else nn.ReLU(inplace=True), + ) + + self.bi_fusion = Deimv2GAPFusion(config, hidden_dim) + + c1, c2 = hidden_dim, hidden_dim + c3 = hidden_dim * 2 + c4 = round(config.hidden_expansion * hidden_dim // 2) + num_blocks = round(3 * config.depth_mult) + self.fpn_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks, c1=c1, c2=c2, c3=c3, c4=c4) + self.pan_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks, c1=c1, c2=c2, c3=c3, c4=c4) + + def forward(self, inputs_embeds=None, **kwargs) -> BaseModelOutput: + feats = inputs_embeds + proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] + proj_feats.append(self.down_sample1(proj_feats[-1])) + + proj_feats[-1] = self.bi_fusion(proj_feats[-1]) + + outs = [] + fuse_feat = proj_feats[0] + F.interpolate(proj_feats[1], scale_factor=2.0, mode="nearest") + outs.append(self.fpn_block(fuse_feat)) + + fuse_feat = proj_feats[1] + self.down_sample2(outs[-1]) + outs.append(self.pan_block(fuse_feat)) + + return BaseModelOutput(last_hidden_state=outs) + + +class Deimv2DINOv3Backbone(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + vit_config = Deimv2DINOv3ViTConfig(**config.dinov3_backbone_config) + + self.embeddings = Deimv2DINOv3ViTEmbeddings(vit_config) + self.rope_embeddings = Deimv2DINOv3ViTRopePositionEmbedding(vit_config) + self.layers = nn.ModuleList([Deimv2DINOv3ViTLayer(vit_config) for _ in range(vit_config.num_hidden_layers)]) + + self.apply_layernorm = config.dinov3_apply_layernorm + if self.apply_layernorm: + self.norm = nn.LayerNorm(vit_config.hidden_size, eps=vit_config.layer_norm_eps) + + self.interaction_indexes = config.dinov3_interaction_indexes + self.patch_size = vit_config.patch_size + self.num_prefix_tokens = 1 + vit_config.num_register_tokens + + self.sta = Deimv2SpatialTuningAdapter(config) + + embed_dim = vit_config.hidden_size + hidden_dim = config.dinov3_hidden_dim or embed_dim + sta_ch = config.sta_inplanes + self.convs = nn.ModuleList( + [ + nn.Conv2d(embed_dim + sta_ch * 2, hidden_dim, 1, bias=False), + nn.Conv2d(embed_dim + sta_ch * 4, hidden_dim, 1, bias=False), + nn.Conv2d(embed_dim + sta_ch * 4, hidden_dim, 1, bias=False), + ] + ) + self.norms = nn.ModuleList([nn.BatchNorm2d(hidden_dim) for _ in range(3)]) + + def forward(self, pixel_values: torch.Tensor) -> list[torch.Tensor]: + hidden_states = self.embeddings(pixel_values) + position_embeddings = self.rope_embeddings(pixel_values) + + intermediate_outputs = [] + for i, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, position_embeddings=position_embeddings) + if i in self.interaction_indexes: + out = self.norm(hidden_states) if self.apply_layernorm else hidden_states + intermediate_outputs.append(out) + + batch_size = pixel_values.shape[0] + h_patches = pixel_values.shape[2] // self.patch_size + w_patches = pixel_values.shape[3] // self.patch_size + + sem_feats = [] + num_scales = len(intermediate_outputs) + for i, feat in enumerate(intermediate_outputs): + patch_tokens = feat[:, self.num_prefix_tokens :] + spatial = patch_tokens.transpose(1, 2).reshape(batch_size, -1, h_patches, w_patches).contiguous() + resize_h = int(h_patches * 2 ** (num_scales - 2 - i)) + resize_w = int(w_patches * 2 ** (num_scales - 2 - i)) + spatial = F.interpolate(spatial, size=[resize_h, resize_w], mode="bilinear", align_corners=False) + sem_feats.append(spatial) + + detail_feats = self.sta(pixel_values) + + outs = [] + for i, (sem_feat, detail_feat) in enumerate(zip(sem_feats, detail_feats)): + fused = torch.cat([sem_feat, detail_feat], dim=1) + outs.append(self.norms[i](self.convs[i](fused))) + + return outs + + class Deimv2Integral(DFineIntegral): pass @@ -508,10 +717,63 @@ class Deimv2DecoderLayer(DFineDecoderLayer): def __init__(self, config: Deimv2Config): super().__init__(config) self.encoder_attn = Deimv2MultiscaleDeformableAttention(config=config) - self.gateway = Deimv2Gate(config.d_model) self.self_attn_layer_norm = Deimv2RMSNorm(config.d_model) self.final_layer_norm = Deimv2RMSNorm(config.d_model) self.mlp = Deimv2SwiGLUFFN(config.d_model, config.decoder_ffn_dim // 2, config.d_model) + if config.use_gateway: + self.gateway = Deimv2Gate(config.d_model) + else: + del self.gateway + self.encoder_attn_layer_norm = Deimv2RMSNorm(config.d_model) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor | None = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=encoder_attention_mask, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + + hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings + hidden_states, _ = self.encoder_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if hasattr(self, "gateway"): + hidden_states = self.gateway(residual, hidden_states) + else: + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states.clamp(min=-65504, max=65504)) + + return hidden_states class Deimv2MLPPredictionHead(DFineMLP): @@ -597,6 +859,20 @@ def _init_weights(self, module): init.ones_(module.weight) init.zeros_(module.bias) + if isinstance(module, Deimv2DINOv3ViTEmbeddings): + init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) + if module.config.num_register_tokens > 0: + init.trunc_normal_(module.register_tokens, mean=0.0, std=self.config.initializer_range) + init.zeros_(module.mask_token) + + if isinstance(module, Deimv2DINOv3ViTLayerScale) and self.config.dinov3_backbone_config is not None: + layerscale_value = self.config.dinov3_backbone_config.get("layerscale_value", 1.0) + init.constant_(module.lambda1, layerscale_value) + + if isinstance(module, Deimv2DINOv3ViTRopePositionEmbedding): + inv_freq = 1 / module.base ** torch.arange(0, 1, 4 / module.head_dim, dtype=torch.float32) + init.copy_(module.inv_freq, inv_freq) + if hasattr(module, "weight_embedding") and self.config.learn_initial_query: init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: @@ -605,7 +881,7 @@ def _init_weights(self, module): class Deimv2HybridEncoder(DFineHybridEncoder): def __init__(self, config: Deimv2Config): - Deimv2PreTrainedModel.__init__(config) + Deimv2PreTrainedModel.__init__(self, config) self.config = config self.in_channels = config.encoder_in_channels self.num_fpn_stages = len(self.in_channels) - 1 @@ -707,9 +983,49 @@ def __init__(self, config: Deimv2Config): class Deimv2Model(DFineModel): def __init__(self, config: Deimv2Config): - super().__init__(config) - del self.decoder_input_proj - self.encoder = Deimv2HybridEncoder(config=config) + Deimv2PreTrainedModel.__init__(self, config) + + if config.backbone_type == "dinov3": + self.dinov3_backbone = Deimv2DINOv3Backbone(config) + else: + from ..d_fine.modeling_d_fine import DFineConvEncoder + + self.backbone = DFineConvEncoder(config) + if config.encoder_type != "lite": + intermediate_channel_sizes = self.backbone.intermediate_channel_sizes + encoder_input_proj = [] + for in_channel in intermediate_channel_sizes: + encoder_input_proj.append( + nn.Sequential( + nn.Conv2d(in_channel, config.encoder_hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(config.encoder_hidden_dim), + ) + ) + self.encoder_input_proj = nn.ModuleList(encoder_input_proj) + + if config.encoder_type == "lite": + self.encoder = Deimv2LiteEncoder(config) + else: + self.encoder = Deimv2HybridEncoder(config=config) + + if config.num_denoising > 0: + self.denoising_class_embed = nn.Embedding( + config.num_labels + 1, config.d_model, padding_idx=config.num_labels + ) + + if config.learn_initial_query: + self.weight_embedding = nn.Embedding(config.num_queries, config.d_model) + + self.enc_output = nn.Sequential( + nn.Linear(config.d_model, config.d_model), + nn.LayerNorm(config.d_model, eps=config.layer_norm_eps), + ) + self.enc_score_head = nn.Linear(config.d_model, config.num_labels) + self.enc_bbox_head = Deimv2MLP(config.d_model, config.d_model, 4, 3) + + if config.anchor_image_size: + self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype) + num_backbone_outs = len(config.decoder_in_channels) decoder_input_proj = [] in_channels = config.decoder_in_channels[-1] @@ -730,18 +1046,189 @@ def __init__(self, config: Deimv2Config): self.decoder_input_proj = nn.ModuleList(decoder_input_proj) self.decoder = Deimv2Decoder(config) - if config.use_spatial_tuning_adapter: + if config.use_spatial_tuning_adapter and config.backbone_type != "dinov3": self.spatial_tuning_adapter = Deimv2SpatialTuningAdapter(config) + self.post_init() + + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: torch.LongTensor | None = None, + encoder_outputs: torch.FloatTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: list[dict] | None = None, + **kwargs, + ): + if pixel_values is None and inputs_embeds is None: + raise ValueError("You have to specify either pixel_values or inputs_embeds") + + if inputs_embeds is None: + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + + if self.config.backbone_type == "dinov3": + proj_feats = self.dinov3_backbone(pixel_values) + elif self.config.encoder_type == "lite": + features = self.backbone(pixel_values, pixel_mask) + proj_feats = [source for source, mask in features] + else: + features = self.backbone(pixel_values, pixel_mask) + proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)] + else: + batch_size = inputs_embeds.shape[0] + device = inputs_embeds.device + proj_feats = inputs_embeds + + if encoder_outputs is None: + encoder_outputs = self.encoder( + proj_feats, + **kwargs, + ) + elif not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + sources = [] + for level, source in enumerate(encoder_outputs.last_hidden_state): + sources.append(self.decoder_input_proj[level](source)) + + if self.config.num_feature_levels > len(sources): + _len_sources = len(sources) + sources.append(self.decoder_input_proj[_len_sources](encoder_outputs.last_hidden_state[-1])) + for i in range(_len_sources + 1, self.config.num_feature_levels): + sources.append(self.decoder_input_proj[i](encoder_outputs.last_hidden_state[-1])) + + source_flatten = [] + spatial_shapes_list = [] + spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long) + for level, source in enumerate(sources): + height, width = source.shape[-2:] + spatial_shapes[level, 0] = height + spatial_shapes[level, 1] = width + spatial_shapes_list.append((height, width)) + source = source.flatten(2).transpose(1, 2) + source_flatten.append(source) + source_flatten = torch.cat(source_flatten, 1) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + + if self.training and self.config.num_denoising > 0 and labels is not None: + from ..d_fine.modeling_d_fine import get_contrastive_denoising_training_group + + ( + denoising_class, + denoising_bbox_unact, + attention_mask, + denoising_meta_values, + ) = get_contrastive_denoising_training_group( + targets=labels, + num_classes=self.config.num_labels, + num_queries=self.config.num_queries, + class_embed=self.denoising_class_embed, + num_denoising_queries=self.config.num_denoising, + label_noise_ratio=self.config.label_noise_ratio, + box_noise_scale=self.config.box_noise_scale, + ) + else: + denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None + + batch_size = len(source_flatten) + device = source_flatten.device + dtype = source_flatten.dtype + + if self.training or self.config.anchor_image_size is None: + spatial_shapes_tuple = tuple(spatial_shapes_list) + anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype) + else: + anchors, valid_mask = self.anchors, self.valid_mask + anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype) + + memory = valid_mask.to(source_flatten.dtype) * source_flatten + + output_memory = self.enc_output(memory) + + enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors + + _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1) + + reference_points_unact = enc_outputs_coord_logits.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1]) + ) + + enc_topk_bboxes = F.sigmoid(reference_points_unact) + if denoising_bbox_unact is not None: + reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1) + + enc_topk_logits = enc_outputs_class.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]) + ) + + if self.config.learn_initial_query: + target = self.weight_embedding.tile([batch_size, 1, 1]) + else: + target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1])) + target = target.detach() + + if denoising_class is not None: + target = torch.concat([denoising_class, target], 1) + + init_reference_points = reference_points_unact.detach() + + from ..d_fine.modeling_d_fine import DFineModelOutput + + decoder_outputs = self.decoder( + inputs_embeds=target, + encoder_hidden_states=source_flatten, + encoder_attention_mask=attention_mask, + reference_points=init_reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + **kwargs, + ) + + return DFineModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + intermediate_logits=decoder_outputs.intermediate_logits, + intermediate_reference_points=decoder_outputs.intermediate_reference_points, + intermediate_predicted_corners=decoder_outputs.intermediate_predicted_corners, + initial_reference_points=decoder_outputs.initial_reference_points, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + init_reference_points=init_reference_points, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + enc_outputs_class=enc_outputs_class, + enc_outputs_coord_logits=enc_outputs_coord_logits, + denoising_meta_values=denoising_meta_values, + ) + class Deimv2ForObjectDetection(DFineForObjectDetection): _no_split_modules = None - _tied_weights_keys = { - r"bbox_embed.(?![0])\d+": r"bbox_embed.0", - r"class_embed.(?![0])\d+": r"^class_embed.0", - "class_embed": "model.decoder.class_embed", - "bbox_embed": "model.decoder.bbox_embed", - } + + @property + def _tied_weights_keys(self): + keys = { + r"class_embed.(?![0])\d+": r"^class_embed.0", + "class_embed": "model.decoder.class_embed", + "bbox_embed": "model.decoder.bbox_embed", + } + if getattr(self.config, "share_bbox_head", False): + keys[r"model\.decoder\.bbox_embed\.(?![0])\d+"] = r"model.decoder.bbox_embed.0" + keys[r"bbox_embed.(?![0])\d+"] = r"bbox_embed.0" + return keys def __init__(self, config: Deimv2Config): Deimv2PreTrainedModel.__init__(self, config) @@ -751,16 +1238,20 @@ def __init__(self, config: Deimv2Config): scaled_dim = round(config.layer_scale * config.hidden_size) num_pred = config.decoder_layers self.class_embed = nn.ModuleList([nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)]) - self.bbox_embed = nn.ModuleList( - [ - Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) - for _ in range(self.eval_idx + 1) - ] - + [ - Deimv2MLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3) - for _ in range(config.decoder_layers - self.eval_idx - 1) - ] - ) + if config.share_bbox_head: + shared_bbox = Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) + self.bbox_embed = nn.ModuleList([shared_bbox] * num_pred) + else: + self.bbox_embed = nn.ModuleList( + [ + Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) + for _ in range(self.eval_idx + 1) + ] + + [ + Deimv2MLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3) + for _ in range(config.decoder_layers - self.eval_idx - 1) + ] + ) self.model.decoder.class_embed = self.class_embed self.model.decoder.bbox_embed = self.bbox_embed diff --git a/tests/models/deimv2/test_modeling_deimv2.py b/tests/models/deimv2/test_modeling_deimv2.py index 9b612c7833d1..db4ce97086b4 100644 --- a/tests/models/deimv2/test_modeling_deimv2.py +++ b/tests/models/deimv2/test_modeling_deimv2.py @@ -46,7 +46,12 @@ from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_modeling_common import ( + TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, + ModelTesterMixin, + _test_eager_matches_sdpa_inference, + floats_tensor, +) from ...test_pipeline_mixin import PipelineTesterMixin @@ -650,6 +655,1009 @@ def test_inference_equivalence_for_static_and_dynamic_anchors(self, dtype_str): ) +class Deimv2LiteEncoderModelTester: + def __init__( + self, + parent, + batch_size=3, + is_training=True, + use_labels=True, + n_targets=3, + num_labels=10, + initializer_range=0.02, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + encoder_hidden_dim=32, + encoder_in_channels=[256], + feat_strides=[16, 32], + dropout=0.0, + activation_dropout=0.0, + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + d_model=32, + num_queries=10, + decoder_in_channels=[32, 32], + decoder_ffn_dim=64, + num_feature_levels=2, + decoder_n_points=[4, 2], + decoder_n_levels=2, + decoder_layers=2, + decoder_attention_heads=2, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=0, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + image_size=64, + disable_custom_kernels=True, + with_box_refine=True, + decoder_offset_scale=0.5, + eval_idx=-1, + layer_scale=1, + reg_max=32, + reg_scale=4.0, + depth_mult=0.34, + hidden_expansion=0.5, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = 3 + self.is_training = is_training + self.use_labels = use_labels + self.n_targets = n_targets + self.num_labels = num_labels + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_layers = 0 + self.encoder_ffn_dim = 64 + self.encoder_attention_heads = 2 + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = [] + self.positional_encoding_temperature = positional_encoding_temperature + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.eval_size = eval_size + self.normalize_before = normalize_before + self.d_model = d_model + self.num_queries = num_queries + self.decoder_in_channels = decoder_in_channels + self.decoder_ffn_dim = decoder_ffn_dim + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_n_levels = decoder_n_levels + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.decoder_offset_scale = decoder_offset_scale + self.eval_idx = eval_idx + self.layer_scale = layer_scale + self.reg_max = reg_max + self.reg_scale = reg_scale + self.depth_mult = depth_mult + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.image_size = image_size + self.disable_custom_kernels = disable_custom_kernels + self.with_box_refine = with_box_refine + self.hidden_expansion = hidden_expansion + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device) + + labels = None + if self.use_labels: + labels = [] + for i in range(self.batch_size): + target = {} + target["class_labels"] = torch.randint( + high=self.num_labels, size=(self.n_targets,), device=torch_device + ) + target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device) + labels.append(target) + + config = self.get_config() + config.num_labels = self.num_labels + return config, pixel_values, pixel_mask, labels + + def get_config(self): + backbone_config = HGNetV2Config( + stage_in_channels=[16, 64, 128], + stage_mid_channels=[16, 32, 64], + stage_out_channels=[64, 128, 256], + stage_num_blocks=[1, 1, 1], + stage_downsample=[False, True, True], + stage_light_block=[False, False, True], + stage_kernel_size=[3, 3, 3], + stage_numb_of_layers=[3, 3, 3], + embeddings_size=10, + hidden_sizes=[64, 128, 256], + depths=[1, 1, 1], + out_features=["stage3"], + out_indices=[3], + stem_channels=[3, 16, 16], + use_lab=True, + ) + return Deimv2Config( + backbone_config=backbone_config, + encoder_hidden_dim=self.encoder_hidden_dim, + encoder_in_channels=self.encoder_in_channels, + feat_strides=self.feat_strides, + encoder_layers=self.encoder_layers, + encoder_ffn_dim=self.encoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + dropout=self.dropout, + activation_dropout=self.activation_dropout, + encode_proj_layers=self.encode_proj_layers, + positional_encoding_temperature=self.positional_encoding_temperature, + encoder_activation_function=self.encoder_activation_function, + activation_function=self.activation_function, + eval_size=self.eval_size, + normalize_before=self.normalize_before, + d_model=self.d_model, + num_queries=self.num_queries, + decoder_in_channels=self.decoder_in_channels, + decoder_ffn_dim=self.decoder_ffn_dim, + num_feature_levels=self.num_feature_levels, + decoder_n_points=self.decoder_n_points, + decoder_n_levels=self.decoder_n_levels, + decoder_layers=self.decoder_layers, + decoder_attention_heads=self.decoder_attention_heads, + decoder_activation_function=self.decoder_activation_function, + decoder_offset_scale=self.decoder_offset_scale, + eval_idx=self.eval_idx, + layer_scale=self.layer_scale, + reg_max=self.reg_max, + reg_scale=self.reg_scale, + depth_mult=self.depth_mult, + attention_dropout=self.attention_dropout, + num_denoising=self.num_denoising, + label_noise_ratio=self.label_noise_ratio, + box_noise_scale=self.box_noise_scale, + learn_initial_query=self.learn_initial_query, + anchor_image_size=self.anchor_image_size, + image_size=self.image_size, + disable_custom_kernels=self.disable_custom_kernels, + with_box_refine=self.with_box_refine, + encoder_type="lite", + use_gateway=False, + share_bbox_head=False, + ) + + def prepare_config_and_inputs_for_common(self): + config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + def create_and_check_deimv2_model(self, config, pixel_values, pixel_mask, labels): + model = Deimv2Model(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.d_model)) + + def create_and_check_deimv2_object_detection_head_model(self, config, pixel_values, pixel_mask, labels): + model = Deimv2ForObjectDetection(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels) + + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + +@require_torch +class Deimv2LiteEncoderModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (Deimv2Model, Deimv2ForObjectDetection) if is_torch_available() else () + pipeline_model_mapping = {} + is_encoder_decoder = True + + test_missing_keys = False + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "Deimv2ForObjectDetection": + labels = [] + for i in range(self.model_tester.batch_size): + target = {} + target["class_labels"] = torch.ones( + size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long + ) + target["boxes"] = torch.ones( + self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float + ) + labels.append(target) + inputs_dict["labels"] = labels + + return inputs_dict + + def setUp(self): + self.model_tester = Deimv2LiteEncoderModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=Deimv2Config, + has_text_modality=False, + common_properties=["hidden_size", "num_attention_heads"], + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_deimv2_lite_encoder_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deimv2_model(*config_and_inputs) + + def test_deimv2_lite_encoder_object_detection_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deimv2_object_detection_head_model(*config_and_inputs) + + @unittest.skip(reason="Deimv2 doesn't work well with `nn.DataParallel") + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip(reason="Deimv2 does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Deimv2 does not use test_inputs_embeds_matches_input_ids") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip(reason="Deimv2 does not support input and output embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="Deimv2 does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="Deimv2 does not use token embeddings") + def test_resize_tokens_embeddings(self): + pass + + @unittest.skip(reason="Feed forward chunking is not implemented") + def test_feed_forward_chunking(self): + pass + + @unittest.skip(reason="Weight tying is hardcoded (module_x = module_y) and always `True`") + def test_load_save_without_tied_weights(self): + pass + + @unittest.skip(reason="LiteEncoder has no AIFI layers, so no encoder attentions are produced") + def test_attention_outputs(self): + pass + + @unittest.skip(reason="LiteEncoder has no encoder attentions for gradient retention check") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip( + reason="LiteEncoder expects exactly 1 backbone feature map (single scale), but test_backbone_selection hardcodes 3-output backbones (out_indices=[2,3,4])" + ) + def test_backbone_selection(self): + pass + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if config.is_encoder_decoder: + hidden_states = outputs.decoder_hidden_states + + expected_num_layers = self.model_tester.decoder_layers + 1 + + self.assertIsInstance(hidden_states, (list, tuple)) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [self.model_tester.num_queries, self.model_tester.d_model], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + arg_names = [*signature.parameters.keys()] + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + +class Deimv2DINOv3ModelTester: + def __init__( + self, + parent, + batch_size=3, + is_training=True, + use_labels=True, + n_targets=3, + num_labels=10, + initializer_range=0.02, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + encoder_hidden_dim=32, + encoder_in_channels=[32, 32, 32], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=64, + encoder_attention_heads=2, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + d_model=32, + num_queries=30, + decoder_in_channels=[32, 32, 32], + decoder_ffn_dim=64, + num_feature_levels=3, + decoder_n_points=4, + decoder_n_levels=3, + decoder_layers=2, + decoder_attention_heads=2, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=0, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + image_size=64, + disable_custom_kernels=True, + with_box_refine=True, + decoder_offset_scale=0.5, + eval_idx=-1, + layer_scale=1, + reg_max=32, + reg_scale=4.0, + depth_mult=0.34, + hidden_expansion=0.5, + sta_inplanes=8, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = 3 + self.is_training = is_training + self.use_labels = use_labels + self.n_targets = n_targets + self.num_labels = num_labels + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_layers = encoder_layers + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = encode_proj_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.eval_size = eval_size + self.normalize_before = normalize_before + self.d_model = d_model + self.num_queries = num_queries + self.decoder_in_channels = decoder_in_channels + self.decoder_ffn_dim = decoder_ffn_dim + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_n_levels = decoder_n_levels + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.decoder_offset_scale = decoder_offset_scale + self.eval_idx = eval_idx + self.layer_scale = layer_scale + self.reg_max = reg_max + self.reg_scale = reg_scale + self.depth_mult = depth_mult + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.image_size = image_size + self.disable_custom_kernels = disable_custom_kernels + self.with_box_refine = with_box_refine + self.hidden_expansion = hidden_expansion + self.sta_inplanes = sta_inplanes + + self.encoder_seq_length = math.ceil(self.image_size / 32) * math.ceil(self.image_size / 32) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device) + + labels = None + if self.use_labels: + labels = [] + for i in range(self.batch_size): + target = {} + target["class_labels"] = torch.randint( + high=self.num_labels, size=(self.n_targets,), device=torch_device + ) + target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device) + labels.append(target) + + config = self.get_config() + config.num_labels = self.num_labels + return config, pixel_values, pixel_mask, labels + + def get_config(self): + hidden_sizes = [64, 128, 256, 512] + backbone_config = HGNetV2Config( + stage_in_channels=[16, 64, 128, 256], + stage_mid_channels=[16, 32, 64, 128], + stage_out_channels=[64, 128, 256, 512], + stage_num_blocks=[1, 1, 2, 1], + stage_downsample=[False, True, True, True], + stage_light_block=[False, False, True, True], + stage_kernel_size=[3, 3, 5, 5], + stage_numb_of_layers=[3, 3, 3, 3], + embeddings_size=10, + hidden_sizes=hidden_sizes, + depths=[1, 1, 2, 1], + out_features=["stage2", "stage3", "stage4"], + out_indices=[2, 3, 4], + stem_channels=[3, 16, 16], + use_lab=True, + ) + dinov3_backbone_config = { + "hidden_size": 32, + "num_attention_heads": 2, + "num_hidden_layers": 4, + "intermediate_size": 64, + "num_register_tokens": 1, + "layerscale_value": 1.0, + "use_gated_mlp": False, + "rope_theta": 100.0, + "patch_size": 16, + "image_size": self.image_size, + } + return Deimv2Config( + backbone_config=backbone_config, + encoder_hidden_dim=self.encoder_hidden_dim, + encoder_in_channels=self.encoder_in_channels, + feat_strides=self.feat_strides, + encoder_layers=self.encoder_layers, + encoder_ffn_dim=self.encoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + dropout=self.dropout, + activation_dropout=self.activation_dropout, + encode_proj_layers=self.encode_proj_layers, + positional_encoding_temperature=self.positional_encoding_temperature, + encoder_activation_function=self.encoder_activation_function, + activation_function=self.activation_function, + eval_size=self.eval_size, + normalize_before=self.normalize_before, + d_model=self.d_model, + num_queries=self.num_queries, + decoder_in_channels=self.decoder_in_channels, + decoder_ffn_dim=self.decoder_ffn_dim, + num_feature_levels=self.num_feature_levels, + decoder_n_points=self.decoder_n_points, + decoder_n_levels=self.decoder_n_levels, + decoder_layers=self.decoder_layers, + decoder_attention_heads=self.decoder_attention_heads, + decoder_activation_function=self.decoder_activation_function, + decoder_offset_scale=self.decoder_offset_scale, + eval_idx=self.eval_idx, + layer_scale=self.layer_scale, + reg_max=self.reg_max, + reg_scale=self.reg_scale, + depth_mult=self.depth_mult, + attention_dropout=self.attention_dropout, + num_denoising=self.num_denoising, + label_noise_ratio=self.label_noise_ratio, + box_noise_scale=self.box_noise_scale, + learn_initial_query=self.learn_initial_query, + anchor_image_size=self.anchor_image_size, + image_size=self.image_size, + disable_custom_kernels=self.disable_custom_kernels, + with_box_refine=self.with_box_refine, + backbone_type="dinov3", + dinov3_backbone_config=dinov3_backbone_config, + dinov3_interaction_indexes=[1, 2, 3], + dinov3_hidden_dim=self.encoder_hidden_dim, + dinov3_apply_layernorm=False, + sta_inplanes=self.sta_inplanes, + use_spatial_tuning_adapter=True, + encoder_has_trailing_conv=False, + ) + + def prepare_config_and_inputs_for_common(self): + config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + def create_and_check_deimv2_model(self, config, pixel_values, pixel_mask, labels): + model = Deimv2Model(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.d_model)) + + def create_and_check_deimv2_object_detection_head_model(self, config, pixel_values, pixel_mask, labels): + model = Deimv2ForObjectDetection(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels) + + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + +@require_torch +class Deimv2DINOv3ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (Deimv2Model, Deimv2ForObjectDetection) if is_torch_available() else () + pipeline_model_mapping = ( + {"image-feature-extraction": Deimv2Model, "object-detection": Deimv2ForObjectDetection} + if is_torch_available() + else {} + ) + is_encoder_decoder = True + + test_missing_keys = False + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "Deimv2ForObjectDetection": + labels = [] + for i in range(self.model_tester.batch_size): + target = {} + target["class_labels"] = torch.ones( + size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long + ) + target["boxes"] = torch.ones( + self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float + ) + labels.append(target) + inputs_dict["labels"] = labels + + return inputs_dict + + def setUp(self): + self.model_tester = Deimv2DINOv3ModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=Deimv2Config, + has_text_modality=False, + common_properties=["hidden_size", "num_attention_heads"], + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_deimv2_dinov3_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deimv2_model(*config_and_inputs) + + def test_deimv2_dinov3_object_detection_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deimv2_object_detection_head_model(*config_and_inputs) + + @unittest.skip(reason="Deimv2 doesn't work well with `nn.DataParallel") + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip(reason="Deimv2 does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Deimv2 does not use test_inputs_embeds_matches_input_ids") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip(reason="Deimv2 does not support input and output embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="Deimv2 does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="Deimv2 does not use token embeddings") + def test_resize_tokens_embeddings(self): + pass + + @unittest.skip(reason="Feed forward chunking is not implemented") + def test_feed_forward_chunking(self): + pass + + @unittest.skip(reason="Weight tying is hardcoded (module_x = module_y) and always `True`") + def test_load_save_without_tied_weights(self): + pass + + @unittest.skip(reason="DINOv3 backbone does not support timm/HF backbone selection") + def test_backbone_selection(self): + pass + + @unittest.skip( + reason="DINOv3 backbone with RoPE and LayerScale produces numerical differences beyond tolerance during offloading" + ) + def test_cpu_offload(self): + pass + + @unittest.skip( + reason="DINOv3 backbone with RoPE and LayerScale produces numerical differences beyond tolerance during offloading" + ) + def test_disk_offload_bin(self): + pass + + @unittest.skip( + reason="DINOv3 backbone with RoPE and LayerScale produces numerical differences beyond tolerance during offloading" + ) + def test_disk_offload_safetensors(self): + pass + + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) + def test_eager_matches_sdpa_inference( + self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels + ): + atols = { + ("cpu", False, torch.float32): 1e-4, + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-4, + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-4, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-4, + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + rtols = { + ("cpu", False, torch.float32): 1e-3, + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-3, + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-3, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-3, + ("cuda", True, torch.bfloat16): 3e-2, + ("cuda", True, torch.float16): 5e-3, + } + _test_eager_matches_sdpa_inference( + self, + name, + torch_dtype, + padding_side, + use_attention_mask, + output_attentions, + enable_kernels, + atols=atols, + rtols=rtols, + ) + + def test_batching_equivalence(self): + super().test_batching_equivalence(atol=1e-4, rtol=1e-4) + + @unittest.skip(reason="Deimv2 is not a generative encoder-decoder model and has no decoder_input_ids") + def test_flex_attention_with_grads(self): + pass + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class._from_config(config, attn_implementation="eager") + config = model.config + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions + self.assertEqual(len(attentions), self.model_tester.encoder_layers) + + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions + self.assertEqual(len(attentions), self.model_tester.encoder_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [ + self.model_tester.encoder_attention_heads, + self.model_tester.encoder_seq_length, + self.model_tester.encoder_seq_length, + ], + ) + out_len = len(outputs) + + correct_outlen = 15 + + if "labels" in inputs_dict: + correct_outlen += 1 + if model_class.__name__ == "Deimv2ForObjectDetection": + correct_outlen += 2 + + self.assertEqual(out_len, correct_outlen) + + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.decoder_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [ + self.model_tester.decoder_attention_heads, + self.model_tester.num_queries, + self.model_tester.num_queries, + ], + ) + + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.decoder_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_queries, + self.model_tester.decoder_attention_heads, + self.model_tester.decoder_n_levels * self.model_tester.decoder_n_points + if isinstance(self.model_tester.decoder_n_points, int) + else sum(self.model_tester.decoder_n_points), + ], + ) + + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + else: + added_hidden_states = 2 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions + + self.assertEqual(len(self_attentions), self.model_tester.encoder_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [ + self.model_tester.encoder_attention_heads, + self.model_tester.encoder_seq_length, + self.model_tester.encoder_seq_length, + ], + ) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", len(self.model_tester.encoder_in_channels) - 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[1].shape[-2:]), + [ + self.model_tester.image_size // self.model_tester.feat_strides[-1], + self.model_tester.image_size // self.model_tester.feat_strides[-1], + ], + ) + + if config.is_encoder_decoder: + hidden_states = outputs.decoder_hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.decoder_layers + 1 + ) + + self.assertIsInstance(hidden_states, (list, tuple)) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [self.model_tester.num_queries, self.model_tester.d_model], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + + outputs = model(**inputs) + + output = outputs[0] + + encoder_hidden_states = outputs.encoder_hidden_states[0] + encoder_attentions = outputs.encoder_attentions[0] + encoder_hidden_states.retain_grad() + encoder_attentions.retain_grad() + + decoder_attentions = outputs.decoder_attentions[0] + decoder_attentions.retain_grad() + + cross_attentions = outputs.cross_attentions[0] + cross_attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(encoder_hidden_states.grad) + self.assertIsNotNone(encoder_attentions.grad) + self.assertIsNotNone(decoder_attentions.grad) + self.assertIsNotNone(cross_attentions.grad) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + arg_names = [*signature.parameters.keys()] + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + @parameterized.expand(["float32", "float16", "bfloat16"]) + @require_torch_accelerator + @slow + def test_inference_with_different_dtypes(self, dtype_str): + dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[dtype_str] + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device).to(dtype) + model.eval() + for key, tensor in inputs_dict.items(): + if tensor.dtype == torch.float32: + inputs_dict[key] = tensor.to(dtype) + with torch.no_grad(): + _ = model(**self._prepare_for_class(inputs_dict, model_class)) + + @parameterized.expand(["float32", "float16", "bfloat16"]) + @require_torch_accelerator + @slow + def test_inference_equivalence_for_static_and_dynamic_anchors(self, dtype_str): + dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[dtype_str] + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + h, w = inputs_dict["pixel_values"].shape[-2:] + + for key, tensor in inputs_dict.items(): + if tensor.dtype == torch.float32: + inputs_dict[key] = tensor.to(dtype) + + for model_class in self.all_model_classes: + with tempfile.TemporaryDirectory() as tmpdirname: + model_class(config).save_pretrained(tmpdirname) + model_static = model_class.from_pretrained( + tmpdirname, anchor_image_size=[h, w], device_map=torch_device, dtype=dtype + ).eval() + model_dynamic = model_class.from_pretrained( + tmpdirname, anchor_image_size=None, device_map=torch_device, dtype=dtype + ).eval() + + self.assertIsNotNone(model_static.config.anchor_image_size) + self.assertIsNone(model_dynamic.config.anchor_image_size) + + with torch.no_grad(): + outputs_static = model_static(**self._prepare_for_class(inputs_dict, model_class)) + outputs_dynamic = model_dynamic(**self._prepare_for_class(inputs_dict, model_class)) + + torch.testing.assert_close( + outputs_static.last_hidden_state, outputs_dynamic.last_hidden_state, rtol=5e-3, atol=5e-3 + ) + + def prepare_img(): image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") return image From adc4079e540e1549581ca75bae3cd13de7e85afa Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Sun, 1 Mar 2026 11:46:55 +0400 Subject: [PATCH 04/17] fix: Fix ci/circleci: check_repository_consistency --- utils/check_config_attributes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index c301d841ab2f..413f5e6d61fa 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -96,6 +96,8 @@ "SwitchTransformersConfig": True, "DetrConfig": True, "DFineConfig": True, + "Deimv2Config": True, + "Deimv2DINOv3ViTConfig": True, "GroundingDinoConfig": True, "MMGroundingDinoConfig": True, "RTDetrConfig": True, From 39d300e87c7988c2c9da4826c7ef7ee11e1332bc Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Tue, 17 Mar 2026 18:46:31 +0400 Subject: [PATCH 05/17] refactor: Resolve review comments --- .../models/deimv2/configuration_deimv2.py | 195 +--- ...eimv2_original_pytorch_checkpoint_to_hf.py | 157 ++-- .../models/deimv2/modeling_deimv2.py | 856 ++++++------------ .../models/deimv2/modular_deimv2.py | 354 ++------ tests/models/deimv2/test_modeling_deimv2.py | 52 +- utils/check_config_attributes.py | 1 - 6 files changed, 465 insertions(+), 1150 deletions(-) diff --git a/src/transformers/models/deimv2/configuration_deimv2.py b/src/transformers/models/deimv2/configuration_deimv2.py index d98372c44509..962c34a74596 100644 --- a/src/transformers/models/deimv2/configuration_deimv2.py +++ b/src/transformers/models/deimv2/configuration_deimv2.py @@ -17,7 +17,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ...backbone_utils import BackboneConfigMixin, consolidate_backbone_kwargs_to_config + +from ...backbone_utils import consolidate_backbone_kwargs_to_config from ...configuration_utils import PreTrainedConfig from ..auto import AutoConfig @@ -45,7 +46,8 @@ class Deimv2Config(PreTrainedConfig): batch_norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the batch normalization layers. backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `HGNetV2Config()`): - The configuration of the backbone model. + The configuration of the backbone model. For HGNetV2 variants, use `HGNetV2Config`. + For DINOv3 variants, use `DINOv3ViTConfig`. freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`): Whether to freeze the batch normalization layers in the backbone. encoder_hidden_dim (`int`, *optional*, defaults to 256): @@ -174,8 +176,6 @@ class Deimv2Config(PreTrainedConfig): Alpha parameter for the Matching Auxiliary Loss (MAL). If `None`, uses `focal_loss_alpha`. encoder_fuse_op (`str`, *optional*, defaults to `"sum"`): Fusion operation used in the encoder FPN. DEIMv2 uses `"sum"` instead of D-Fine's `"cat"`. - use_spatial_tuning_adapter (`bool`, *optional*, defaults to `False`): - Whether to use the Spatial Tuning Adapter (STA) for DINOv2 backbone variants. sta_inplanes (`int`, *optional*, defaults to 16): Number of input planes for the STA convolutional stem. encoder_type (`str`, *optional*, defaults to `"hybrid"`): @@ -186,17 +186,6 @@ class Deimv2Config(PreTrainedConfig): uses RMSNorm on the encoder attention output instead. share_bbox_head (`bool`, *optional*, defaults to `False`): Whether to share the bounding box prediction head across all decoder layers. - backbone_type (`str`, *optional*, defaults to `"hgnetv2"`): - Type of backbone to use. `"hgnetv2"` uses HGNetV2, `"dinov3"` uses DINOv3 ViT backbone with STA. - dinov3_backbone_config (`dict`, *optional*): - Configuration dictionary for the DINOv3 ViT backbone. Passed as kwargs to `DINOv3ViTConfig`. - dinov3_interaction_indexes (`list[int]`, *optional*): - Layer indices in the DINOv3 ViT backbone from which to extract intermediate features. - dinov3_hidden_dim (`int`, *optional*): - Hidden dimension for the DINOv3 backbone projection convolutions. If `None`, uses `hidden_size` from - the DINOv3 ViT config. - dinov3_apply_layernorm (`bool`, *optional*, defaults to `False`): - Whether to apply LayerNorm to intermediate features extracted from the DINOv3 ViT backbone. encoder_has_trailing_conv (`bool`, *optional*, defaults to `True`): Whether the encoder's CSP blocks include a trailing 3x3 convolution after the bottleneck path. `True` for RepNCSPELAN4 (used by HGNetV2 N and LiteEncoder variants). @@ -283,16 +272,10 @@ def __init__( use_dense_o2o=True, mal_alpha=None, encoder_fuse_op="sum", - use_spatial_tuning_adapter=False, sta_inplanes=16, encoder_type="hybrid", use_gateway=True, share_bbox_head=False, - backbone_type="hgnetv2", - dinov3_backbone_config=None, - dinov3_interaction_indexes=None, - dinov3_hidden_dim=None, - dinov3_apply_layernorm=False, encoder_has_trailing_conv=True, tie_word_embeddings=True, **kwargs, @@ -301,16 +284,10 @@ def __init__( self.use_dense_o2o = use_dense_o2o self.mal_alpha = mal_alpha self.encoder_fuse_op = encoder_fuse_op - self.use_spatial_tuning_adapter = use_spatial_tuning_adapter self.sta_inplanes = sta_inplanes self.encoder_type = encoder_type self.use_gateway = use_gateway self.share_bbox_head = share_bbox_head - self.backbone_type = backbone_type - self.dinov3_backbone_config = dinov3_backbone_config - self.dinov3_interaction_indexes = dinov3_interaction_indexes - self.dinov3_hidden_dim = dinov3_hidden_dim - self.dinov3_apply_layernorm = dinov3_apply_layernorm self.encoder_has_trailing_conv = encoder_has_trailing_conv self.initializer_range = initializer_range self.initializer_bias_prior_prob = initializer_bias_prior_prob @@ -404,168 +381,4 @@ def __init__( super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) -class Deimv2DINOv3ViTConfig(BackboneConfigMixin, PreTrainedConfig): - r""" - This is the configuration class to store the configuration of a [`DINOv3Model`]. It is used to instantiate an - DINOv3 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the DINOv3 - [facebook/dinov3-vits16-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vits16-pretrain-lvd1689m) architecture. - - Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PreTrainedConfig`] for more information. - - Args: - patch_size (`int`, *optional*, defaults to 16): - The size (resolution) of each patch. - hidden_size (`int`, *optional*, defaults to 384): - Dimensionality of the encoder layers and the pooler layer. - intermediate_size (`int`, *optional*, defaults to 1536): - Dimensionality of the "intermediate" (i.e., feed-forward) layer. - num_hidden_layers (`int`, *optional*, defaults to 12): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 6): - Number of attention heads for each attention layer in the Transformer encoder. - hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): - The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, - `"relu"`, `"selu"` and `"gelu_new"` are supported. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - layer_norm_eps (`float`, *optional*, defaults to 1e-05): - The epsilon used by the layer normalization layers. - rope_theta (`float`, *optional*, defaults to 100.0): - The base period of the RoPE embeddings. - image_size (`int`, *optional*, defaults to 224): - The size (resolution) of each image. - num_channels (`int`, *optional*, defaults to 3): - The number of input channels. - query_bias (`bool`, *optional*, defaults to `True`): - Whether to add a bias to the query projection. - key_bias (`bool`, *optional*, defaults to `False`): - Whether to add a bias to the key projection. - value_bias (`bool`, *optional*, defaults to `True`): - Whether to add a bias to the value projection. - proj_bias (`bool`, *optional*, defaults to `True`): - Whether to add a bias to the output projection. - mlp_bias (`bool`, *optional*, defaults to `True`): - Whether to add a bias to the MLP layers. - layerscale_value (`float`, *optional*, defaults to 1.0): - Initial value to use for layer scale. - drop_path_rate (`float`, *optional*, defaults to 0.0): - Stochastic depth rate per sample (when applied in the main path of residual layers). - use_gated_mlp (`bool`, *optional*, defaults to `False`): - Whether to use the SwiGLU feedforward neural network. - num_register_tokens (`int`, *optional*, defaults to 0): - The number of register tokens. - pos_embed_shift (`float`, *optional*): - Amount to randomly shift position embedding coordinates in [-shift, shift], - applied only in training mode if not `None`. - pos_embed_jitter (`float`, *optional*): - Amount to randomly jitter position embedding coordinates in log-uniform value in [1/jitter, jitter], - applied only in training mode if not `None`. - pos_embed_rescale (`float`, *optional*, defaults to 2.0): - Amount to randomly rescale position embedding coordinates in log-uniform value in [1/rescale, rescale], - applied only in training mode if not `None`. - out_features (`list[str]`, *optional*): - If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. - (depending on how many stages the model has). Will default to the last stage if unset. - out_indices (`list[int]`, *optional*): - If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. - (depending on how many stages the model has). Will default to the last stage if unset. - apply_layernorm (`bool`, *optional*, defaults to `True`): - Whether to apply layer normalization to the feature maps when used as backbone. - reshape_hidden_states (`bool`, *optional*, defaults to `True`): - Whether to reshape the hidden states to spatial dimensions when used as backbone. - - Example: - - ```python - >>> from transformers import Deimv2DINOv3ViTConfig, Deimv2DINOv3ViTModel - - >>> # Initializing a DINOv3 ViT-small style configuration - >>> config = Deimv2DINOv3ViTConfig() - - >>> # Initializing a model (with random weights) from the config - >>> model = Deimv2DINOv3ViTModel(config) - - >>> # Accessing the model config - >>> config = model.config - ```""" - - model_type = "deimv2_dinov3_vit" - - def __init__( - self, - patch_size: int = 16, - hidden_size: int = 384, - intermediate_size: int = 1536, - num_hidden_layers: int = 12, - num_attention_heads: int = 6, - hidden_act: str = "gelu", - attention_dropout: float = 0.0, - initializer_range: float = 0.02, - layer_norm_eps: float = 1e-5, - rope_theta: float = 100.0, - image_size: int = 224, - num_channels: int = 3, - query_bias: bool = True, - key_bias: bool = False, - value_bias: bool = True, - proj_bias: bool = True, - mlp_bias: bool = True, - layerscale_value: float = 1.0, - drop_path_rate: float = 0.0, - use_gated_mlp: bool = False, - num_register_tokens: int = 0, - # train augs - pos_embed_shift: float | None = None, - pos_embed_jitter: float | None = None, - pos_embed_rescale: float | None = 2.0, - out_features: list[str] | None = None, - out_indices: list[int] | None = None, - apply_layernorm: bool = True, - reshape_hidden_states: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.attention_dropout = attention_dropout - self.initializer_range = initializer_range - self.layer_norm_eps = layer_norm_eps - self.layerscale_value = layerscale_value - self.drop_path_rate = drop_path_rate - self.use_gated_mlp = use_gated_mlp - self.rope_theta = rope_theta - self.query_bias = query_bias - self.key_bias = key_bias - self.value_bias = value_bias - self.proj_bias = proj_bias - self.mlp_bias = mlp_bias - self.num_register_tokens = num_register_tokens - - # train augs - self.pos_embed_shift = pos_embed_shift - self.pos_embed_jitter = pos_embed_jitter - self.pos_embed_rescale = pos_embed_rescale - # Initialize backbone-specific configuration - self.apply_layernorm = apply_layernorm - self.reshape_hidden_states = reshape_hidden_states - - # Initialize backbone stage names - stage_names = ["stem"] + [f"stage{i}" for i in range(1, num_hidden_layers + 1)] - self.stage_names = stage_names - - # Initialize backbone features/indices - self.set_output_features_output_indices(out_indices=out_indices, out_features=out_features) - - __all__ = ["Deimv2Config"] diff --git a/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py b/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py index a6125952deeb..a4ba5aeda67f 100644 --- a/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py +++ b/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py @@ -26,6 +26,7 @@ from torchvision import transforms from transformers import Deimv2Config, Deimv2ForObjectDetection, RTDetrImageProcessor +from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig from transformers.utils import logging @@ -180,15 +181,11 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: num_stages = len(config.backbone_config.hidden_sizes) config.backbone_config.depths = config.backbone_config.stage_numb_of_layers config.backbone_config.stage_names = ["stem"] + [f"stage{i}" for i in range(1, num_stages + 1)] - - config.use_spatial_tuning_adapter = False elif "DINOv3STAs" in orig_config: dinov3_cfg = orig_config["DINOv3STAs"] name = dinov3_cfg["name"] - config.backbone_type = "dinov3" - config.dinov3_interaction_indexes = dinov3_cfg["interaction_indexes"] + interaction_indexes = dinov3_cfg["interaction_indexes"] config.sta_inplanes = dinov3_cfg.get("conv_inplane") or 16 - config.use_spatial_tuning_adapter = True is_dinov3 = "dinov3" in name @@ -239,8 +236,6 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: vit_hidden_size = dinov3_cfg.get("embed_dim") or 192 vit_num_heads = dinov3_cfg.get("num_heads") or 3 - config.dinov3_hidden_dim = dinov3_cfg.get("hidden_dim") or vit_hidden_size - ffn_ratio = preset["ffn_ratio"] if preset["use_gated_mlp"]: hidden_features = vit_hidden_size * ffn_ratio @@ -249,19 +244,22 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: else: intermediate_size = vit_hidden_size * ffn_ratio - config.dinov3_backbone_config = { - "hidden_size": vit_hidden_size, - "num_attention_heads": vit_num_heads, - "num_hidden_layers": 12, - "intermediate_size": intermediate_size, - "layerscale_value": preset["layerscale_value"], - "use_gated_mlp": preset["use_gated_mlp"], - "num_register_tokens": preset["num_register_tokens"], - "pos_embed_rescale": preset["pos_embed_rescale"], - "key_bias": preset["key_bias"], - "rope_theta": 100.0, - } - config.dinov3_apply_layernorm = is_dinov3 + out_indices = [idx + 1 for idx in interaction_indexes] + config.backbone_config = DINOv3ViTConfig( + hidden_size=vit_hidden_size, + num_attention_heads=vit_num_heads, + num_hidden_layers=12, + intermediate_size=intermediate_size, + layerscale_value=preset["layerscale_value"], + use_gated_mlp=preset["use_gated_mlp"], + num_register_tokens=preset["num_register_tokens"], + pos_embed_rescale=preset["pos_embed_rescale"], + key_bias=preset["key_bias"], + rope_theta=100.0, + out_indices=out_indices, + apply_layernorm=is_dinov3, + reshape_hidden_states=True, + ) config.encoder_has_trailing_conv = False else: raise ValueError(f"Unknown backbone in config: {list(orig_config.keys())}") @@ -297,8 +295,8 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: r"backbone\.stages\.(\d+)\.downsample\.lab\.(scale|bias)": r"model.backbone.model.encoder.stages.\1.downsample.lab.\2", # Encoder mappings # Input projections - r"encoder\.input_proj\.(\d+)\.conv\.weight": r"model.encoder_input_proj.\1.0.weight", - r"encoder\.input_proj\.(\d+)\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder_input_proj.\1.1.\2", + r"encoder\.input_proj\.(\d+)\.conv\.weight": r"model.backbone.encoder_input_proj.\1.0.weight", + r"encoder\.input_proj\.(\d+)\.norm\.(weight|bias|running_mean|running_var)": r"model.backbone.encoder_input_proj.\1.1.\2", # AIFI transformer encoder layers r"encoder\.encoder\.(\d+)\.layers\.0\.self_attn\.out_proj\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.self_attn.o_proj.\2", r"encoder\.encoder\.(\d+)\.layers\.0\.linear1\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.mlp.layers.0.\2", @@ -369,12 +367,12 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.value_proj\.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.value_proj.\2", r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.output_proj\.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.output_proj.\2", r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.num_points_scale": r"model.decoder.layers.\1.encoder_attn.num_points_scale", - r"decoder\.decoder\.layers\.(\d+)\.norm1\.scale": r"model.decoder.layers.\1.self_attn_layer_norm.scale", - r"decoder\.decoder\.layers\.(\d+)\.norm3\.scale": r"model.decoder.layers.\1.final_layer_norm.scale", + r"decoder\.decoder\.layers\.(\d+)\.norm1\.scale": r"model.decoder.layers.\1.self_attn_layer_norm.weight", + r"decoder\.decoder\.layers\.(\d+)\.norm3\.scale": r"model.decoder.layers.\1.final_layer_norm.weight", r"decoder\.decoder\.layers\.(\d+)\.swish_ffn\.w12\.(weight|bias)": r"model.decoder.layers.\1.mlp.w12.\2", r"decoder\.decoder\.layers\.(\d+)\.swish_ffn\.w3\.(weight|bias)": r"model.decoder.layers.\1.mlp.w3.\2", r"decoder\.decoder\.layers\.(\d+)\.gateway\.gate\.(weight|bias)": r"model.decoder.layers.\1.gateway.gate.\2", - r"decoder\.decoder\.layers\.(\d+)\.gateway\.norm\.scale": r"model.decoder.layers.\1.gateway.norm.scale", + r"decoder\.decoder\.layers\.(\d+)\.gateway\.norm\.scale": r"model.decoder.layers.\1.gateway.norm.weight", # LQE layers r"decoder\.decoder\.lqe_layers\.(\d+)\.reg_conf\.layers\.(\d+)\.(weight|bias)": r"model.decoder.lqe_layers.\1.reg_conf.layers.\2.\3", # Decoder heads and projections @@ -449,50 +447,63 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: } DECODER_NO_GATEWAY_KEY_MAPPING = { - r"decoder\.decoder\.layers\.(\d+)\.norm2\.scale": r"model.decoder.layers.\1.encoder_attn_layer_norm.scale", + r"decoder\.decoder\.layers\.(\d+)\.norm2\.scale": r"model.decoder.layers.\1.encoder_attn_layer_norm.weight", } DINOV3_KEY_MAPPING = { # ViT embeddings - r"backbone\.dinov3\.patch_embed\.proj\.(weight|bias)": r"model.dinov3_backbone.embeddings.patch_embeddings.\1", - r"backbone\.dinov3\.cls_token": r"model.dinov3_backbone.embeddings.cls_token", - r"backbone\.dinov3\.storage_tokens": r"model.dinov3_backbone.embeddings.register_tokens", - r"backbone\.dinov3\.mask_token": r"model.dinov3_backbone.embeddings.mask_token", + r"backbone\.dinov3\.patch_embed\.proj\.(weight|bias)": r"model.backbone.backbone.embeddings.patch_embeddings.\1", + r"backbone\.dinov3\.cls_token": r"model.backbone.backbone.embeddings.cls_token", + r"backbone\.dinov3\.storage_tokens": r"model.backbone.backbone.embeddings.register_tokens", + r"backbone\.dinov3\.mask_token": r"model.backbone.backbone.embeddings.mask_token", # ViT blocks - r"backbone\.dinov3\.blocks\.(\d+)\.norm1\.(weight|bias)": r"model.dinov3_backbone.layers.\1.norm1.\2", - r"backbone\.dinov3\.blocks\.(\d+)\.norm2\.(weight|bias)": r"model.dinov3_backbone.layers.\1.norm2.\2", - r"backbone\.dinov3\.blocks\.(\d+)\.attn\.qkv\.(weight|bias)": r"model.dinov3_backbone.layers.\1.attention.qkv.\2", - r"backbone\.dinov3\.blocks\.(\d+)\.attn\.proj\.(weight|bias)": r"model.dinov3_backbone.layers.\1.attention.o_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.norm1\.(weight|bias)": r"model.backbone.backbone.layer.\1.norm1.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.norm2\.(weight|bias)": r"model.backbone.backbone.layer.\1.norm2.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.attn\.qkv\.(weight|bias)": r"model.backbone.backbone.layer.\1.attention.qkv.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.attn\.proj\.(weight|bias)": r"model.backbone.backbone.layer.\1.attention.o_proj.\2", # Standard MLP (S/M/L) - r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.fc1\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.up_proj.\2", - r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.fc2\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.down_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.fc1\.(weight|bias)": r"model.backbone.backbone.layer.\1.mlp.up_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.fc2\.(weight|bias)": r"model.backbone.backbone.layer.\1.mlp.down_proj.\2", # SwiGLU MLP (X only) - r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w1\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.gate_proj.\2", - r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w2\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.up_proj.\2", - r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w3\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.down_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w1\.(weight|bias)": r"model.backbone.backbone.layer.\1.mlp.gate_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w2\.(weight|bias)": r"model.backbone.backbone.layer.\1.mlp.up_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w3\.(weight|bias)": r"model.backbone.backbone.layer.\1.mlp.down_proj.\2", # LayerScale (L/X only) - r"backbone\.dinov3\.blocks\.(\d+)\.ls1\.gamma": r"model.dinov3_backbone.layers.\1.layer_scale1.lambda1", - r"backbone\.dinov3\.blocks\.(\d+)\.ls2\.gamma": r"model.dinov3_backbone.layers.\1.layer_scale2.lambda1", + r"backbone\.dinov3\.blocks\.(\d+)\.ls1\.gamma": r"model.backbone.backbone.layer.\1.layer_scale1.lambda1", + r"backbone\.dinov3\.blocks\.(\d+)\.ls2\.gamma": r"model.backbone.backbone.layer.\1.layer_scale2.lambda1", # Norm (L/X only) - r"backbone\.dinov3\.norm\.(weight|bias)": r"model.dinov3_backbone.norm.\1", + r"backbone\.dinov3\.norm\.(weight|bias)": r"model.backbone.backbone.norm.\1", # STA adapter - r"backbone\.sta\.stem\.0\.(weight|bias)": r"model.dinov3_backbone.sta.stem.0.\1", - r"backbone\.sta\.stem\.1\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.sta.stem.1.\1", - r"backbone\.sta\.conv2\.0\.(weight)": r"model.dinov3_backbone.sta.conv2.0.\1", - r"backbone\.sta\.conv2\.1\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.sta.conv2.1.\1", - r"backbone\.sta\.conv3\.1\.(weight)": r"model.dinov3_backbone.sta.conv3.1.\1", - r"backbone\.sta\.conv3\.2\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.sta.conv3.2.\1", - r"backbone\.sta\.conv4\.1\.(weight)": r"model.dinov3_backbone.sta.conv4.1.\1", - r"backbone\.sta\.conv4\.2\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.sta.conv4.2.\1", + r"backbone\.sta\.stem\.0\.(weight|bias)": r"model.backbone.sta.stem.0.\1", + r"backbone\.sta\.stem\.1\.(weight|bias|running_mean|running_var)": r"model.backbone.sta.stem.1.\1", + r"backbone\.sta\.conv2\.0\.(weight)": r"model.backbone.sta.conv2.0.\1", + r"backbone\.sta\.conv2\.1\.(weight|bias|running_mean|running_var)": r"model.backbone.sta.conv2.1.\1", + r"backbone\.sta\.conv3\.1\.(weight)": r"model.backbone.sta.conv3.1.\1", + r"backbone\.sta\.conv3\.2\.(weight|bias|running_mean|running_var)": r"model.backbone.sta.conv3.2.\1", + r"backbone\.sta\.conv4\.1\.(weight)": r"model.backbone.sta.conv4.1.\1", + r"backbone\.sta\.conv4\.2\.(weight|bias|running_mean|running_var)": r"model.backbone.sta.conv4.2.\1", # Projection convs/norms - r"backbone\.convs\.(\d+)\.weight": r"model.dinov3_backbone.convs.\1.weight", - r"backbone\.norms\.(\d+)\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.norms.\1.\2", + r"backbone\.convs\.(\d+)\.weight": r"model.backbone.convs.\1.weight", + r"backbone\.norms\.(\d+)\.(weight|bias|running_mean|running_var)": r"model.backbone.norms.\1.\2", } +def split_swiglu_weights(state_dict): + for key in list(state_dict.keys()): + if ".mlp.w12." in key: + w12 = state_dict.pop(key) + gate, up = w12.chunk(2, dim=0) + state_dict[key.replace(".w12.", ".gate_proj.")] = gate + state_dict[key.replace(".w12.", ".up_proj.")] = up + elif ".mlp.w3." in key: + state_dict[key.replace(".w3.", ".down_proj.")] = state_dict.pop(key) + + def convert_old_keys_to_new_keys(state_dict, config=None): mapping = dict(ORIGINAL_TO_CONVERTED_KEY_MAPPING) + is_dinov3 = getattr(config.backbone_config, "model_type", None) == "dinov3_vit" if config else False + if config is not None: if config.encoder_type == "lite": for k in list(mapping.keys()): @@ -513,9 +524,9 @@ def convert_old_keys_to_new_keys(state_dict, config=None): if "gateway" in k: del mapping[k] - if config.backbone_type == "dinov3": + if is_dinov3: for k in list(mapping.keys()): - if k.startswith(r"backbone\."): + if k.startswith(r"backbone\.") or k.startswith(r"encoder\.input_proj"): del mapping[k] mapping.update(DINOV3_KEY_MAPPING) @@ -578,28 +589,25 @@ def strip_dinov3_model_prefix(state_dict): def read_in_q_k_v_vit(state_dict, config): - from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig - - vit_config = DINOv3ViTConfig(**config.dinov3_backbone_config) - has_key_bias = config.dinov3_backbone_config.get("key_bias", True) - prefix = "model.dinov3_backbone" - for i in range(vit_config.num_hidden_layers): - qkv_key = f"{prefix}.layers.{i}.attention.qkv.weight" + has_key_bias = getattr(config.backbone_config, "key_bias", True) + prefix = "model.backbone.backbone" + for i in range(config.backbone_config.num_hidden_layers): + qkv_key = f"{prefix}.layer.{i}.attention.qkv.weight" if qkv_key in state_dict: qkv_w = state_dict.pop(qkv_key) q, k, v = qkv_w.chunk(3, dim=0) - state_dict[f"{prefix}.layers.{i}.attention.q_proj.weight"] = q - state_dict[f"{prefix}.layers.{i}.attention.k_proj.weight"] = k - state_dict[f"{prefix}.layers.{i}.attention.v_proj.weight"] = v + state_dict[f"{prefix}.layer.{i}.attention.q_proj.weight"] = q + state_dict[f"{prefix}.layer.{i}.attention.k_proj.weight"] = k + state_dict[f"{prefix}.layer.{i}.attention.v_proj.weight"] = v - qkv_bias_key = f"{prefix}.layers.{i}.attention.qkv.bias" + qkv_bias_key = f"{prefix}.layer.{i}.attention.qkv.bias" if qkv_bias_key in state_dict: qkv_b = state_dict.pop(qkv_bias_key) q_b, k_b, v_b = qkv_b.chunk(3, dim=0) - state_dict[f"{prefix}.layers.{i}.attention.q_proj.bias"] = q_b + state_dict[f"{prefix}.layer.{i}.attention.q_proj.bias"] = q_b if has_key_bias: - state_dict[f"{prefix}.layers.{i}.attention.k_proj.bias"] = k_b - state_dict[f"{prefix}.layers.{i}.attention.v_proj.bias"] = v_b + state_dict[f"{prefix}.layer.{i}.attention.k_proj.bias"] = k_b + state_dict[f"{prefix}.layer.{i}.attention.v_proj.bias"] = v_b def load_original_state_dict(repo_id): @@ -634,7 +642,9 @@ def convert_deimv2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, if key.endswith(".num_batches_tracked"): state_dict.pop(key) - if config.backbone_type == "dinov3": + is_dinov3 = getattr(config.backbone_config, "model_type", None) == "dinov3_vit" + + if is_dinov3: strip_dinov3_model_prefix(state_dict) for key in list(state_dict.keys()): if "rope_embed.periods" in key or "qkv.bias_mask" in key: @@ -645,9 +655,11 @@ def convert_deimv2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, state_dict = convert_old_keys_to_new_keys(state_dict, config) - if config.backbone_type == "dinov3": + split_swiglu_weights(state_dict) + + if is_dinov3: read_in_q_k_v_vit(state_dict, config) - mask_key = "model.dinov3_backbone.embeddings.mask_token" + mask_key = "model.backbone.backbone.embeddings.mask_token" if mask_key in state_dict and state_dict[mask_key].dim() == 2: state_dict[mask_key] = state_dict[mask_key].unsqueeze(1) @@ -677,7 +689,7 @@ def convert_deimv2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, model = Deimv2ForObjectDetection(config) missing, unexpected = model.load_state_dict(state_dict, strict=False) - expected_missing = {"mask_token", "register_tokens", "layer_scale1", "layer_scale2"} + expected_missing = {"mask_token", "register_tokens", "layer_scale1", "layer_scale2", "backbone.norm"} unexpected_missing = [k for k in missing if not any(e in k for e in expected_missing)] if unexpected_missing: logger.warning(f"Missing keys ({len(unexpected_missing)}): {unexpected_missing[:10]}...") @@ -690,7 +702,6 @@ def convert_deimv2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, model.eval() - is_dinov3 = config.backbone_type == "dinov3" if is_dinov3: image_processor = RTDetrImageProcessor( do_normalize=True, diff --git a/src/transformers/models/deimv2/modeling_deimv2.py b/src/transformers/models/deimv2/modeling_deimv2.py index 9173609eb0df..525695807d90 100644 --- a/src/transformers/models/deimv2/modeling_deimv2.py +++ b/src/transformers/models/deimv2/modeling_deimv2.py @@ -17,28 +17,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import math from collections.abc import Callable from dataclasses import dataclass -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from ... import initialization as init -from ...activations import ACT2CLS, ACT2FN +from ...activations import ACT2CLS +from ...backbone_utils import load_backbone from ...image_transforms import center_to_corners_format, corners_to_center_format -from ...modeling_layers import GradientCheckpointingLayer +from ...integrations import use_kernel_forward_from_hub from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import compile_compatible_method_lru_cache from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int -from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults +from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from .configuration_deimv2 import Deimv2Config, Deimv2DINOv3ViTConfig +from .configuration_deimv2 import Deimv2Config @dataclass @@ -79,30 +80,94 @@ class Deimv2DecoderOutput(ModelOutput): cross_attentions: tuple[torch.FloatTensor] | None = None +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the RT-DETR encoder-decoder model. + """ +) +class Deimv2ModelOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate predicted corners (predicted corners of each layer of the decoder). + initial_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points used for the first decoder layer. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`): + Logits of predicted bounding boxes coordinates in the encoder stage. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values. + """ + + last_hidden_state: torch.FloatTensor | None = None + intermediate_hidden_states: torch.FloatTensor | None = None + intermediate_logits: torch.FloatTensor | None = None + intermediate_reference_points: torch.FloatTensor | None = None + intermediate_predicted_corners: torch.FloatTensor | None = None + initial_reference_points: torch.FloatTensor | None = None + decoder_hidden_states: tuple[torch.FloatTensor] | None = None + decoder_attentions: tuple[torch.FloatTensor] | None = None + cross_attentions: tuple[torch.FloatTensor] | None = None + encoder_last_hidden_state: torch.FloatTensor | None = None + encoder_hidden_states: tuple[torch.FloatTensor] | None = None + encoder_attentions: tuple[torch.FloatTensor] | None = None + init_reference_points: torch.FloatTensor | None = None + enc_topk_logits: torch.FloatTensor | None = None + enc_topk_bboxes: torch.FloatTensor | None = None + enc_outputs_class: torch.FloatTensor | None = None + enc_outputs_coord_logits: torch.FloatTensor | None = None + denoising_meta_values: dict | None = None + + +@use_kernel_forward_from_hub("RMSNorm") class Deimv2RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + Deimv2RMSNorm is equivalent to T5LayerNorm + """ super().__init__() - self.eps = eps - self.scale = nn.Parameter(torch.ones(dim)) + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_dtype = hidden_states.dtype - hidden_states = hidden_states.float() - hidden_states = hidden_states * torch.rsqrt(hidden_states.pow(2).mean(-1, keepdim=True) + self.eps) - return (hidden_states * self.scale).to(input_dtype) + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Deimv2SwiGLUFFN(nn.Module): def __init__(self, in_features: int, hidden_features: int, out_features: int): super().__init__() - self.w12 = nn.Linear(in_features, 2 * hidden_features) - self.w3 = nn.Linear(hidden_features, out_features) + self.gate_proj = nn.Linear(in_features, hidden_features) + self.up_proj = nn.Linear(in_features, hidden_features) + self.down_proj = nn.Linear(hidden_features, out_features) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - x12 = self.w12(hidden_states) - x1, x2 = x12.chunk(2, dim=-1) - hidden = F.silu(x1) * x2 - return self.w3(hidden) + return self.down_proj(F.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) class Deimv2Gate(nn.Module): @@ -352,400 +417,7 @@ def forward(self, x): return self.activation(y) -class Deimv2DINOv3ViTEmbeddings(nn.Module): - """ - Construct the CLS token, mask token, position and patch embeddings. - """ - - def __init__(self, config: Deimv2DINOv3ViTConfig): - super().__init__() - self.config = config - self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) - self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) - self.register_tokens = nn.Parameter(torch.empty(1, config.num_register_tokens, config.hidden_size)) - self.patch_embeddings = nn.Conv2d( - config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size - ) - - def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None) -> torch.Tensor: - batch_size = pixel_values.shape[0] - target_dtype = self.patch_embeddings.weight.dtype - - # (batch_size, num_channels, height, width) -> (batch_size, num_patches, hidden_size) - patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) - patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2) - - if bool_masked_pos is not None: - mask_token = self.mask_token.to(patch_embeddings.dtype) - patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings) - - # Add CLS and register tokens - cls_token = self.cls_token.expand(batch_size, -1, -1) - register_tokens = self.register_tokens.expand(batch_size, -1, -1) - embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) - - return embeddings - - -class Deimv2DINOv3ViTLayerScale(nn.Module): - def __init__(self, config) -> None: - super().__init__() - self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size)) - - def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: - return hidden_state * self.lambda1 - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor | None, - scaling: float | None = None, - dropout: float = 0.0, - **kwargs: Unpack[TransformersKwargs], -): - if scaling is None: - scaling = query.size(-1) ** -0.5 - - # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - - attn_output = torch.matmul(attn_weights, value) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -def apply_rotary_pos_emb( - q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, **kwargs -) -> tuple[torch.Tensor, torch.Tensor]: - """Applies Rotary Position Embedding to the query and key tensors, but only to the patch tokens, - ignoring the prefix tokens (cls token and register tokens). - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - - num_tokens = q.shape[-2] - num_patches = sin.shape[-2] - num_prefix_tokens = num_tokens - num_patches # cls token + register tokens - - q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2) - k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2) - - # apply rope only to patch tokens - q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin) - k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin) - - q = torch.cat((q_prefix_tokens, q_patches), dim=-2) - k = torch.cat((k_prefix_tokens, k_patches), dim=-2) - - return q, k - - -class Deimv2DINOv3ViTAttention(nn.Module): - """ - Multi-headed attention compatible with ALL_ATTENTION_FUNCTIONS. - """ - - def __init__(self, config: Deimv2DINOv3ViTConfig): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - self.is_causal = False - - self.scaling = self.head_dim**-0.5 - self.is_causal = False - - self.dropout = config.attention_dropout - self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.key_bias) - self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.value_bias) - - self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.query_bias) - self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.proj_bias) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, - **kwargs: Unpack[TransformersKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None]: - """Input shape: Batch x Time x Channel""" - - batch_size, patches, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( - self.config._attn_implementation, eager_attention_forward - ) - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, - scaling=self.scaling, - **kwargs, - ) - - attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() - attn_output = self.o_proj(attn_output) - - return attn_output, attn_weights - - -def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: - """ - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - - """ - if drop_prob == 0.0 or not training: - return input - keep_prob = 1 - drop_prob - shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) - random_tensor.floor_() # binarize - output = input.div(keep_prob) * random_tensor - return output - - -class Deimv2DINOv3ViTDropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, drop_prob: float | None = None) -> None: - super().__init__() - self.drop_prob = drop_prob - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - return drop_path(hidden_states, self.drop_prob, self.training) - - def extra_repr(self) -> str: - return f"p={self.drop_prob}" - - -class Deimv2DINOv3ViTMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.up_proj(x))) - - -class Deimv2DINOv3ViTGatedMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -class Deimv2DINOv3ViTLayer(GradientCheckpointingLayer): - """This corresponds to the Block class in the original implementation.""" - - def __init__(self, config: Deimv2DINOv3ViTConfig): - super().__init__() - - self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = Deimv2DINOv3ViTAttention(config) - self.layer_scale1 = Deimv2DINOv3ViTLayerScale(config) - self.drop_path = ( - Deimv2DINOv3ViTDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() - ) - - self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.use_gated_mlp: - self.mlp = Deimv2DINOv3ViTGatedMLP(config) - else: - self.mlp = Deimv2DINOv3ViTMLP(config) - self.layer_scale2 = Deimv2DINOv3ViTLayerScale(config) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, - ) -> torch.Tensor: - # Attention with residual connection - residual = hidden_states - hidden_states = self.norm1(hidden_states) - hidden_states, _ = self.attention( - hidden_states, - attention_mask=attention_mask, - position_embeddings=position_embeddings, - ) - hidden_states = self.layer_scale1(hidden_states) - hidden_states = self.drop_path(hidden_states) + residual - - # MLP with residual connection - residual = hidden_states - hidden_states = self.norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = self.layer_scale2(hidden_states) - hidden_states = self.drop_path(hidden_states) + residual - - return hidden_states - - -@compile_compatible_method_lru_cache(maxsize=32) -def get_patches_center_coordinates( - num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device -) -> torch.Tensor: - """ - Computes the 2D coordinates of the centers of image patches, normalized to the range [-1, +1]. - The center of each patch is exactly halfway between its top-left and bottom-right corners. - - Args: - num_patches_h (int): Number of patches along the vertical (height) axis. - num_patches_w (int): Number of patches along the horizontal (width) axis. - dtype (torch.dtype): The desired data type of the returned tensor. - - Returns: - torch.Tensor: A tensor of shape (height * width, 2), where each row contains the (y, x) - coordinates of a patch center, normalized to [-1, +1]. - """ - coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device) - coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device) - coords_h = coords_h / num_patches_h - coords_w = coords_w / num_patches_w - # (height, width, 2) -> (height * width, 2) - coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) - coords = coords.flatten(0, 1) - # Shift range [0, 1] to [-1, +1] - coords = 2.0 * coords - 1.0 - return coords - - -def augment_patches_center_coordinates( - coords: torch.Tensor, - shift: float | None = None, - jitter: float | None = None, - rescale: float | None = None, -) -> torch.Tensor: - # Shift coords by adding a uniform value in [-shift, shift] - if shift is not None: - shift_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype) - shift_hw = shift_hw.uniform_(-shift, shift) - coords = coords + shift_hw - - # Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter] - if jitter is not None: - jitter_range = np.log(jitter) - jitter_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype) - jitter_hw = jitter_hw.uniform_(-jitter_range, jitter_range).exp() - coords = coords * jitter_hw - - # Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale] - if rescale is not None: - rescale_range = np.log(rescale) - rescale_hw = torch.empty(1, device=coords.device, dtype=coords.dtype) - rescale_hw = rescale_hw.uniform_(-rescale_range, rescale_range).exp() - coords = coords * rescale_hw - - return coords - - -class Deimv2DINOv3ViTRopePositionEmbedding(nn.Module): - inv_freq: torch.Tensor - - def __init__(self, config: Deimv2DINOv3ViTConfig): - super().__init__() - - self.config = config - self.base = config.rope_theta - self.head_dim = config.hidden_size // config.num_attention_heads - self.num_patches_h = config.image_size // config.patch_size - self.num_patches_w = config.image_size // config.patch_size - - inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32) # (head_dim / 4,) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - _, _, height, width = pixel_values.shape - num_patches_h = height // self.config.patch_size - num_patches_w = width // self.config.patch_size - - device = pixel_values.device - device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu" - - with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - # Although we could precompute static patch_coords from image_size and patch_size in the config, - # the model was trained with random_scale, so it can process images of varying sizes. - # Therefore, it's better to compute patch_coords dynamically (with lru_cache). - patch_coords = get_patches_center_coordinates( - num_patches_h, num_patches_w, dtype=torch.float32, device=device - ) - if self.training: - patch_coords = augment_patches_center_coordinates( - patch_coords, - shift=self.config.pos_embed_shift, - jitter=self.config.pos_embed_jitter, - rescale=self.config.pos_embed_rescale, - ) - - # (height * width, 2, head_dim / 4) -> (height * width, head_dim / 2) -> (height * width, head_dim) - angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] - angles = angles.flatten(1, 2) - angles = angles.tile(2) - - cos = torch.cos(angles) - sin = torch.sin(angles) - - dtype = pixel_values.dtype - return cos.to(dtype=dtype), sin.to(dtype=dtype) - - -class Deimv2CSPRepLayer2(nn.Module): +class Deimv2CSPRepLayer(nn.Module): def __init__( self, config: Deimv2Config, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0 ): @@ -762,34 +434,27 @@ def __init__( self.conv3 = nn.Identity() def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: - chunks = list(self.conv1(hidden_state).chunk(2, 1)) - bottleneck_out = chunks[1] + hidden_state_1, hidden_state_2 = self.conv1(hidden_state).chunk(2, 1) for bottleneck in self.bottlenecks: - bottleneck_out = bottleneck(bottleneck_out) - return self.conv3(chunks[0] + bottleneck_out) + hidden_state_2 = bottleneck(hidden_state_2) + return self.conv3(hidden_state_1 + hidden_state_2) class Deimv2RepNCSPELAN5(nn.Module): - def __init__( - self, - config: Deimv2Config, - numb_blocks: int = 3, - c1: int | None = None, - c2: int | None = None, - c3: int | None = None, - c4: int | None = None, - ): + def __init__(self, config: Deimv2Config, numb_blocks: int = 3): super().__init__() - act = config.activation_function - c1 = c1 if c1 is not None else config.encoder_hidden_dim - c2 = c2 if c2 is not None else config.encoder_hidden_dim - c3 = c3 if c3 is not None else config.encoder_hidden_dim * 2 - c4 = c4 if c4 is not None else round(config.hidden_expansion * config.encoder_hidden_dim // 2) - self.conv_dim = c3 // 2 - self.conv1 = Deimv2ConvNormLayer(config, c1, c3, 1, 1, activation=act) - self.csp_rep1 = Deimv2CSPRepLayer2(config, c3 // 2, c4, num_blocks=numb_blocks) - self.csp_rep2 = Deimv2CSPRepLayer2(config, c4, c4, num_blocks=numb_blocks) - self.conv4 = Deimv2ConvNormLayer(config, c3 + (2 * c4), c2, 1, 1, activation=act) + activation = config.activation_function + in_channels = config.encoder_hidden_dim + out_channels = config.encoder_hidden_dim + split_channels = config.encoder_hidden_dim * 2 + csp_channels = round(config.hidden_expansion * config.encoder_hidden_dim // 2) + self.conv_dim = split_channels // 2 + self.conv1 = Deimv2ConvNormLayer(config, in_channels, split_channels, 1, 1, activation=activation) + self.csp_rep1 = Deimv2CSPRepLayer(config, split_channels // 2, csp_channels, num_blocks=numb_blocks) + self.csp_rep2 = Deimv2CSPRepLayer(config, csp_channels, csp_channels, num_blocks=numb_blocks) + self.conv4 = Deimv2ConvNormLayer( + config, split_channels + (2 * csp_channels), out_channels, 1, 1, activation=activation + ) def forward(self, input_features: torch.Tensor) -> torch.Tensor: split_features = list(self.conv1(input_features).split((self.conv_dim, self.conv_dim), 1)) @@ -819,6 +484,34 @@ def forward(self, input_features: torch.Tensor) -> torch.Tensor: return input_features +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float | None = None, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Deimv2SelfAttention(nn.Module): """ Multi-headed self-attention from 'Attention Is All You Need' paper. @@ -1078,11 +771,11 @@ def __init__(self, config: Deimv2Config): ) def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - c1 = self.stem(pixel_values) - c2 = self.conv2(c1) - c3 = self.conv3(c2) - c4 = self.conv4(c3) - return c2, c3, c4 + hidden_states = self.stem(pixel_values) + feature_map_8 = self.conv2(hidden_states) + feature_map_16 = self.conv3(feature_map_8) + feature_map_32 = self.conv4(feature_map_16) + return feature_map_8, feature_map_16, feature_map_32 class Deimv2GAPFusion(nn.Module): @@ -1124,12 +817,9 @@ def __init__(self, config: Deimv2Config): self.bi_fusion = Deimv2GAPFusion(config, hidden_dim) - c1, c2 = hidden_dim, hidden_dim - c3 = hidden_dim * 2 - c4 = round(config.hidden_expansion * hidden_dim // 2) num_blocks = round(3 * config.depth_mult) - self.fpn_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks, c1=c1, c2=c2, c3=c3, c4=c4) - self.pan_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks, c1=c1, c2=c2, c3=c3, c4=c4) + self.fpn_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks) + self.pan_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks) def forward(self, inputs_embeds=None, **kwargs) -> BaseModelOutput: feats = inputs_embeds @@ -1148,27 +838,115 @@ def forward(self, inputs_embeds=None, **kwargs) -> BaseModelOutput: return BaseModelOutput(last_hidden_state=outs) -class Deimv2DINOv3Backbone(nn.Module): - def __init__(self, config: Deimv2Config): +class Deimv2FrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): super().__init__() - vit_config = Deimv2DINOv3ViTConfig(**config.dinov3_backbone_config) + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) - self.embeddings = Deimv2DINOv3ViTEmbeddings(vit_config) - self.rope_embeddings = Deimv2DINOv3ViTRopePositionEmbedding(vit_config) - self.layers = nn.ModuleList([Deimv2DINOv3ViTLayer(vit_config) for _ in range(vit_config.num_hidden_layers)]) + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] - self.apply_layernorm = config.dinov3_apply_layernorm - if self.apply_layernorm: - self.norm = nn.LayerNorm(vit_config.hidden_size, eps=vit_config.layer_norm_eps) + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +def replace_batch_norm(model): + r""" + Recursively replace all `torch.nn.BatchNorm2d` with `Deimv2FrozenBatchNorm2d`. - self.interaction_indexes = config.dinov3_interaction_indexes - self.patch_size = vit_config.patch_size - self.num_prefix_tokens = 1 + vit_config.num_register_tokens + Args: + model (torch.nn.Module): + input model + """ + for name, module in model.named_children(): + if isinstance(module, nn.BatchNorm2d): + new_module = Deimv2FrozenBatchNorm2d(module.num_features) + + if module.weight.device != torch.device("meta"): + new_module.weight.copy_(module.weight) + new_module.bias.copy_(module.bias) + new_module.running_mean.copy_(module.running_mean) + new_module.running_var.copy_(module.running_var) + + model._modules[name] = new_module + + if len(list(module.children())) > 0: + replace_batch_norm(module) + + +class Deimv2ConvEncoder(nn.Module): + """ + Convolutional backbone using the modeling_deimv2_resnet.py. + + nn.BatchNorm2d layers are replaced by Deimv2FrozenBatchNorm2d as defined above. + https://github.com/lyuwenyu/RT-DETR/blob/main/Deimv2_pytorch/src/nn/backbone/presnet.py#L142 + """ + + def __init__(self, config): + super().__init__() + + backbone = load_backbone(config) + + if config.freeze_backbone_batch_norms: + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) + self.model = backbone + self.intermediate_channel_sizes = self.model.channels + if config.encoder_type != "lite": + encoder_input_proj = [] + for in_channel in self.intermediate_channel_sizes: + encoder_input_proj.append( + nn.Sequential( + nn.Conv2d(in_channel, config.encoder_hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(config.encoder_hidden_dim), + ) + ) + self.encoder_input_proj = nn.ModuleList(encoder_input_proj) + + def forward(self, pixel_values: torch.Tensor) -> list[torch.Tensor]: + features = self.model(pixel_values).feature_maps + if hasattr(self, "encoder_input_proj"): + return [self.encoder_input_proj[i](feat) for i, feat in enumerate(features)] + return list(features) + + +class Deimv2DINOv3ConvEncoder(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + self.backbone = load_backbone(config) self.sta = Deimv2SpatialTuningAdapter(config) - embed_dim = vit_config.hidden_size - hidden_dim = config.dinov3_hidden_dim or embed_dim + embed_dim = config.backbone_config.hidden_size + hidden_dim = config.encoder_hidden_dim sta_ch = config.sta_inplanes self.convs = nn.ModuleList( [ @@ -1180,28 +958,19 @@ def __init__(self, config: Deimv2Config): self.norms = nn.ModuleList([nn.BatchNorm2d(hidden_dim) for _ in range(3)]) def forward(self, pixel_values: torch.Tensor) -> list[torch.Tensor]: - hidden_states = self.embeddings(pixel_values) - position_embeddings = self.rope_embeddings(pixel_values) - - intermediate_outputs = [] - for i, layer in enumerate(self.layers): - hidden_states = layer(hidden_states, position_embeddings=position_embeddings) - if i in self.interaction_indexes: - out = self.norm(hidden_states) if self.apply_layernorm else hidden_states - intermediate_outputs.append(out) + backbone_output = self.backbone(pixel_values) + feature_maps = backbone_output.feature_maps - batch_size = pixel_values.shape[0] - h_patches = pixel_values.shape[2] // self.patch_size - w_patches = pixel_values.shape[3] // self.patch_size + patch_size = self.backbone.config.patch_size + h_patches = pixel_values.shape[2] // patch_size + w_patches = pixel_values.shape[3] // patch_size sem_feats = [] - num_scales = len(intermediate_outputs) - for i, feat in enumerate(intermediate_outputs): - patch_tokens = feat[:, self.num_prefix_tokens :] - spatial = patch_tokens.transpose(1, 2).reshape(batch_size, -1, h_patches, w_patches).contiguous() + num_scales = len(feature_maps) + for i, feat in enumerate(feature_maps): resize_h = int(h_patches * 2 ** (num_scales - 2 - i)) resize_w = int(w_patches * 2 ** (num_scales - 2 - i)) - spatial = F.interpolate(spatial, size=[resize_h, resize_w], mode="bilinear", align_corners=False) + spatial = F.interpolate(feat, size=[resize_h, resize_w], mode="bilinear", align_corners=False) sem_feats.append(spatial) detail_feats = self.sta(pixel_values) @@ -1275,6 +1044,7 @@ def __init__(self, config: Deimv2Config): self.final_layer_norm = Deimv2RMSNorm(config.d_model) # gate self.gateway = Deimv2Gate(config.d_model) + self.use_gateway = config.use_gateway if config.use_gateway: self.gateway = Deimv2Gate(config.d_model) else: @@ -1336,7 +1106,7 @@ def forward( ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if hasattr(self, "gateway"): + if self.use_gateway: hidden_states = self.gateway(residual, hidden_states) else: hidden_states = residual + hidden_states @@ -1345,7 +1115,8 @@ def forward( residual = hidden_states hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states.clamp(min=-65504, max=65504)) + clamp_value = torch.finfo(hidden_states.dtype).max + hidden_states = self.final_layer_norm(hidden_states.clamp(min=-clamp_value, max=clamp_value)) return hidden_states @@ -1381,6 +1152,7 @@ class Deimv2PreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" + # initialize linear layer bias value according to a given probability value. if isinstance(module, (Deimv2ForObjectDetection, Deimv2Decoder)): if module.class_embed is not None: for layer in module.class_embed: @@ -1443,38 +1215,26 @@ def _init_weights(self, module): init.constant_(module.reg_conf.layers[-1].bias, 0) init.constant_(module.reg_conf.layers[-1].weight, 0) - if isinstance(module, Deimv2SwiGLUFFN): - init.xavier_uniform_(module.w12.weight) - init.constant_(module.w12.bias, 0) - init.xavier_uniform_(module.w3.weight) - init.constant_(module.w3.bias, 0) - - if isinstance(module, Deimv2RMSNorm): - init.ones_(module.scale) - if isinstance(module, nn.LayerNorm): init.ones_(module.weight) init.zeros_(module.bias) - if isinstance(module, Deimv2DINOv3ViTEmbeddings): - init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) - if module.config.num_register_tokens > 0: - init.trunc_normal_(module.register_tokens, mean=0.0, std=self.config.initializer_range) - init.zeros_(module.mask_token) - - if isinstance(module, Deimv2DINOv3ViTLayerScale) and self.config.dinov3_backbone_config is not None: - layerscale_value = self.config.dinov3_backbone_config.get("layerscale_value", 1.0) - init.constant_(module.lambda1, layerscale_value) - - if isinstance(module, Deimv2DINOv3ViTRopePositionEmbedding): - inv_freq = 1 / module.base ** torch.arange(0, 1, 4 / module.head_dim, dtype=torch.float32) - init.copy_(module.inv_freq, inv_freq) - if hasattr(module, "weight_embedding") and self.config.learn_initial_query: init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: init.xavier_uniform_(module.denoising_class_embed.weight) + if isinstance(module, Deimv2SwiGLUFFN): + init.xavier_uniform_(module.gate_proj.weight) + init.constant_(module.gate_proj.bias, 0) + init.xavier_uniform_(module.up_proj.weight) + init.constant_(module.up_proj.bias, 0) + init.xavier_uniform_(module.down_proj.weight) + init.constant_(module.down_proj.bias, 0) + + if isinstance(module, Deimv2RMSNorm): + init.ones_(module.weight) + class Deimv2HybridEncoder(Deimv2PreTrainedModel): """ @@ -1787,64 +1547,6 @@ def forward( ) -@dataclass -@auto_docstring( - custom_intro=""" - Base class for outputs of the RT-DETR encoder-decoder model. - """ -) -class Deimv2ModelOutput(ModelOutput): - r""" - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the decoder of the model. - intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): - Stacked intermediate hidden states (output of each layer of the decoder). - intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): - Stacked intermediate logits (logits of each layer of the decoder). - intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): - Stacked intermediate reference points (reference points of each layer of the decoder). - intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): - Stacked intermediate predicted corners (predicted corners of each layer of the decoder). - initial_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): - Initial reference points used for the first decoder layer. - init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): - Initial reference points sent through the Transformer decoder. - enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): - Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are - picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e. - foreground and background). - enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`): - Logits of predicted bounding boxes coordinates in the encoder stage. - enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): - Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are - picked as region proposals in the first stage. Output of bounding box binary classification (i.e. - foreground and background). - enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): - Logits of predicted bounding boxes coordinates in the first stage. - denoising_meta_values (`dict`): - Extra dictionary for the denoising related values. - """ - - last_hidden_state: torch.FloatTensor | None = None - intermediate_hidden_states: torch.FloatTensor | None = None - intermediate_logits: torch.FloatTensor | None = None - intermediate_reference_points: torch.FloatTensor | None = None - intermediate_predicted_corners: torch.FloatTensor | None = None - initial_reference_points: torch.FloatTensor | None = None - decoder_hidden_states: tuple[torch.FloatTensor] | None = None - decoder_attentions: tuple[torch.FloatTensor] | None = None - cross_attentions: tuple[torch.FloatTensor] | None = None - encoder_last_hidden_state: torch.FloatTensor | None = None - encoder_hidden_states: tuple[torch.FloatTensor] | None = None - encoder_attentions: tuple[torch.FloatTensor] | None = None - init_reference_points: torch.FloatTensor | None = None - enc_topk_logits: torch.FloatTensor | None = None - enc_topk_bboxes: torch.FloatTensor | None = None - enc_outputs_class: torch.FloatTensor | None = None - enc_outputs_coord_logits: torch.FloatTensor | None = None - denoising_meta_values: dict | None = None - - def get_contrastive_denoising_training_group( targets, num_classes, @@ -1977,23 +1679,11 @@ class Deimv2Model(Deimv2PreTrainedModel): def __init__(self, config: Deimv2Config): super().__init__(config) - if config.backbone_type == "dinov3": - self.dinov3_backbone = Deimv2DINOv3Backbone(config) + is_dinov3 = getattr(config.backbone_config, "model_type", None) == "dinov3_vit" + if is_dinov3: + self.backbone = Deimv2DINOv3ConvEncoder(config) else: - from ..d_fine.modeling_d_fine import DFineConvEncoder - - self.backbone = DFineConvEncoder(config) - if config.encoder_type != "lite": - intermediate_channel_sizes = self.backbone.intermediate_channel_sizes - encoder_input_proj = [] - for in_channel in intermediate_channel_sizes: - encoder_input_proj.append( - nn.Sequential( - nn.Conv2d(in_channel, config.encoder_hidden_dim, kernel_size=1, bias=False), - nn.BatchNorm2d(config.encoder_hidden_dim), - ) - ) - self.encoder_input_proj = nn.ModuleList(encoder_input_proj) + self.backbone = Deimv2ConvEncoder(config) if config.encoder_type == "lite": self.encoder = Deimv2LiteEncoder(config) @@ -2038,9 +1728,6 @@ def __init__(self, config: Deimv2Config): self.decoder_input_proj = nn.ModuleList(decoder_input_proj) self.decoder = Deimv2Decoder(config) - if config.use_spatial_tuning_adapter and config.backbone_type != "dinov3": - self.spatial_tuning_adapter = Deimv2SpatialTuningAdapter(config) - self.post_init() def freeze_backbone(self): @@ -2131,14 +1818,7 @@ def forward( if pixel_mask is None: pixel_mask = torch.ones(((batch_size, height, width)), device=device) - if self.config.backbone_type == "dinov3": - proj_feats = self.dinov3_backbone(pixel_values) - elif self.config.encoder_type == "lite": - features = self.backbone(pixel_values, pixel_mask) - proj_feats = [source for source, mask in features] - else: - features = self.backbone(pixel_values, pixel_mask) - proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)] + proj_feats = self.backbone(pixel_values) else: batch_size = inputs_embeds.shape[0] device = inputs_embeds.device @@ -2180,8 +1860,6 @@ def forward( level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) if self.training and self.config.num_denoising > 0 and labels is not None: - from ..d_fine.modeling_d_fine import get_contrastive_denoising_training_group - ( denoising_class, denoising_bbox_unact, @@ -2242,8 +1920,6 @@ def forward( init_reference_points = reference_points_unact.detach() - from ..d_fine.modeling_d_fine import DFineModelOutput - decoder_outputs = self.decoder( inputs_embeds=target, encoder_hidden_states=source_flatten, @@ -2255,7 +1931,7 @@ def forward( **kwargs, ) - return DFineModelOutput( + return Deimv2ModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, intermediate_logits=decoder_outputs.intermediate_logits, diff --git a/src/transformers/models/deimv2/modular_deimv2.py b/src/transformers/models/deimv2/modular_deimv2.py index 5a7edf292378..536f0c10937d 100644 --- a/src/transformers/models/deimv2/modular_deimv2.py +++ b/src/transformers/models/deimv2/modular_deimv2.py @@ -11,17 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math import torch import torch.nn as nn import torch.nn.functional as F from ... import initialization as init +from ...backbone_utils import load_backbone from ...modeling_outputs import BaseModelOutput from ...utils import logging +from ..auto import AutoConfig from ..d_fine.configuration_d_fine import DFineConfig from ..d_fine.modeling_d_fine import ( + DFineAIFILayer, + DFineConvEncoder, DFineConvNormLayer, DFineDecoder, DFineDecoderLayer, @@ -34,19 +37,14 @@ DFineLQE, DFineMLP, DFineModel, + DFineModelOutput, DFineMultiscaleDeformableAttention, DFinePreTrainedModel, DFineRepVggBlock, DFineSCDown, + get_contrastive_denoising_training_group, ) -from ..dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig -from ..dinov3_vit.modeling_dinov3_vit import ( - DINOv3ViTEmbeddings, - DINOv3ViTLayer, - DINOv3ViTLayerScale, - DINOv3ViTRopePositionEmbedding, -) -from ..rt_detr.modeling_rt_detr import RTDetrAIFILayer +from ..llama.modeling_llama import LlamaRMSNorm logger = logging.get_logger(__name__) @@ -73,7 +71,8 @@ class Deimv2Config(DFineConfig): batch_norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the batch normalization layers. backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `HGNetV2Config()`): - The configuration of the backbone model. + The configuration of the backbone model. For HGNetV2 variants, use `HGNetV2Config`. + For DINOv3 variants, use `DINOv3ViTConfig`. freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`): Whether to freeze the batch normalization layers in the backbone. encoder_hidden_dim (`int`, *optional*, defaults to 256): @@ -202,8 +201,6 @@ class Deimv2Config(DFineConfig): Alpha parameter for the Matching Auxiliary Loss (MAL). If `None`, uses `focal_loss_alpha`. encoder_fuse_op (`str`, *optional*, defaults to `"sum"`): Fusion operation used in the encoder FPN. DEIMv2 uses `"sum"` instead of D-Fine's `"cat"`. - use_spatial_tuning_adapter (`bool`, *optional*, defaults to `False`): - Whether to use the Spatial Tuning Adapter (STA) for DINOv2 backbone variants. sta_inplanes (`int`, *optional*, defaults to 16): Number of input planes for the STA convolutional stem. encoder_type (`str`, *optional*, defaults to `"hybrid"`): @@ -214,17 +211,6 @@ class Deimv2Config(DFineConfig): uses RMSNorm on the encoder attention output instead. share_bbox_head (`bool`, *optional*, defaults to `False`): Whether to share the bounding box prediction head across all decoder layers. - backbone_type (`str`, *optional*, defaults to `"hgnetv2"`): - Type of backbone to use. `"hgnetv2"` uses HGNetV2, `"dinov3"` uses DINOv3 ViT backbone with STA. - dinov3_backbone_config (`dict`, *optional*): - Configuration dictionary for the DINOv3 ViT backbone. Passed as kwargs to `DINOv3ViTConfig`. - dinov3_interaction_indexes (`list[int]`, *optional*): - Layer indices in the DINOv3 ViT backbone from which to extract intermediate features. - dinov3_hidden_dim (`int`, *optional*): - Hidden dimension for the DINOv3 backbone projection convolutions. If `None`, uses `hidden_size` from - the DINOv3 ViT config. - dinov3_apply_layernorm (`bool`, *optional*, defaults to `False`): - Whether to apply LayerNorm to intermediate features extracted from the DINOv3 ViT backbone. encoder_has_trailing_conv (`bool`, *optional*, defaults to `True`): Whether the encoder's CSP blocks include a trailing 3x3 convolution after the bottleneck path. `True` for RepNCSPELAN4 (used by HGNetV2 N and LiteEncoder variants). @@ -234,6 +220,7 @@ class Deimv2Config(DFineConfig): """ model_type = "deimv2" + sub_configs = {"backbone_config": AutoConfig} def __init__( self, @@ -305,16 +292,10 @@ def __init__( use_dense_o2o=True, mal_alpha=None, encoder_fuse_op="sum", - use_spatial_tuning_adapter=False, sta_inplanes=16, encoder_type="hybrid", use_gateway=True, share_bbox_head=False, - backbone_type="hgnetv2", - dinov3_backbone_config=None, - dinov3_interaction_indexes=None, - dinov3_hidden_dim=None, - dinov3_apply_layernorm=False, encoder_has_trailing_conv=True, tie_word_embeddings=True, **kwargs, @@ -323,16 +304,10 @@ def __init__( self.use_dense_o2o = use_dense_o2o self.mal_alpha = mal_alpha self.encoder_fuse_op = encoder_fuse_op - self.use_spatial_tuning_adapter = use_spatial_tuning_adapter self.sta_inplanes = sta_inplanes self.encoder_type = encoder_type self.use_gateway = use_gateway self.share_bbox_head = share_bbox_head - self.backbone_type = backbone_type - self.dinov3_backbone_config = dinov3_backbone_config - self.dinov3_interaction_indexes = dinov3_interaction_indexes - self.dinov3_hidden_dim = dinov3_hidden_dim - self.dinov3_apply_layernorm = dinov3_apply_layernorm self.encoder_has_trailing_conv = encoder_has_trailing_conv super().__init__( initializer_range=initializer_range, @@ -404,38 +379,27 @@ def __init__( ) -class Deimv2DINOv3ViTConfig(DINOv3ViTConfig): - model_type = "deimv2_dinov3_vit" - - class Deimv2DecoderOutput(DFineDecoderOutput): pass -class Deimv2RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.scale = nn.Parameter(torch.ones(dim)) +class Deimv2ModelOutput(DFineModelOutput): + pass - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.float() - hidden_states = hidden_states * torch.rsqrt(hidden_states.pow(2).mean(-1, keepdim=True) + self.eps) - return (hidden_states * self.scale).to(input_dtype) + +class Deimv2RMSNorm(LlamaRMSNorm): + pass class Deimv2SwiGLUFFN(nn.Module): def __init__(self, in_features: int, hidden_features: int, out_features: int): super().__init__() - self.w12 = nn.Linear(in_features, 2 * hidden_features) - self.w3 = nn.Linear(hidden_features, out_features) + self.gate_proj = nn.Linear(in_features, hidden_features) + self.up_proj = nn.Linear(in_features, hidden_features) + self.down_proj = nn.Linear(hidden_features, out_features) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - x12 = self.w12(hidden_states) - x1, x2 = x12.chunk(2, dim=-1) - hidden = F.silu(x1) * x2 - return self.w3(hidden) + return self.down_proj(F.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) class Deimv2Gate(DFineGate): @@ -460,23 +424,7 @@ class Deimv2RepVggBlock(DFineRepVggBlock): pass -class Deimv2DINOv3ViTEmbeddings(DINOv3ViTEmbeddings): - pass - - -class Deimv2DINOv3ViTLayerScale(DINOv3ViTLayerScale): - pass - - -class Deimv2DINOv3ViTLayer(DINOv3ViTLayer): - pass - - -class Deimv2DINOv3ViTRopePositionEmbedding(DINOv3ViTRopePositionEmbedding): - pass - - -class Deimv2CSPRepLayer2(nn.Module): +class Deimv2CSPRepLayer(nn.Module): def __init__( self, config: Deimv2Config, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0 ): @@ -493,34 +441,27 @@ def __init__( self.conv3 = nn.Identity() def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: - chunks = list(self.conv1(hidden_state).chunk(2, 1)) - bottleneck_out = chunks[1] + hidden_state_1, hidden_state_2 = self.conv1(hidden_state).chunk(2, 1) for bottleneck in self.bottlenecks: - bottleneck_out = bottleneck(bottleneck_out) - return self.conv3(chunks[0] + bottleneck_out) + hidden_state_2 = bottleneck(hidden_state_2) + return self.conv3(hidden_state_1 + hidden_state_2) class Deimv2RepNCSPELAN5(nn.Module): - def __init__( - self, - config: Deimv2Config, - numb_blocks: int = 3, - c1: int | None = None, - c2: int | None = None, - c3: int | None = None, - c4: int | None = None, - ): + def __init__(self, config: Deimv2Config, numb_blocks: int = 3): super().__init__() - act = config.activation_function - c1 = c1 if c1 is not None else config.encoder_hidden_dim - c2 = c2 if c2 is not None else config.encoder_hidden_dim - c3 = c3 if c3 is not None else config.encoder_hidden_dim * 2 - c4 = c4 if c4 is not None else round(config.hidden_expansion * config.encoder_hidden_dim // 2) - self.conv_dim = c3 // 2 - self.conv1 = Deimv2ConvNormLayer(config, c1, c3, 1, 1, activation=act) - self.csp_rep1 = Deimv2CSPRepLayer2(config, c3 // 2, c4, num_blocks=numb_blocks) - self.csp_rep2 = Deimv2CSPRepLayer2(config, c4, c4, num_blocks=numb_blocks) - self.conv4 = Deimv2ConvNormLayer(config, c3 + (2 * c4), c2, 1, 1, activation=act) + activation = config.activation_function + in_channels = config.encoder_hidden_dim + out_channels = config.encoder_hidden_dim + split_channels = config.encoder_hidden_dim * 2 + csp_channels = round(config.hidden_expansion * config.encoder_hidden_dim // 2) + self.conv_dim = split_channels // 2 + self.conv1 = Deimv2ConvNormLayer(config, in_channels, split_channels, 1, 1, activation=activation) + self.csp_rep1 = Deimv2CSPRepLayer(config, split_channels // 2, csp_channels, num_blocks=numb_blocks) + self.csp_rep2 = Deimv2CSPRepLayer(config, csp_channels, csp_channels, num_blocks=numb_blocks) + self.conv4 = Deimv2ConvNormLayer( + config, split_channels + (2 * csp_channels), out_channels, 1, 1, activation=activation + ) def forward(self, input_features: torch.Tensor) -> torch.Tensor: split_features = list(self.conv1(input_features).split((self.conv_dim, self.conv_dim), 1)) @@ -539,7 +480,7 @@ class Deimv2EncoderLayer(DFineEncoderLayer): pass -class Deimv2AIFILayer(RTDetrAIFILayer): +class Deimv2AIFILayer(DFineAIFILayer): pass @@ -569,11 +510,11 @@ def __init__(self, config: Deimv2Config): ) def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - c1 = self.stem(pixel_values) - c2 = self.conv2(c1) - c3 = self.conv3(c2) - c4 = self.conv4(c3) - return c2, c3, c4 + hidden_states = self.stem(pixel_values) + feature_map_8 = self.conv2(hidden_states) + feature_map_16 = self.conv3(feature_map_8) + feature_map_32 = self.conv4(feature_map_16) + return feature_map_8, feature_map_16, feature_map_32 class Deimv2GAPFusion(nn.Module): @@ -615,12 +556,9 @@ def __init__(self, config: Deimv2Config): self.bi_fusion = Deimv2GAPFusion(config, hidden_dim) - c1, c2 = hidden_dim, hidden_dim - c3 = hidden_dim * 2 - c4 = round(config.hidden_expansion * hidden_dim // 2) num_blocks = round(3 * config.depth_mult) - self.fpn_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks, c1=c1, c2=c2, c3=c3, c4=c4) - self.pan_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks, c1=c1, c2=c2, c3=c3, c4=c4) + self.fpn_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks) + self.pan_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks) def forward(self, inputs_embeds=None, **kwargs) -> BaseModelOutput: feats = inputs_embeds @@ -639,27 +577,36 @@ def forward(self, inputs_embeds=None, **kwargs) -> BaseModelOutput: return BaseModelOutput(last_hidden_state=outs) -class Deimv2DINOv3Backbone(nn.Module): - def __init__(self, config: Deimv2Config): - super().__init__() - vit_config = Deimv2DINOv3ViTConfig(**config.dinov3_backbone_config) +class Deimv2ConvEncoder(DFineConvEncoder): + def __init__(self, config): + super().__init__(config) + if config.encoder_type != "lite": + encoder_input_proj = [] + for in_channel in self.intermediate_channel_sizes: + encoder_input_proj.append( + nn.Sequential( + nn.Conv2d(in_channel, config.encoder_hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(config.encoder_hidden_dim), + ) + ) + self.encoder_input_proj = nn.ModuleList(encoder_input_proj) - self.embeddings = Deimv2DINOv3ViTEmbeddings(vit_config) - self.rope_embeddings = Deimv2DINOv3ViTRopePositionEmbedding(vit_config) - self.layers = nn.ModuleList([Deimv2DINOv3ViTLayer(vit_config) for _ in range(vit_config.num_hidden_layers)]) + def forward(self, pixel_values: torch.Tensor) -> list[torch.Tensor]: + features = self.model(pixel_values).feature_maps + if hasattr(self, "encoder_input_proj"): + return [self.encoder_input_proj[i](feat) for i, feat in enumerate(features)] + return list(features) - self.apply_layernorm = config.dinov3_apply_layernorm - if self.apply_layernorm: - self.norm = nn.LayerNorm(vit_config.hidden_size, eps=vit_config.layer_norm_eps) - self.interaction_indexes = config.dinov3_interaction_indexes - self.patch_size = vit_config.patch_size - self.num_prefix_tokens = 1 + vit_config.num_register_tokens +class Deimv2DINOv3ConvEncoder(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + self.backbone = load_backbone(config) self.sta = Deimv2SpatialTuningAdapter(config) - embed_dim = vit_config.hidden_size - hidden_dim = config.dinov3_hidden_dim or embed_dim + embed_dim = config.backbone_config.hidden_size + hidden_dim = config.encoder_hidden_dim sta_ch = config.sta_inplanes self.convs = nn.ModuleList( [ @@ -671,28 +618,19 @@ def __init__(self, config: Deimv2Config): self.norms = nn.ModuleList([nn.BatchNorm2d(hidden_dim) for _ in range(3)]) def forward(self, pixel_values: torch.Tensor) -> list[torch.Tensor]: - hidden_states = self.embeddings(pixel_values) - position_embeddings = self.rope_embeddings(pixel_values) - - intermediate_outputs = [] - for i, layer in enumerate(self.layers): - hidden_states = layer(hidden_states, position_embeddings=position_embeddings) - if i in self.interaction_indexes: - out = self.norm(hidden_states) if self.apply_layernorm else hidden_states - intermediate_outputs.append(out) + backbone_output = self.backbone(pixel_values) + feature_maps = backbone_output.feature_maps - batch_size = pixel_values.shape[0] - h_patches = pixel_values.shape[2] // self.patch_size - w_patches = pixel_values.shape[3] // self.patch_size + patch_size = self.backbone.config.patch_size + h_patches = pixel_values.shape[2] // patch_size + w_patches = pixel_values.shape[3] // patch_size sem_feats = [] - num_scales = len(intermediate_outputs) - for i, feat in enumerate(intermediate_outputs): - patch_tokens = feat[:, self.num_prefix_tokens :] - spatial = patch_tokens.transpose(1, 2).reshape(batch_size, -1, h_patches, w_patches).contiguous() + num_scales = len(feature_maps) + for i, feat in enumerate(feature_maps): resize_h = int(h_patches * 2 ** (num_scales - 2 - i)) resize_w = int(w_patches * 2 ** (num_scales - 2 - i)) - spatial = F.interpolate(spatial, size=[resize_h, resize_w], mode="bilinear", align_corners=False) + spatial = F.interpolate(feat, size=[resize_h, resize_w], mode="bilinear", align_corners=False) sem_feats.append(spatial) detail_feats = self.sta(pixel_values) @@ -720,6 +658,7 @@ def __init__(self, config: Deimv2Config): self.self_attn_layer_norm = Deimv2RMSNorm(config.d_model) self.final_layer_norm = Deimv2RMSNorm(config.d_model) self.mlp = Deimv2SwiGLUFFN(config.d_model, config.decoder_ffn_dim // 2, config.d_model) + self.use_gateway = config.use_gateway if config.use_gateway: self.gateway = Deimv2Gate(config.d_model) else: @@ -762,7 +701,7 @@ def forward( ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if hasattr(self, "gateway"): + if self.use_gateway: hidden_states = self.gateway(residual, hidden_states) else: hidden_states = residual + hidden_states @@ -771,7 +710,8 @@ def forward( residual = hidden_states hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states.clamp(min=-65504, max=65504)) + clamp_value = torch.finfo(hidden_states.dtype).max + hidden_states = self.final_layer_norm(hidden_states.clamp(min=-clamp_value, max=clamp_value)) return hidden_states @@ -783,100 +723,18 @@ class Deimv2MLPPredictionHead(DFineMLP): class Deimv2PreTrainedModel(DFinePreTrainedModel): @torch.no_grad() def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (Deimv2ForObjectDetection, Deimv2Decoder)): - if module.class_embed is not None: - for layer in module.class_embed: - prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) - bias = float(-math.log((1 - prior_prob) / prior_prob)) - init.xavier_uniform_(layer.weight) - init.constant_(layer.bias, bias) - - if module.bbox_embed is not None: - for layer in module.bbox_embed: - init.constant_(layer.layers[-1].weight, 0) - init.constant_(layer.layers[-1].bias, 0) - - if hasattr(module, "reg_scale"): - init.constant_(module.reg_scale, self.config.reg_scale) - - if hasattr(module, "up"): - init.constant_(module.up, self.config.up) - - if isinstance(module, Deimv2MultiscaleDeformableAttention): - init.constant_(module.sampling_offsets.weight, 0.0) - default_dtype = torch.get_default_dtype() - thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( - 2.0 * math.pi / module.n_heads - ) - grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) - grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values - grid_init = grid_init.reshape(module.n_heads, 1, 2).tile([1, sum(module.num_points_list), 1]) - scaling = torch.concat([torch.arange(1, n + 1) for n in module.num_points_list]).reshape(1, -1, 1) - grid_init *= scaling - init.copy_(module.sampling_offsets.bias, grid_init.flatten()) - - init.constant_(module.attention_weights.weight, 0.0) - init.constant_(module.attention_weights.bias, 0.0) - - num_points_scale = [1 / n for n in module.num_points_list for _ in range(n)] - init.copy_(module.num_points_scale, torch.tensor(num_points_scale, dtype=torch.float32)) - - if isinstance(module, Deimv2Model): - prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) - bias = float(-math.log((1 - prior_prob) / prior_prob)) - init.xavier_uniform_(module.enc_score_head.weight) - init.constant_(module.enc_score_head.bias, bias) - - if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - init.zeros_(module.bias) - if getattr(module, "running_mean", None) is not None: - init.zeros_(module.running_mean) - init.ones_(module.running_var) - init.zeros_(module.num_batches_tracked) - - if isinstance(module, Deimv2Gate): - bias = float(-math.log((1 - 0.5) / 0.5)) - init.constant_(module.gate.bias, bias) - init.constant_(module.gate.weight, 0) - - if isinstance(module, Deimv2LQE): - init.constant_(module.reg_conf.layers[-1].bias, 0) - init.constant_(module.reg_conf.layers[-1].weight, 0) + super()._init_weights(module) if isinstance(module, Deimv2SwiGLUFFN): - init.xavier_uniform_(module.w12.weight) - init.constant_(module.w12.bias, 0) - init.xavier_uniform_(module.w3.weight) - init.constant_(module.w3.bias, 0) + init.xavier_uniform_(module.gate_proj.weight) + init.constant_(module.gate_proj.bias, 0) + init.xavier_uniform_(module.up_proj.weight) + init.constant_(module.up_proj.bias, 0) + init.xavier_uniform_(module.down_proj.weight) + init.constant_(module.down_proj.bias, 0) if isinstance(module, Deimv2RMSNorm): - init.ones_(module.scale) - - if isinstance(module, nn.LayerNorm): init.ones_(module.weight) - init.zeros_(module.bias) - - if isinstance(module, Deimv2DINOv3ViTEmbeddings): - init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) - if module.config.num_register_tokens > 0: - init.trunc_normal_(module.register_tokens, mean=0.0, std=self.config.initializer_range) - init.zeros_(module.mask_token) - - if isinstance(module, Deimv2DINOv3ViTLayerScale) and self.config.dinov3_backbone_config is not None: - layerscale_value = self.config.dinov3_backbone_config.get("layerscale_value", 1.0) - init.constant_(module.lambda1, layerscale_value) - - if isinstance(module, Deimv2DINOv3ViTRopePositionEmbedding): - inv_freq = 1 / module.base ** torch.arange(0, 1, 4 / module.head_dim, dtype=torch.float32) - init.copy_(module.inv_freq, inv_freq) - - if hasattr(module, "weight_embedding") and self.config.learn_initial_query: - init.xavier_uniform_(module.weight_embedding.weight) - if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: - init.xavier_uniform_(module.denoising_class_embed.weight) class Deimv2HybridEncoder(DFineHybridEncoder): @@ -985,23 +843,11 @@ class Deimv2Model(DFineModel): def __init__(self, config: Deimv2Config): Deimv2PreTrainedModel.__init__(self, config) - if config.backbone_type == "dinov3": - self.dinov3_backbone = Deimv2DINOv3Backbone(config) + is_dinov3 = getattr(config.backbone_config, "model_type", None) == "dinov3_vit" + if is_dinov3: + self.backbone = Deimv2DINOv3ConvEncoder(config) else: - from ..d_fine.modeling_d_fine import DFineConvEncoder - - self.backbone = DFineConvEncoder(config) - if config.encoder_type != "lite": - intermediate_channel_sizes = self.backbone.intermediate_channel_sizes - encoder_input_proj = [] - for in_channel in intermediate_channel_sizes: - encoder_input_proj.append( - nn.Sequential( - nn.Conv2d(in_channel, config.encoder_hidden_dim, kernel_size=1, bias=False), - nn.BatchNorm2d(config.encoder_hidden_dim), - ) - ) - self.encoder_input_proj = nn.ModuleList(encoder_input_proj) + self.backbone = Deimv2ConvEncoder(config) if config.encoder_type == "lite": self.encoder = Deimv2LiteEncoder(config) @@ -1046,9 +892,6 @@ def __init__(self, config: Deimv2Config): self.decoder_input_proj = nn.ModuleList(decoder_input_proj) self.decoder = Deimv2Decoder(config) - if config.use_spatial_tuning_adapter and config.backbone_type != "dinov3": - self.spatial_tuning_adapter = Deimv2SpatialTuningAdapter(config) - self.post_init() def forward( @@ -1069,14 +912,7 @@ def forward( if pixel_mask is None: pixel_mask = torch.ones(((batch_size, height, width)), device=device) - if self.config.backbone_type == "dinov3": - proj_feats = self.dinov3_backbone(pixel_values) - elif self.config.encoder_type == "lite": - features = self.backbone(pixel_values, pixel_mask) - proj_feats = [source for source, mask in features] - else: - features = self.backbone(pixel_values, pixel_mask) - proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)] + proj_feats = self.backbone(pixel_values) else: batch_size = inputs_embeds.shape[0] device = inputs_embeds.device @@ -1118,8 +954,6 @@ def forward( level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) if self.training and self.config.num_denoising > 0 and labels is not None: - from ..d_fine.modeling_d_fine import get_contrastive_denoising_training_group - ( denoising_class, denoising_bbox_unact, @@ -1180,8 +1014,6 @@ def forward( init_reference_points = reference_points_unact.detach() - from ..d_fine.modeling_d_fine import DFineModelOutput - decoder_outputs = self.decoder( inputs_embeds=target, encoder_hidden_states=source_flatten, @@ -1193,7 +1025,7 @@ def forward( **kwargs, ) - return DFineModelOutput( + return Deimv2ModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, intermediate_logits=decoder_outputs.intermediate_logits, diff --git a/tests/models/deimv2/test_modeling_deimv2.py b/tests/models/deimv2/test_modeling_deimv2.py index db4ce97086b4..8ec713a524ef 100644 --- a/tests/models/deimv2/test_modeling_deimv2.py +++ b/tests/models/deimv2/test_modeling_deimv2.py @@ -24,6 +24,7 @@ from transformers import ( Deimv2Config, + DINOv3ViTConfig, HGNetV2Config, is_torch_available, is_vision_available, @@ -1136,36 +1137,21 @@ def prepare_config_and_inputs(self): return config, pixel_values, pixel_mask, labels def get_config(self): - hidden_sizes = [64, 128, 256, 512] - backbone_config = HGNetV2Config( - stage_in_channels=[16, 64, 128, 256], - stage_mid_channels=[16, 32, 64, 128], - stage_out_channels=[64, 128, 256, 512], - stage_num_blocks=[1, 1, 2, 1], - stage_downsample=[False, True, True, True], - stage_light_block=[False, False, True, True], - stage_kernel_size=[3, 3, 5, 5], - stage_numb_of_layers=[3, 3, 3, 3], - embeddings_size=10, - hidden_sizes=hidden_sizes, - depths=[1, 1, 2, 1], - out_features=["stage2", "stage3", "stage4"], + backbone_config = DINOv3ViTConfig( + hidden_size=32, + num_attention_heads=2, + num_hidden_layers=4, + intermediate_size=64, + num_register_tokens=1, + layerscale_value=1.0, + use_gated_mlp=False, + rope_theta=100.0, + patch_size=16, + image_size=self.image_size, out_indices=[2, 3, 4], - stem_channels=[3, 16, 16], - use_lab=True, + apply_layernorm=False, + reshape_hidden_states=True, ) - dinov3_backbone_config = { - "hidden_size": 32, - "num_attention_heads": 2, - "num_hidden_layers": 4, - "intermediate_size": 64, - "num_register_tokens": 1, - "layerscale_value": 1.0, - "use_gated_mlp": False, - "rope_theta": 100.0, - "patch_size": 16, - "image_size": self.image_size, - } return Deimv2Config( backbone_config=backbone_config, encoder_hidden_dim=self.encoder_hidden_dim, @@ -1207,13 +1193,7 @@ def get_config(self): image_size=self.image_size, disable_custom_kernels=self.disable_custom_kernels, with_box_refine=self.with_box_refine, - backbone_type="dinov3", - dinov3_backbone_config=dinov3_backbone_config, - dinov3_interaction_indexes=[1, 2, 3], - dinov3_hidden_dim=self.encoder_hidden_dim, - dinov3_apply_layernorm=False, sta_inplanes=self.sta_inplanes, - use_spatial_tuning_adapter=True, encoder_has_trailing_conv=False, ) @@ -1355,6 +1335,10 @@ def test_disk_offload_bin(self): def test_disk_offload_safetensors(self): pass + @unittest.skip(reason="DINOv3 backbone with RoPE and dynamic interpolation causes torch.compile inductor overflow") + def test_sdpa_can_compile_dynamic(self): + pass + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) def test_eager_matches_sdpa_inference( self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 413f5e6d61fa..345d8172c049 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -97,7 +97,6 @@ "DetrConfig": True, "DFineConfig": True, "Deimv2Config": True, - "Deimv2DINOv3ViTConfig": True, "GroundingDinoConfig": True, "MMGroundingDinoConfig": True, "RTDetrConfig": True, From 4ad0dc5bd9e4c0ab04a150ca77e05dbd9df3c07f Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Thu, 19 Mar 2026 15:58:56 +0400 Subject: [PATCH 06/17] refactor: Resolve second review round --- docs/source/en/model_doc/deimv2.md | 2 +- src/transformers/models/deimv2/__init__.py | 2 +- .../models/deimv2/configuration_deimv2.py | 542 +++++++----------- ...eimv2_original_pytorch_checkpoint_to_hf.py | 136 ++--- .../models/deimv2/modeling_deimv2.py | 49 +- .../models/deimv2/modular_deimv2.py | 491 +++++----------- tests/models/deimv2/test_modeling_deimv2.py | 148 ++--- 7 files changed, 519 insertions(+), 851 deletions(-) diff --git a/docs/source/en/model_doc/deimv2.md b/docs/source/en/model_doc/deimv2.md index 3d4e4a4a77b1..67cb46398921 100644 --- a/docs/source/en/model_doc/deimv2.md +++ b/docs/source/en/model_doc/deimv2.md @@ -1,4 +1,4 @@ - +*This model was released on 2025-09-25 and added to Hugging Face Transformers on 2026-04-22.* # DEIMv2 diff --git a/src/transformers/loss/loss_deimv2.py b/src/transformers/loss/loss_deimv2.py index 083e32fd2a82..88c1aa94d1fa 100644 --- a/src/transformers/loss/loss_deimv2.py +++ b/src/transformers/loss/loss_deimv2.py @@ -109,16 +109,29 @@ def get_loss(self, loss, outputs, targets, indices, num_boxes): return loss_map[loss](outputs, targets, indices, num_boxes) def forward(self, outputs, targets): + """ + This performs the loss computation. + + Args: + outputs (`dict`, *optional*): + Dictionary of tensors, see the output specification of the model for the format. + targets (`list[dict]`, *optional*): + List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the + losses applied, see each loss' doc. + """ if not self.use_dense_one_to_one: return super().forward(outputs, targets) + # Retrieve the matching between the outputs of the last layer and the targets outputs_without_aux = {k: v for k, v in outputs.items() if "auxiliary_outputs" not in k} indices = self.matcher(outputs_without_aux, targets) + # Compute the average number of target boxes across all nodes, for normalization purposes num_boxes = sum(len(t["class_labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) num_boxes = torch.clamp(num_boxes, min=1).item() + # Handle auxiliary outputs matching cached_indices = [] indices_aux_list = [] if "auxiliary_outputs" in outputs: @@ -127,11 +140,13 @@ def forward(self, outputs, targets): cached_indices.append(aux_indices) indices_aux_list.append(aux_indices) + # Dense one-to-one matching indices_go = self._get_dense_o2o_indices(indices, indices_aux_list) num_boxes_go = sum(len(x[0]) for x in indices_go) num_boxes_go = torch.as_tensor([num_boxes_go], dtype=torch.float, device=next(iter(outputs.values())).device) num_boxes_go = torch.clamp(num_boxes_go, min=1).item() + # Compute all the requested losses losses = {} for loss in self.losses: use_union = loss in ("boxes", "local") @@ -141,6 +156,7 @@ def forward(self, outputs, targets): l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} losses.update(l_dict) + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if "auxiliary_outputs" in outputs: for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]): for loss in self.losses: @@ -152,6 +168,7 @@ def forward(self, outputs, targets): l_dict = {k + f"_aux_{i}": v for k, v in l_dict.items()} losses.update(l_dict) + # In case of cdn auxiliary losses. For deimv2 if "dn_auxiliary_outputs" in outputs: if "denoising_meta_values" not in outputs: raise ValueError( @@ -187,55 +204,50 @@ def Deimv2ForObjectDetectionLoss( ): criterion = Deimv2Loss(config) criterion.to(device) - outputs_loss = {} + + outputs_loss = {"logits": logits, "pred_boxes": pred_boxes.clamp(min=0, max=1)} auxiliary_outputs = None - if config.auxiliary_loss: - if denoising_meta_values is not None: - dn_out_coord, outputs_coord = torch.split( - outputs_coord.clamp(min=0, max=1), denoising_meta_values["dn_num_split"], dim=2 - ) - dn_out_class, outputs_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2) - # https://github.com/Intellindust-AI-Lab/DEIMv2/blob/main/engine/deim/deim_decoder.py#L562-L571 - # The original splits denoising queries in the decoder; here it happens in the loss since the decoder returns unsplit tensors. - _, logits = torch.split(logits, denoising_meta_values["dn_num_split"], dim=1) - _, pred_boxes = torch.split(pred_boxes, denoising_meta_values["dn_num_split"], dim=1) - dn_out_corners, out_corners = torch.split(predicted_corners, denoising_meta_values["dn_num_split"], dim=2) - dn_out_refs, out_refs = torch.split(initial_reference_points, denoising_meta_values["dn_num_split"], dim=2) - - outputs_loss["logits"] = logits - outputs_loss["pred_boxes"] = pred_boxes.clamp(min=0, max=1) - - auxiliary_outputs = _set_aux_loss2( - outputs_class[:, :-1].transpose(0, 1), - outputs_coord[:, :-1].transpose(0, 1), - out_corners[:, :-1].transpose(0, 1), - out_refs[:, :-1].transpose(0, 1), - out_corners[:, -1], - outputs_class[:, -1], - ) - - outputs_loss["auxiliary_outputs"] = auxiliary_outputs - outputs_loss["auxiliary_outputs"].extend( - _set_aux_loss([enc_topk_logits], [enc_topk_bboxes.clamp(min=0, max=1)]) - ) - - dn_auxiliary_outputs = _set_aux_loss2( - dn_out_class.transpose(0, 1), - dn_out_coord.transpose(0, 1), - dn_out_corners.transpose(0, 1), - dn_out_refs.transpose(0, 1), - dn_out_corners[:, -1], - dn_out_class[:, -1], - ) - outputs_loss["dn_auxiliary_outputs"] = dn_auxiliary_outputs - outputs_loss["denoising_meta_values"] = denoising_meta_values - else: - outputs_loss["logits"] = logits - outputs_loss["pred_boxes"] = pred_boxes.clamp(min=0, max=1) - else: + + if config.auxiliary_loss and denoising_meta_values is not None: + dn_out_coord, outputs_coord = torch.split( + outputs_coord.clamp(min=0, max=1), denoising_meta_values["dn_num_split"], dim=2 + ) + dn_out_class, outputs_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2) + # https://github.com/Intellindust-AI-Lab/DEIMv2/blob/main/engine/deim/deim_decoder.py#L562-L571 + # The original splits denoising queries in the decoder; here it happens in the loss since the decoder returns unsplit tensors. + _, logits = torch.split(logits, denoising_meta_values["dn_num_split"], dim=1) + _, pred_boxes = torch.split(pred_boxes, denoising_meta_values["dn_num_split"], dim=1) + dn_out_corners, out_corners = torch.split(predicted_corners, denoising_meta_values["dn_num_split"], dim=2) + dn_out_refs, out_refs = torch.split(initial_reference_points, denoising_meta_values["dn_num_split"], dim=2) + outputs_loss["logits"] = logits outputs_loss["pred_boxes"] = pred_boxes.clamp(min=0, max=1) + auxiliary_outputs = _set_aux_loss2( + outputs_class[:, :-1].transpose(0, 1), + outputs_coord[:, :-1].transpose(0, 1), + out_corners[:, :-1].transpose(0, 1), + out_refs[:, :-1].transpose(0, 1), + out_corners[:, -1], + outputs_class[:, -1], + ) + + outputs_loss["auxiliary_outputs"] = auxiliary_outputs + outputs_loss["auxiliary_outputs"].extend( + _set_aux_loss([enc_topk_logits], [enc_topk_bboxes.clamp(min=0, max=1)]) + ) + + dn_auxiliary_outputs = _set_aux_loss2( + dn_out_class.transpose(0, 1), + dn_out_coord.transpose(0, 1), + dn_out_corners.transpose(0, 1), + dn_out_refs.transpose(0, 1), + dn_out_corners[:, -1], + dn_out_class[:, -1], + ) + outputs_loss["dn_auxiliary_outputs"] = dn_auxiliary_outputs + outputs_loss["denoising_meta_values"] = denoising_meta_values + loss_dict = criterion(outputs_loss, labels) loss = sum(loss_dict.values()) diff --git a/src/transformers/models/deimv2/modeling_deimv2.py b/src/transformers/models/deimv2/modeling_deimv2.py index 04dd673ed826..070592ff08f0 100644 --- a/src/transformers/models/deimv2/modeling_deimv2.py +++ b/src/transformers/models/deimv2/modeling_deimv2.py @@ -1158,6 +1158,7 @@ def _init_weights(self, module): init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: init.xavier_uniform_(module.denoising_class_embed.weight) + PreTrainedModel._init_weights(self, module) if isinstance(module, Deimv2SwiGLUFFN): init.xavier_uniform_(module.gate_proj.weight) @@ -1167,9 +1168,6 @@ def _init_weights(self, module): init.xavier_uniform_(module.down_proj.weight) init.constant_(module.down_proj.bias, 0) - if isinstance(module, Deimv2RMSNorm): - init.ones_(module.weight) - class Deimv2LiteEncoder(Deimv2PreTrainedModel): # LiteEncoder has no transformer layers, so hidden_states are recorded from the conv projections. @@ -2023,8 +2021,6 @@ class Deimv2ObjectDetectionOutput(ModelOutput): """ ) class Deimv2ForObjectDetection(Deimv2PreTrainedModel): - # When using clones, all layers > 0 will be clones, but layer 0 *is* required - # We can't initialize the model on meta device as some weights are modified during the initialization _no_split_modules = None _tied_weights_keys = { r"bbox_embed.(?![0])\d+": r"bbox_embed.0", diff --git a/src/transformers/models/deimv2/modular_deimv2.py b/src/transformers/models/deimv2/modular_deimv2.py index 59ad178d8278..271f2756df67 100644 --- a/src/transformers/models/deimv2/modular_deimv2.py +++ b/src/transformers/models/deimv2/modular_deimv2.py @@ -22,6 +22,7 @@ from ... import initialization as init from ...backbone_utils import load_backbone from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging from ...utils.generic import merge_with_config_defaults @@ -491,6 +492,7 @@ class Deimv2PreTrainedModel(DFinePreTrainedModel): @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) + PreTrainedModel._init_weights(self, module) if isinstance(module, Deimv2SwiGLUFFN): init.xavier_uniform_(module.gate_proj.weight) @@ -500,9 +502,6 @@ def _init_weights(self, module): init.xavier_uniform_(module.down_proj.weight) init.constant_(module.down_proj.bias, 0) - if isinstance(module, Deimv2RMSNorm): - init.ones_(module.weight) - class Deimv2LiteEncoder(Deimv2PreTrainedModel): # LiteEncoder has no transformer layers, so hidden_states are recorded from the conv projections. @@ -854,6 +853,8 @@ def forward( class Deimv2ForObjectDetection(DFineForObjectDetection): + _no_split_modules = None + @property def _tied_weights_keys(self): keys = { From 943f4bb7dcacbeb0424f21ed77ac00511622b7d7 Mon Sep 17 00:00:00 2001 From: vasqu Date: Wed, 22 Apr 2026 13:07:32 +0200 Subject: [PATCH 16/17] fixup their init weights --- .../models/d_fine/modeling_d_fine.py | 14 +------------- src/transformers/models/d_fine/modular_d_fine.py | 15 ++------------- .../models/deimv2/modeling_deimv2.py | 16 +--------------- src/transformers/models/deimv2/modular_deimv2.py | 4 +--- 4 files changed, 5 insertions(+), 44 deletions(-) diff --git a/src/transformers/models/d_fine/modeling_d_fine.py b/src/transformers/models/d_fine/modeling_d_fine.py index 1c758f8b1dcd..f1d23356fb2b 100644 --- a/src/transformers/models/d_fine/modeling_d_fine.py +++ b/src/transformers/models/d_fine/modeling_d_fine.py @@ -874,6 +874,7 @@ class DFinePreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" + super()._init_weights(module) # initialize linear layer bias value according to a given probability value. if isinstance(module, (DFineForObjectDetection, DFineDecoder)): if module.class_embed is not None: @@ -919,15 +920,6 @@ def _init_weights(self, module): init.xavier_uniform_(module.enc_score_head.weight) init.constant_(module.enc_score_head.bias, bias) - if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - init.zeros_(module.bias) - if getattr(module, "running_mean", None) is not None: - init.zeros_(module.running_mean) - init.ones_(module.running_var) - init.zeros_(module.num_batches_tracked) - if isinstance(module, DFineGate): bias = float(-math.log((1 - 0.5) / 0.5)) init.constant_(module.gate.bias, bias) @@ -937,10 +929,6 @@ def _init_weights(self, module): init.constant_(module.reg_conf.layers[-1].bias, 0) init.constant_(module.reg_conf.layers[-1].weight, 0) - if isinstance(module, nn.LayerNorm): - init.ones_(module.weight) - init.zeros_(module.bias) - if hasattr(module, "weight_embedding") and self.config.learn_initial_query: init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: diff --git a/src/transformers/models/d_fine/modular_d_fine.py b/src/transformers/models/d_fine/modular_d_fine.py index ba5798ad93cb..49289f075037 100644 --- a/src/transformers/models/d_fine/modular_d_fine.py +++ b/src/transformers/models/d_fine/modular_d_fine.py @@ -23,6 +23,7 @@ from ...backbone_utils import consolidate_backbone_kwargs_to_config from ...configuration_utils import PreTrainedConfig from ...image_transforms import corners_to_center_format +from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging, torch_compilable_check from ..auto import AutoConfig @@ -678,6 +679,7 @@ class DFinePreTrainedModel(RTDetrPreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" + PreTrainedModel._init_weights(self, module) # initialize linear layer bias value according to a given probability value. if isinstance(module, (DFineForObjectDetection, DFineDecoder)): if module.class_embed is not None: @@ -723,15 +725,6 @@ def _init_weights(self, module): init.xavier_uniform_(module.enc_score_head.weight) init.constant_(module.enc_score_head.bias, bias) - if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - init.zeros_(module.bias) - if getattr(module, "running_mean", None) is not None: - init.zeros_(module.running_mean) - init.ones_(module.running_var) - init.zeros_(module.num_batches_tracked) - if isinstance(module, DFineGate): bias = float(-math.log((1 - 0.5) / 0.5)) init.constant_(module.gate.bias, bias) @@ -741,10 +734,6 @@ def _init_weights(self, module): init.constant_(module.reg_conf.layers[-1].bias, 0) init.constant_(module.reg_conf.layers[-1].weight, 0) - if isinstance(module, nn.LayerNorm): - init.ones_(module.weight) - init.zeros_(module.bias) - if hasattr(module, "weight_embedding") and self.config.learn_initial_query: init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: diff --git a/src/transformers/models/deimv2/modeling_deimv2.py b/src/transformers/models/deimv2/modeling_deimv2.py index 070592ff08f0..fe0f002890c5 100644 --- a/src/transformers/models/deimv2/modeling_deimv2.py +++ b/src/transformers/models/deimv2/modeling_deimv2.py @@ -1087,6 +1087,7 @@ class Deimv2PreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" + super()._init_weights(module) # initialize linear layer bias value according to a given probability value. if isinstance(module, (Deimv2ForObjectDetection, Deimv2Decoder)): if module.class_embed is not None: @@ -1132,15 +1133,6 @@ def _init_weights(self, module): init.xavier_uniform_(module.enc_score_head.weight) init.constant_(module.enc_score_head.bias, bias) - if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - init.zeros_(module.bias) - if getattr(module, "running_mean", None) is not None: - init.zeros_(module.running_mean) - init.ones_(module.running_var) - init.zeros_(module.num_batches_tracked) - if isinstance(module, Deimv2Gate): bias = float(-math.log((1 - 0.5) / 0.5)) init.constant_(module.gate.bias, bias) @@ -1150,15 +1142,10 @@ def _init_weights(self, module): init.constant_(module.reg_conf.layers[-1].bias, 0) init.constant_(module.reg_conf.layers[-1].weight, 0) - if isinstance(module, nn.LayerNorm): - init.ones_(module.weight) - init.zeros_(module.bias) - if hasattr(module, "weight_embedding") and self.config.learn_initial_query: init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: init.xavier_uniform_(module.denoising_class_embed.weight) - PreTrainedModel._init_weights(self, module) if isinstance(module, Deimv2SwiGLUFFN): init.xavier_uniform_(module.gate_proj.weight) @@ -2021,7 +2008,6 @@ class Deimv2ObjectDetectionOutput(ModelOutput): """ ) class Deimv2ForObjectDetection(Deimv2PreTrainedModel): - _no_split_modules = None _tied_weights_keys = { r"bbox_embed.(?![0])\d+": r"bbox_embed.0", r"class_embed.(?![0])\d+": r"^class_embed.0", diff --git a/src/transformers/models/deimv2/modular_deimv2.py b/src/transformers/models/deimv2/modular_deimv2.py index 271f2756df67..e675fd12114c 100644 --- a/src/transformers/models/deimv2/modular_deimv2.py +++ b/src/transformers/models/deimv2/modular_deimv2.py @@ -22,7 +22,6 @@ from ... import initialization as init from ...backbone_utils import load_backbone from ...modeling_outputs import ModelOutput -from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging from ...utils.generic import merge_with_config_defaults @@ -492,7 +491,6 @@ class Deimv2PreTrainedModel(DFinePreTrainedModel): @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) - PreTrainedModel._init_weights(self, module) if isinstance(module, Deimv2SwiGLUFFN): init.xavier_uniform_(module.gate_proj.weight) @@ -853,7 +851,7 @@ def forward( class Deimv2ForObjectDetection(DFineForObjectDetection): - _no_split_modules = None + _no_split_modules = AttributeError() # Don't have the same restriction as DFine @property def _tied_weights_keys(self): From fb1f387c209f0c0374b0579b7f78017cdc34cbd3 Mon Sep 17 00:00:00 2001 From: Harshal Janjani Date: Thu, 23 Apr 2026 19:00:16 +0400 Subject: [PATCH 17/17] fix: Fix loss coupling issue --- src/transformers/loss/loss_deimv2.py | 61 ++++++++++++++++------------ 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/src/transformers/loss/loss_deimv2.py b/src/transformers/loss/loss_deimv2.py index 88c1aa94d1fa..5c8573f7da44 100644 --- a/src/transformers/loss/loss_deimv2.py +++ b/src/transformers/loss/loss_deimv2.py @@ -208,28 +208,34 @@ def Deimv2ForObjectDetectionLoss( outputs_loss = {"logits": logits, "pred_boxes": pred_boxes.clamp(min=0, max=1)} auxiliary_outputs = None - if config.auxiliary_loss and denoising_meta_values is not None: - dn_out_coord, outputs_coord = torch.split( - outputs_coord.clamp(min=0, max=1), denoising_meta_values["dn_num_split"], dim=2 - ) - dn_out_class, outputs_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2) - # https://github.com/Intellindust-AI-Lab/DEIMv2/blob/main/engine/deim/deim_decoder.py#L562-L571 - # The original splits denoising queries in the decoder; here it happens in the loss since the decoder returns unsplit tensors. - _, logits = torch.split(logits, denoising_meta_values["dn_num_split"], dim=1) - _, pred_boxes = torch.split(pred_boxes, denoising_meta_values["dn_num_split"], dim=1) - dn_out_corners, out_corners = torch.split(predicted_corners, denoising_meta_values["dn_num_split"], dim=2) - dn_out_refs, out_refs = torch.split(initial_reference_points, denoising_meta_values["dn_num_split"], dim=2) - - outputs_loss["logits"] = logits - outputs_loss["pred_boxes"] = pred_boxes.clamp(min=0, max=1) + if config.auxiliary_loss: + if denoising_meta_values is not None: + dn_out_coord, normal_out_coord = torch.split( + outputs_coord.clamp(min=0, max=1), denoising_meta_values["dn_num_split"], dim=2 + ) + dn_out_class, normal_out_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2) + # https://github.com/Intellindust-AI-Lab/DEIMv2/blob/main/engine/deim/deim_decoder.py#L562-L571 + # The original splits denoising queries in the decoder; here it happens in the loss since the decoder returns unsplit tensors. + _, normal_logits = torch.split(logits, denoising_meta_values["dn_num_split"], dim=1) + _, normal_pred_boxes = torch.split(pred_boxes, denoising_meta_values["dn_num_split"], dim=1) + dn_out_corners, out_corners = torch.split(predicted_corners, denoising_meta_values["dn_num_split"], dim=2) + dn_out_refs, out_refs = torch.split(initial_reference_points, denoising_meta_values["dn_num_split"], dim=2) + + outputs_loss["logits"] = normal_logits + outputs_loss["pred_boxes"] = normal_pred_boxes.clamp(min=0, max=1) + else: + normal_out_coord = outputs_coord.clamp(min=0, max=1) + normal_out_class = outputs_class + out_corners = predicted_corners + out_refs = initial_reference_points auxiliary_outputs = _set_aux_loss2( - outputs_class[:, :-1].transpose(0, 1), - outputs_coord[:, :-1].transpose(0, 1), + normal_out_class[:, :-1].transpose(0, 1), + normal_out_coord[:, :-1].transpose(0, 1), out_corners[:, :-1].transpose(0, 1), out_refs[:, :-1].transpose(0, 1), out_corners[:, -1], - outputs_class[:, -1], + normal_out_class[:, -1], ) outputs_loss["auxiliary_outputs"] = auxiliary_outputs @@ -237,16 +243,17 @@ def Deimv2ForObjectDetectionLoss( _set_aux_loss([enc_topk_logits], [enc_topk_bboxes.clamp(min=0, max=1)]) ) - dn_auxiliary_outputs = _set_aux_loss2( - dn_out_class.transpose(0, 1), - dn_out_coord.transpose(0, 1), - dn_out_corners.transpose(0, 1), - dn_out_refs.transpose(0, 1), - dn_out_corners[:, -1], - dn_out_class[:, -1], - ) - outputs_loss["dn_auxiliary_outputs"] = dn_auxiliary_outputs - outputs_loss["denoising_meta_values"] = denoising_meta_values + if denoising_meta_values is not None: + dn_auxiliary_outputs = _set_aux_loss2( + dn_out_class.transpose(0, 1), + dn_out_coord.transpose(0, 1), + dn_out_corners.transpose(0, 1), + dn_out_refs.transpose(0, 1), + dn_out_corners[:, -1], + dn_out_class[:, -1], + ) + outputs_loss["dn_auxiliary_outputs"] = dn_auxiliary_outputs + outputs_loss["denoising_meta_values"] = denoising_meta_values loss_dict = criterion(outputs_loss, labels)