From eaef822134ab632f7b78ea558b0c05b570dfa77d Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Fri, 27 Feb 2026 22:06:58 +0400 Subject: [PATCH 01/25] init: Add files (v1) --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/deimv2.md | 68 + src/transformers/loss/loss_deimv2.py | 167 ++ src/transformers/loss/loss_utils.py | 2 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + .../models/auto/image_processing_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 2 + src/transformers/models/deimv2/__init__.py | 29 + .../models/deimv2/configuration_deimv2.py | 365 +++ ...eimv2_original_pytorch_checkpoint_to_hf.py | 451 ++++ .../models/deimv2/modeling_deimv2.py | 2098 +++++++++++++++++ .../models/deimv2/modular_deimv2.py | 819 +++++++ tests/models/deimv2/__init__.py | 0 tests/models/deimv2/test_modeling_deimv2.py | 655 +++++ 15 files changed, 4662 insertions(+) create mode 100644 docs/source/en/model_doc/deimv2.md create mode 100644 src/transformers/loss/loss_deimv2.py create mode 100644 src/transformers/models/deimv2/__init__.py create mode 100644 src/transformers/models/deimv2/configuration_deimv2.py create mode 100644 src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py create mode 100644 src/transformers/models/deimv2/modeling_deimv2.py create mode 100644 src/transformers/models/deimv2/modular_deimv2.py create mode 100644 tests/models/deimv2/__init__.py create mode 100644 tests/models/deimv2/test_modeling_deimv2.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 61f34ccb891c..1668ed109a40 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -842,6 +842,8 @@ title: DAB-DETR - local: model_doc/deformable_detr title: Deformable DETR + - local: model_doc/deimv2 + title: DEIMv2 - local: model_doc/deit title: DeiT - local: model_doc/depth_anything diff --git a/docs/source/en/model_doc/deimv2.md b/docs/source/en/model_doc/deimv2.md new file mode 100644 index 000000000000..3d4e4a4a77b1 --- /dev/null +++ b/docs/source/en/model_doc/deimv2.md @@ -0,0 +1,68 @@ + + +# DEIMv2 + +## Overview + +DEIMv2 (DETR with Improved Matching v2) was proposed in [DEIMv2: Real-Time Object Detection Meets DINOv3](https://huggingface.co/papers/2509.20787) by Shihua Huang, Yongjie Hou, Longfei Liu, Xuanlong Yu, and Xi Shen. + +DEIMv2 builds upon D-FINE's distribution-based bounding box refinement approach, adding several key innovations: +- **SwiGLU FFN**: Replaces the standard MLP in decoder layers with a SwiGLU-gated feed-forward network. +- **RMSNorm**: Uses RMSNorm instead of LayerNorm in decoder layers for improved training stability. +- **RepNCSPELAN5**: An enhanced 5-branch CSP-ELAN encoder block (vs D-Fine's 4-branch RepNCSPELAN4). +- **Matching Auxiliary Loss (MAL)**: A focal-style BCE loss with IoU-weighted targets replacing VFL. +- **Dense O2O Matching**: Unified matching across decoder layers for improved training convergence. + +## Usage + +```python +from transformers import AutoImageProcessor, Deimv2ForObjectDetection +from transformers.image_utils import load_image + +url = "http://images.cocodataset.org/val2017/000000039769.jpg" +image = load_image(url) + +# TODO: Replace with Transformers-compatible ckpts once uploaded. +image_processor = AutoImageProcessor.from_pretrained("Intellindust/DEIMv2_HGNetv2_N_COCO") +model = Deimv2ForObjectDetection.from_pretrained("Intellindust/DEIMv2_HGNetv2_N_COCO") + +inputs = image_processor(images=image, return_tensors="pt") +outputs = model(**inputs) + +results = image_processor.post_process_object_detection( + outputs, threshold=0.5, target_sizes=[image.size[::-1]] +) + +for result in results: + for score, label, box in zip(result["scores"], result["labels"], result["boxes"]): + box = [round(i, 2) for i in box.tolist()] + print(f"Detected {model.config.id2label[label.item()]} with confidence {round(score.item(), 3)} at location {box}") +``` + +## Deimv2Config + +[[autodoc]] Deimv2Config + +## Deimv2Model + +[[autodoc]] Deimv2Model + - forward + +## Deimv2ForObjectDetection + +[[autodoc]] Deimv2ForObjectDetection + - forward diff --git a/src/transformers/loss/loss_deimv2.py b/src/transformers/loss/loss_deimv2.py new file mode 100644 index 000000000000..ec71299a3d0a --- /dev/null +++ b/src/transformers/loss/loss_deimv2.py @@ -0,0 +1,167 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn.functional as F + +from ..utils import is_vision_available +from .loss_d_fine import DFineLoss, _set_aux_loss, _set_aux_loss2 +from .loss_for_object_detection import box_iou + + +if is_vision_available(): + from transformers.image_transforms import center_to_corners_format + + +class Deimv2Loss(DFineLoss): + def __init__(self, config): + super().__init__(config) + self.weight_dict = { + "loss_mal": config.weight_loss_mal, + "loss_bbox": config.weight_loss_bbox, + "loss_giou": config.weight_loss_giou, + "loss_fgl": config.weight_loss_fgl, + "loss_ddf": config.weight_loss_ddf, + } + self.losses = ["mal", "boxes", "local"] + self.mal_alpha = config.mal_alpha + self.use_dense_o2o = config.use_dense_o2o + + def loss_labels_mal(self, outputs, targets, indices, num_boxes): + idx = self._get_source_permutation_idx(indices) + + src_boxes = outputs["pred_boxes"][idx] + target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) + ious, _ = box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes)) + ious = torch.diag(ious).detach() + + src_logits = outputs["logits"] + target_classes_o = torch.cat([t["class_labels"][i] for t, (_, i) in zip(targets, indices)]) + target_classes = torch.full( + src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device + ) + target_classes[idx] = target_classes_o + target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1] + + target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype) + target_score_o[idx] = ious.to(target_score_o.dtype) + target_score = target_score_o.unsqueeze(-1) * target + + pred_score = F.sigmoid(src_logits).detach() + target_score = target_score.pow(self.gamma) + if self.mal_alpha is not None: + weight = self.mal_alpha * pred_score.pow(self.gamma) * (1 - target) + target + else: + weight = pred_score.pow(self.gamma) * (1 - target) + target + + loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction="none") + loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes + return {"loss_mal": loss} + + def _get_dense_o2o_indices(self, indices, indices_aux_list): + results = [] + for indices_aux in indices_aux_list: + indices = [ + (torch.cat([idx1[0], idx2[0]]), torch.cat([idx1[1], idx2[1]])) + for idx1, idx2 in zip(indices.copy(), indices_aux.copy()) + ] + + for ind in [torch.cat([idx[0][:, None], idx[1][:, None]], 1) for idx in indices]: + unique, counts = torch.unique(ind, return_counts=True, dim=0) + count_sort_indices = torch.argsort(counts, descending=True) + unique_sorted = unique[count_sort_indices] + column_to_row = {} + for idx_pair in unique_sorted: + row_idx, col_idx = idx_pair[0].item(), idx_pair[1].item() + if row_idx not in column_to_row: + column_to_row[row_idx] = col_idx + final_rows = torch.tensor(list(column_to_row.keys()), device=ind.device) + final_cols = torch.tensor(list(column_to_row.values()), device=ind.device) + results.append((final_rows.long(), final_cols.long())) + return results + + def get_loss(self, loss, outputs, targets, indices, num_boxes): + loss_map = { + "cardinality": self.loss_cardinality, + "local": self.loss_local, + "boxes": self.loss_boxes, + "focal": self.loss_labels_focal, + "vfl": self.loss_labels_vfl, + "mal": self.loss_labels_mal, + } + if loss not in loss_map: + raise ValueError(f"Loss {loss} not supported") + return loss_map[loss](outputs, targets, indices, num_boxes) + + +def Deimv2ForObjectDetectionLoss( + logits, + labels, + device, + pred_boxes, + config, + outputs_class=None, + outputs_coord=None, + enc_topk_logits=None, + enc_topk_bboxes=None, + denoising_meta_values=None, + predicted_corners=None, + initial_reference_points=None, + **kwargs, +): + criterion = Deimv2Loss(config) + criterion.to(device) + outputs_loss = {} + outputs_loss["logits"] = logits + outputs_loss["pred_boxes"] = pred_boxes.clamp(min=0, max=1) + auxiliary_outputs = None + if config.auxiliary_loss: + if denoising_meta_values is not None: + dn_out_coord, outputs_coord = torch.split( + outputs_coord.clamp(min=0, max=1), denoising_meta_values["dn_num_split"], dim=2 + ) + dn_out_class, outputs_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2) + dn_out_corners, out_corners = torch.split(predicted_corners, denoising_meta_values["dn_num_split"], dim=2) + dn_out_refs, out_refs = torch.split(initial_reference_points, denoising_meta_values["dn_num_split"], dim=2) + + auxiliary_outputs = _set_aux_loss2( + outputs_class[:, :-1].transpose(0, 1), + outputs_coord[:, :-1].transpose(0, 1), + out_corners[:, :-1].transpose(0, 1), + out_refs[:, :-1].transpose(0, 1), + out_corners[:, -1], + outputs_class[:, -1], + ) + + outputs_loss["auxiliary_outputs"] = auxiliary_outputs + outputs_loss["auxiliary_outputs"].extend( + _set_aux_loss([enc_topk_logits], [enc_topk_bboxes.clamp(min=0, max=1)]) + ) + + dn_auxiliary_outputs = _set_aux_loss2( + dn_out_class.transpose(0, 1), + dn_out_coord.transpose(0, 1), + dn_out_corners.transpose(0, 1), + dn_out_refs.transpose(0, 1), + dn_out_corners[:, -1], + dn_out_class[:, -1], + ) + outputs_loss["dn_auxiliary_outputs"] = dn_auxiliary_outputs + outputs_loss["denoising_meta_values"] = denoising_meta_values + + loss_dict = criterion(outputs_loss, labels) + + loss = sum(loss_dict.values()) + return loss, loss_dict, auxiliary_outputs diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index df269477e9ec..51564d299e55 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -19,6 +19,7 @@ from .loss_d_fine import DFineForObjectDetectionLoss from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss +from .loss_deimv2 import Deimv2ForObjectDetectionLoss from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss from .loss_grounding_dino import GroundingDinoForObjectDetectionLoss from .loss_lw_detr import LwDetrForObjectDetectionLoss @@ -163,6 +164,7 @@ def ForTokenClassification(logits: torch.Tensor, labels, config, **kwargs): "RTDetrForObjectDetection": RTDetrForObjectDetectionLoss, "RTDetrV2ForObjectDetection": RTDetrForObjectDetectionLoss, "DFineForObjectDetection": DFineForObjectDetectionLoss, + "Deimv2ForObjectDetection": Deimv2ForObjectDetectionLoss, "CsmForConditionalGeneration": ForCausalLMLoss, "LwDetrForObjectDetection": LwDetrForObjectDetectionLoss, } diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 791a23149934..31a949e80738 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -94,6 +94,7 @@ from .deepseek_vl import * from .deepseek_vl_hybrid import * from .deformable_detr import * + from .deimv2 import * from .deit import * from .deprecated import * from .depth_anything import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 4a493e759d16..3a7ca87b3dce 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -112,6 +112,7 @@ ("deepseek_vl", "DeepseekVLConfig"), ("deepseek_vl_hybrid", "DeepseekVLHybridConfig"), ("deformable_detr", "DeformableDetrConfig"), + ("deimv2", "Deimv2Config"), ("deit", "DeiTConfig"), ("depth_anything", "DepthAnythingConfig"), ("depth_pro", "DepthProConfig"), @@ -595,6 +596,7 @@ ("deepseek_vl", "DeepseekVL"), ("deepseek_vl_hybrid", "DeepseekVLHybrid"), ("deformable_detr", "Deformable DETR"), + ("deimv2", "DEIMv2"), ("deit", "DeiT"), ("deplot", "DePlot"), ("depth_anything", "Depth Anything"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index e4d8f08963d6..512992a6ddf0 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -88,6 +88,7 @@ ("deepseek_vl", ("DeepseekVLImageProcessor", "DeepseekVLImageProcessorFast")), ("deepseek_vl_hybrid", ("DeepseekVLHybridImageProcessor", "DeepseekVLHybridImageProcessorFast")), ("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")), + ("deimv2", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")), ("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")), ("depth_anything", ("DPTImageProcessor", "DPTImageProcessorFast")), ("depth_pro", ("DepthProImageProcessor", "DepthProImageProcessorFast")), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 1fbab01f684b..34854698f6d6 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -116,6 +116,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("deepseek_vl", "DeepseekVLModel"), ("deepseek_vl_hybrid", "DeepseekVLHybridModel"), ("deformable_detr", "DeformableDetrModel"), + ("deimv2", "Deimv2Model"), ("deit", "DeiTModel"), ("depth_pro", "DepthProModel"), ("detr", "DetrModel"), @@ -1068,6 +1069,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("d_fine", "DFineForObjectDetection"), ("dab-detr", "DabDetrForObjectDetection"), ("deformable_detr", "DeformableDetrForObjectDetection"), + ("deimv2", "Deimv2ForObjectDetection"), ("detr", "DetrForObjectDetection"), ("lw_detr", "LwDetrForObjectDetection"), ("pp_doclayout_v2", "PPDocLayoutV2ForObjectDetection"), diff --git a/src/transformers/models/deimv2/__init__.py b/src/transformers/models/deimv2/__init__.py new file mode 100644 index 000000000000..07caf7c91851 --- /dev/null +++ b/src/transformers/models/deimv2/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_deimv2 import * + from .modeling_deimv2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/deimv2/configuration_deimv2.py b/src/transformers/models/deimv2/configuration_deimv2.py new file mode 100644 index 000000000000..56936d2cdd7a --- /dev/null +++ b/src/transformers/models/deimv2/configuration_deimv2.py @@ -0,0 +1,365 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/deimv2/modular_deimv2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_deimv2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...backbone_utils import consolidate_backbone_kwargs_to_config +from ...configuration_utils import PreTrainedConfig +from ..auto import AutoConfig + + +# TODO: Attribute map assignment logic should be fixed in modular +# as well as super() call parsing because otherwise we cannot re-write args after initialization +class Deimv2Config(PreTrainedConfig): + """ + This is the configuration class to store the configuration of a [`Deimv2Model`]. It is used to instantiate a + DEIMv2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of DEIMv2-L-COCO. + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + initializer_range (`float`, *optional*, defaults to 0.01): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_bias_prior_prob (`float`, *optional*): + The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`. + If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + batch_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the batch normalization layers. + backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `HGNetV2Config()`): + The configuration of the backbone model. + freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`): + Whether to freeze the batch normalization layers in the backbone. + encoder_hidden_dim (`int`, *optional*, defaults to 256): + Dimension of the layers in hybrid encoder. + encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`): + Multi level features input for encoder. + feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`): + Strides used in each feature map. + encoder_layers (`int`, *optional*, defaults to 1): + Total of layers to be used by the encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in encoder. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + dropout (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`): + Indexes of the projected layers to be used in the encoder. + positional_encoding_temperature (`int`, *optional*, defaults to 10000): + The temperature parameter used to create the positional encodings. + encoder_activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. + activation_function (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the general layer. + eval_size (`tuple[int, int]`, *optional*): + Height and width used to computes the effective height and width of the position embeddings after taking + into account the stride. + normalize_before (`bool`, *optional*, defaults to `False`): + Determine whether to apply layer normalization in the transformer encoder layer before self-attention and + feed-forward modules. + hidden_expansion (`float`, *optional*, defaults to 1.0): + Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer. + d_model (`int`, *optional*, defaults to 256): + Dimension of the layers exclude hybrid encoder. + num_queries (`int`, *optional*, defaults to 300): + Number of object queries. + decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`): + Multi level features dimension for decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of input feature levels. + decoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the decoder. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_activation_function (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the decoder. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_denoising (`int`, *optional*, defaults to 100): + The total number of denoising tasks or queries to be used for contrastive denoising. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + The fraction of denoising labels to which random noise should be added. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale or magnitude of noise to be added to the bounding boxes. + learn_initial_query (`bool`, *optional*, defaults to `False`): + Indicates whether the initial query embeddings for the decoder should be learned during training. + anchor_image_size (`tuple[int, int]`, *optional*): + Height and width of the input image used during evaluation to generate the bounding box anchors. + with_box_refine (`bool`, *optional*, defaults to `True`): + Whether to apply iterative bounding box refinement. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the architecture has an encoder decoder structure. + matcher_alpha (`float`, *optional*, defaults to 0.25): + Parameter alpha used by the Hungarian Matcher. + matcher_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used by the Hungarian Matcher. + matcher_class_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the class loss used by the Hungarian Matcher. + matcher_bbox_cost (`float`, *optional*, defaults to 5.0): + The relative weight of the bounding box loss used by the Hungarian Matcher. + matcher_giou_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the giou loss of used by the Hungarian Matcher. + use_focal_loss (`bool`, *optional*, defaults to `True`): + Parameter informing if focal loss should be used. + auxiliary_loss (`bool`, *optional*, defaults to `True`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + focal_loss_alpha (`float`, *optional*, defaults to 0.75): + Parameter alpha used to compute the focal loss. + focal_loss_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used to compute the focal loss. + weight_loss_vfl (`float`, *optional*, defaults to 1.0): + Relative weight of the varifocal loss in the object detection loss. + weight_loss_bbox (`float`, *optional*, defaults to 5.0): + Relative weight of the L1 bounding box loss in the object detection loss. + weight_loss_giou (`float`, *optional*, defaults to 2.0): + Relative weight of the generalized IoU loss in the object detection loss. + weight_loss_fgl (`float`, *optional*, defaults to 0.15): + Relative weight of the fine-grained localization loss in the object detection loss. + weight_loss_ddf (`float`, *optional*, defaults to 1.5): + Relative weight of the decoupled distillation focal loss in the object detection loss. + eos_coefficient (`float`, *optional*, defaults to 0.0001): + Relative classification weight of the 'no-object' class in the object detection loss. + eval_idx (`int`, *optional*, defaults to -1): + Index of the decoder layer to use for evaluation. + layer_scale (`float`, *optional*, defaults to `1.0`): + Scaling factor for the hidden dimension in later decoder layers. + max_num_bins (`int`, *optional*, defaults to 32): + Maximum number of bins for the distribution-guided bounding box refinement. + reg_scale (`float`, *optional*, defaults to 4.0): + Scale factor for the regression distribution. + depth_mult (`float`, *optional*, defaults to 1.0): + Multiplier for the number of blocks in RepNCSPELAN5 layers. + top_prob_values (`int`, *optional*, defaults to 4): + Number of top probability values to consider from each corner's distribution. + lqe_hidden_dim (`int`, *optional*, defaults to 64): + Hidden dimension size for the Location Quality Estimator (LQE) network. + lqe_layers (`int`, *optional*, defaults to 2): + Number of layers in the Location Quality Estimator MLP. + decoder_offset_scale (`float`, *optional*, defaults to 0.5): + Offset scale used in deformable attention. + decoder_method (`str`, *optional*, defaults to `"default"`): + The method to use for the decoder: `"default"` or `"discrete"`. + up (`float`, *optional*, defaults to 0.5): + Controls the upper bounds of the Weighting Function. + weight_loss_mal (`float`, *optional*, defaults to 1.0): + Relative weight of the matching auxiliary loss in the object detection loss. + use_dense_o2o (`bool`, *optional*, defaults to `True`): + Whether to use dense one-to-one matching across decoder layers. + mal_alpha (`float`, *optional*): + Alpha parameter for the Matching Auxiliary Loss (MAL). If `None`, uses `focal_loss_alpha`. + encoder_fuse_op (`str`, *optional*, defaults to `"sum"`): + Fusion operation used in the encoder FPN. DEIMv2 uses `"sum"` instead of D-Fine's `"cat"`. + use_spatial_tuning_adapter (`bool`, *optional*, defaults to `False`): + Whether to use the Spatial Tuning Adapter (STA) for DINOv2 backbone variants. + sta_inplanes (`int`, *optional*, defaults to 16): + Number of input planes for the STA convolutional stem. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings. + """ + + model_type = "deimv2" + sub_configs = {"backbone_config": AutoConfig} + layer_types = ["basic", "bottleneck"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + } + + def __init__( + self, + initializer_range=0.01, + initializer_bias_prior_prob=None, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + backbone_config=None, + freeze_backbone_batch_norms=True, + encoder_hidden_dim=256, + encoder_in_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=1024, + encoder_attention_heads=8, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + hidden_expansion=1.0, + d_model=256, + num_queries=300, + decoder_in_channels=[256, 256, 256], + decoder_ffn_dim=1024, + num_feature_levels=3, + decoder_n_points=4, + decoder_layers=6, + decoder_attention_heads=8, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + with_box_refine=True, + is_encoder_decoder=True, + matcher_alpha=0.25, + matcher_gamma=2.0, + matcher_class_cost=2.0, + matcher_bbox_cost=5.0, + matcher_giou_cost=2.0, + use_focal_loss=True, + auxiliary_loss=True, + focal_loss_alpha=0.75, + focal_loss_gamma=2.0, + weight_loss_vfl=1.0, + weight_loss_bbox=5.0, + weight_loss_giou=2.0, + weight_loss_fgl=0.15, + weight_loss_ddf=1.5, + eos_coefficient=1e-4, + eval_idx=-1, + layer_scale=1, + max_num_bins=32, + reg_scale=4.0, + depth_mult=1.0, + top_prob_values=4, + lqe_hidden_dim=64, + lqe_layers=2, + decoder_offset_scale=0.5, + decoder_method="default", + up=0.5, + weight_loss_mal=1.0, + use_dense_o2o=True, + mal_alpha=None, + encoder_fuse_op="sum", + use_spatial_tuning_adapter=False, + sta_inplanes=16, + tie_word_embeddings=True, + **kwargs, + ): + self.weight_loss_mal = weight_loss_mal + self.use_dense_o2o = use_dense_o2o + self.mal_alpha = mal_alpha + self.encoder_fuse_op = encoder_fuse_op + self.use_spatial_tuning_adapter = use_spatial_tuning_adapter + self.sta_inplanes = sta_inplanes + self.initializer_range = initializer_range + self.initializer_bias_prior_prob = initializer_bias_prior_prob + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + + backbone_config, kwargs = consolidate_backbone_kwargs_to_config( + backbone_config=backbone_config, + default_config_type="hgnet_v2", + default_config_kwargs={"out_indices": [2, 3, 4]}, + **kwargs, + ) + + self.backbone_config = backbone_config + self.freeze_backbone_batch_norms = freeze_backbone_batch_norms + # encoder + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = encode_proj_layers + self.encoder_layers = encoder_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.eval_size = eval_size + self.normalize_before = normalize_before + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.hidden_expansion = hidden_expansion + # decoder + self.d_model = d_model + self.num_queries = num_queries + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_in_channels = decoder_in_channels + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.auxiliary_loss = auxiliary_loss + self.with_box_refine = with_box_refine + # Loss + self.matcher_alpha = matcher_alpha + self.matcher_gamma = matcher_gamma + self.matcher_class_cost = matcher_class_cost + self.matcher_bbox_cost = matcher_bbox_cost + self.matcher_giou_cost = matcher_giou_cost + self.use_focal_loss = use_focal_loss + self.focal_loss_alpha = focal_loss_alpha + self.focal_loss_gamma = focal_loss_gamma + self.weight_loss_vfl = weight_loss_vfl + self.weight_loss_bbox = weight_loss_bbox + self.weight_loss_giou = weight_loss_giou + self.weight_loss_fgl = weight_loss_fgl + self.weight_loss_ddf = weight_loss_ddf + self.eos_coefficient = eos_coefficient + # add the new attributes with the given values or defaults + self.eval_idx = eval_idx + self.layer_scale = layer_scale + self.max_num_bins = max_num_bins + self.reg_scale = reg_scale + self.depth_mult = depth_mult + self.decoder_offset_scale = decoder_offset_scale + self.decoder_method = decoder_method + self.top_prob_values = top_prob_values + self.lqe_hidden_dim = lqe_hidden_dim + self.lqe_layers = lqe_layers + self.up = up + self.tie_word_embeddings = tie_word_embeddings + + if isinstance(self.decoder_n_points, list): + if len(self.decoder_n_points) != self.num_feature_levels: + raise ValueError( + f"Length of decoder_n_points list ({len(self.decoder_n_points)}) must match num_feature_levels ({self.num_feature_levels})." + ) + + head_dim = self.d_model // self.decoder_attention_heads + if head_dim * self.decoder_attention_heads != self.d_model: + raise ValueError( + f"Embedded dimension {self.d_model} must be divisible by decoder_attention_heads {self.decoder_attention_heads}" + ) + + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + + +__all__ = ["Deimv2Config"] diff --git a/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py b/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py new file mode 100644 index 000000000000..5aaecc844594 --- /dev/null +++ b/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py @@ -0,0 +1,451 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import re +from io import BytesIO +from pathlib import Path + +import httpx +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from safetensors.torch import load_file +from torchvision import transforms + +from transformers import Deimv2Config, Deimv2ForObjectDetection, RTDetrImageProcessor +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +MODEL_NAME_TO_HUB_REPO = { + "deimv2_hgnetv2_n_coco": "Intellindust/DEIMv2_HGNetv2_N_COCO", + "deimv2_hgnetv2_pico_coco": "Intellindust/DEIMv2_HGNetv2_PICO_COCO", + "deimv2_hgnetv2_femto_coco": "Intellindust/DEIMv2_HGNetv2_FEMTO_COCO", + "deimv2_hgnetv2_atto_coco": "Intellindust/DEIMv2_HGNetv2_ATTO_COCO", + "deimv2_dinov3_s_coco": "Intellindust/DEIMv2_DINOv3_S_COCO", + "deimv2_dinov3_m_coco": "Intellindust/DEIMv2_DINOv3_M_COCO", + "deimv2_dinov3_l_coco": "Intellindust/DEIMv2_DINOv3_L_COCO", + "deimv2_dinov3_x_coco": "Intellindust/DEIMv2_DINOv3_X_COCO", +} + + +def get_deimv2_config(model_name: str) -> Deimv2Config: + repo_id = MODEL_NAME_TO_HUB_REPO[model_name] + config_path = hf_hub_download(repo_id=repo_id, filename="config.json") + with open(config_path) as f: + orig_config = json.load(f) + + # COCO labels + id2label = json.load( + open(hf_hub_download("huggingface/label-files", "coco-detection-mmdet-id2label.json", repo_type="dataset")) + ) + id2label = {int(k): v for k, v in id2label.items()} + + decoder_cfg = orig_config["DEIMTransformer"] + if "HybridEncoder" in orig_config: + encoder_cfg = orig_config["HybridEncoder"] + elif "LiteEncoder" in orig_config: + raise ValueError( + "LiteEncoder variants (pico/femto/atto) are not yet supported. " + "The LiteEncoder uses a different architecture (AvgPool downsampling, GAP fusion, " + "RepNCSPELAN4 blocks) that requires a dedicated Deimv2LiteEncoder implementation. " + "Supported variants: deimv2_hgnetv2_n_coco and DINOv3 variants." + ) + else: + raise ValueError(f"No encoder config found. Available keys: {list(orig_config.keys())}") + + config = Deimv2Config() + config.num_labels = 80 + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + # Encoder settings + config.encoder_hidden_dim = encoder_cfg["hidden_dim"] + config.encoder_in_channels = encoder_cfg["in_channels"] + config.feat_strides = encoder_cfg["feat_strides"] + config.activation_function = encoder_cfg.get("act", "silu") + config.depth_mult = encoder_cfg.get("depth_mult", 1.0) + config.hidden_expansion = encoder_cfg.get("expansion", 1.0) + config.encoder_fuse_op = encoder_cfg.get("fuse_op", "sum") + config.encoder_ffn_dim = encoder_cfg["dim_feedforward"] + config.encoder_attention_heads = encoder_cfg["nhead"] + config.dropout = encoder_cfg.get("dropout", 0.0) + config.encode_proj_layers = encoder_cfg["use_encoder_idx"] + config.encoder_activation_function = encoder_cfg.get("enc_act", "gelu") + + # Decoder settings + config.d_model = decoder_cfg["hidden_dim"] + config.decoder_ffn_dim = decoder_cfg["dim_feedforward"] + config.decoder_layers = decoder_cfg["num_layers"] + config.num_feature_levels = decoder_cfg["num_levels"] + config.decoder_n_points = decoder_cfg["num_points"] + config.num_queries = decoder_cfg["num_queries"] + config.num_denoising = decoder_cfg.get("num_denoising", 100) + config.label_noise_ratio = decoder_cfg.get("label_noise_ratio", 0.5) + config.box_noise_scale = decoder_cfg.get("box_noise_scale", 1.0) + config.max_num_bins = decoder_cfg.get("reg_max", 32) + config.reg_scale = decoder_cfg.get("reg_scale", 4.0) + config.eval_idx = decoder_cfg.get("eval_idx", -1) + config.layer_scale = decoder_cfg.get("layer_scale", 1) + config.decoder_in_channels = decoder_cfg["feat_channels"] + config.eval_size = tuple(decoder_cfg["eval_spatial_size"]) if "eval_spatial_size" in decoder_cfg else None + config.decoder_activation_function = decoder_cfg.get("activation", "silu") + + # Backbone settings + if "HGNetv2" in orig_config: + backbone_cfg = orig_config["HGNetv2"] + backbone_name = backbone_cfg.get("name", "B0") + return_idx = backbone_cfg.get("return_idx", [2, 3]) + config.backbone_config.out_indices = [i + 1 for i in return_idx] + config.backbone_config.use_learnable_affine_block = backbone_cfg.get("use_lab", True) + + # Set backbone sizes based on the model variant + if backbone_name == "B0": + config.backbone_config.hidden_sizes = [128, 256, 512, 1024] + config.backbone_config.stem_channels = [3, 16, 16] + config.backbone_config.stage_in_channels = [16, 64, 256, 512] + config.backbone_config.stage_mid_channels = [16, 32, 64, 128] + config.backbone_config.stage_out_channels = [64, 256, 512, 1024] + config.backbone_config.stage_num_blocks = [1, 1, 2, 1] + config.backbone_config.stage_downsample = [False, True, True, True] + config.backbone_config.stage_light_block = [False, False, True, True] + config.backbone_config.stage_kernel_size = [3, 3, 5, 5] + config.backbone_config.stage_numb_of_layers = [3, 3, 3, 3] + elif backbone_name in ["B1", "B2"]: + config.backbone_config.hidden_sizes = [128, 256, 512, 1024] + config.backbone_config.stem_channels = [3, 16, 16] + config.backbone_config.stage_in_channels = [16, 64, 256, 512] + config.backbone_config.stage_mid_channels = [16, 32, 64, 128] + config.backbone_config.stage_out_channels = [64, 256, 512, 1024] + config.backbone_config.stage_num_blocks = [1, 1, 2, 1] + config.backbone_config.stage_downsample = [False, True, True, True] + config.backbone_config.stage_light_block = [False, False, True, True] + config.backbone_config.stage_kernel_size = [3, 3, 5, 5] + config.backbone_config.stage_numb_of_layers = [3, 3, 3, 3] + else: + raise ValueError(f"Unknown HGNetv2 variant: {backbone_name}") + + config.use_spatial_tuning_adapter = False + elif "DINOv3STAs" in orig_config: + raise ValueError( + "DINOv3 backbone variants are not yet supported. " + "The DINOv3+STA architecture requires ViT backbone key mappings and " + "STA adapter integration that are not yet implemented in the conversion script. " + "Supported variants: deimv2_hgnetv2_n_coco." + ) + else: + raise ValueError(f"Unknown backbone in config: {list(orig_config.keys())}") + + return config + + +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + # Backbone stem mappings + r"backbone\.stem\.(stem\w+)\.conv\.weight": r"model.backbone.model.embedder.\1.convolution.weight", + # Stem normalization + r"backbone\.stem\.(stem\w+)\.bn\.(weight|bias|running_mean|running_var)": r"model.backbone.model.embedder.\1.normalization.\2", + # Stem lab parameters + r"backbone\.stem\.(stem\w+)\.lab\.(scale|bias)": r"model.backbone.model.embedder.\1.lab.\2", + # Backbone stages mappings + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv\.weight": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.convolution.weight", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.bn\.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.normalization.\4", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.lab\.(scale|bias)": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.lab.\4", + # Conv1/Conv2 layers + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv1\.conv\.weight": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv1.convolution.weight", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv1\.bn\.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv1.normalization.\4", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv1\.lab\.(scale|bias)": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv1.lab.\4", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv2\.conv\.weight": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv2.convolution.weight", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv2\.bn\.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv2.normalization.\4", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv2\.lab\.(scale|bias)": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv2.lab.\4", + # Backbone stages aggregation + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.aggregation\.(\d+)\.conv\.weight": r"model.backbone.model.encoder.stages.\1.blocks.\2.aggregation.\3.convolution.weight", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.aggregation\.(\d+)\.bn\.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.blocks.\2.aggregation.\3.normalization.\4", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.aggregation\.(\d+)\.lab\.(scale|bias)": r"model.backbone.model.encoder.stages.\1.blocks.\2.aggregation.\3.lab.\4", + # Downsample + r"backbone\.stages\.(\d+)\.downsample\.conv\.weight": r"model.backbone.model.encoder.stages.\1.downsample.convolution.weight", + r"backbone\.stages\.(\d+)\.downsample\.bn\.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.downsample.normalization.\2", + r"backbone\.stages\.(\d+)\.downsample\.lab\.(scale|bias)": r"model.backbone.model.encoder.stages.\1.downsample.lab.\2", + # Encoder mappings + # Input projections + r"encoder\.input_proj\.(\d+)\.conv\.weight": r"model.encoder_input_proj.\1.0.weight", + r"encoder\.input_proj\.(\d+)\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder_input_proj.\1.1.\2", + # AIFI transformer encoder layers + r"encoder\.encoder\.(\d+)\.layers\.0\.self_attn\.out_proj\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.self_attn.o_proj.\2", + r"encoder\.encoder\.(\d+)\.layers\.0\.linear1\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.mlp.layers.0.\2", + r"encoder\.encoder\.(\d+)\.layers\.0\.linear2\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.mlp.layers.1.\2", + r"encoder\.encoder\.(\d+)\.layers\.0\.norm1\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.self_attn_layer_norm.\2", + r"encoder\.encoder\.(\d+)\.layers\.0\.norm2\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.final_layer_norm.\2", + # Encoder projections and convolutions + r"encoder\.lateral_convs\.(\d+)\.conv\.weight": r"model.encoder.lateral_convs.\1.conv.weight", + r"encoder\.lateral_convs\.(\d+)\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.lateral_convs.\1.norm.\2", + # FPN blocks - complete structure + # Basic convolutions + r"encoder\.fpn_blocks\.(\d+)\.cv1\.conv\.weight": r"model.encoder.fpn_blocks.\1.conv1.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.conv1.norm.\2", + r"encoder\.fpn_blocks\.(\d+)\.cv4\.conv\.weight": r"model.encoder.fpn_blocks.\1.conv4.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv4\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.conv4.norm.\2", + # CSP Rep1 path + r"encoder\.fpn_blocks\.(\d+)\.cv2\.0\.conv1\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep1.conv1.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv2\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep1.conv1.norm.\2", + r"encoder\.fpn_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep1.bottlenecks.\2.conv1.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep1.bottlenecks.\2.conv1.norm.\3", + r"encoder\.fpn_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep1.bottlenecks.\2.conv2.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep1.bottlenecks.\2.conv2.norm.\3", + # CSP Rep2 path + r"encoder\.fpn_blocks\.(\d+)\.cv3\.0\.conv1\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep2.conv1.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep2.conv1.norm.\2", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep2.bottlenecks.\2.conv1.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep2.bottlenecks.\2.conv1.norm.\3", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep2.bottlenecks.\2.conv2.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep2.bottlenecks.\2.conv2.norm.\3", + # FPN trailing convs + r"encoder\.fpn_blocks\.(\d+)\.cv2\.1\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep1.conv3.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv2\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep1.conv3.norm.\2", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.1\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep2.conv3.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep2.conv3.norm.\2", + # PAN blocks - complete structure + r"encoder\.pan_blocks\.(\d+)\.cv1\.conv\.weight": r"model.encoder.pan_blocks.\1.conv1.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.conv1.norm.\2", + r"encoder\.pan_blocks\.(\d+)\.cv4\.conv\.weight": r"model.encoder.pan_blocks.\1.conv4.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv4\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.conv4.norm.\2", + # CSP Rep1 path + r"encoder\.pan_blocks\.(\d+)\.cv2\.0\.conv1\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep1.conv1.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv2\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep1.conv1.norm.\2", + r"encoder\.pan_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep1.bottlenecks.\2.conv1.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep1.bottlenecks.\2.conv1.norm.\3", + r"encoder\.pan_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep1.bottlenecks.\2.conv2.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep1.bottlenecks.\2.conv2.norm.\3", + # CSP Rep2 path + r"encoder\.pan_blocks\.(\d+)\.cv3\.0\.conv1\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep2.conv1.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv3\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep2.conv1.norm.\2", + r"encoder\.pan_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep2.bottlenecks.\2.conv1.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep2.bottlenecks.\2.conv1.norm.\3", + r"encoder\.pan_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep2.bottlenecks.\2.conv2.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep2.bottlenecks.\2.conv2.norm.\3", + # PAN trailing convs + r"encoder\.pan_blocks\.(\d+)\.cv2\.1\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep1.conv3.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv2\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep1.conv3.norm.\2", + r"encoder\.pan_blocks\.(\d+)\.cv3\.1\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep2.conv3.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv3\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep2.conv3.norm.\2", + # Downsample convolutions + r"encoder\.downsample_convs\.(\d+)\.0\.cv(\d+)\.conv\.weight": r"model.encoder.downsample_convs.\1.conv\2.conv.weight", + r"encoder\.downsample_convs\.(\d+)\.0\.cv(\d+)\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.downsample_convs.\1.conv\2.norm.\3", + # Decoder layers + r"decoder\.input_proj\.(\d+)\.0\.weight": r"model.decoder_input_proj.\1.0.weight", + r"decoder\.input_proj\.(\d+)\.1\.(weight|bias|running_mean|running_var)": r"model.decoder_input_proj.\1.1.\2", + r"decoder\.decoder\.layers\.(\d+)\.self_attn\.out_proj\.(weight|bias)": r"model.decoder.layers.\1.self_attn.o_proj.\2", + r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.sampling_offsets\.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.sampling_offsets.\2", + r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.attention_weights\.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.attention_weights.\2", + r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.value_proj\.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.value_proj.\2", + r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.output_proj\.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.output_proj.\2", + r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.num_points_scale": r"model.decoder.layers.\1.encoder_attn.num_points_scale", + r"decoder\.decoder\.layers\.(\d+)\.norm1\.scale": r"model.decoder.layers.\1.self_attn_layer_norm.scale", + r"decoder\.decoder\.layers\.(\d+)\.norm3\.scale": r"model.decoder.layers.\1.final_layer_norm.scale", + r"decoder\.decoder\.layers\.(\d+)\.swish_ffn\.w12\.(weight|bias)": r"model.decoder.layers.\1.mlp.w12.\2", + r"decoder\.decoder\.layers\.(\d+)\.swish_ffn\.w3\.(weight|bias)": r"model.decoder.layers.\1.mlp.w3.\2", + r"decoder\.decoder\.layers\.(\d+)\.gateway\.gate\.(weight|bias)": r"model.decoder.layers.\1.gateway.gate.\2", + r"decoder\.decoder\.layers\.(\d+)\.gateway\.norm\.scale": r"model.decoder.layers.\1.gateway.norm.scale", + # LQE layers + r"decoder\.decoder\.lqe_layers\.(\d+)\.reg_conf\.layers\.(\d+)\.(weight|bias)": r"model.decoder.lqe_layers.\1.reg_conf.layers.\2.\3", + # Decoder heads and projections + r"decoder\.dec_score_head\.(\d+)\.(weight|bias)": r"model.decoder.class_embed.\1.\2", + r"decoder\.dec_bbox_head\.(\d+)\.layers\.(\d+)\.(weight|bias)": r"model.decoder.bbox_embed.\1.layers.\2.\3", + r"decoder\.pre_bbox_head\.layers\.(\d+)\.(weight|bias)": r"model.decoder.pre_bbox_head.layers.\1.\2", + r"decoder\.query_pos_head\.layers\.(\d+)\.(weight|bias)": r"model.decoder.query_pos_head.layers.\1.\2", + # Encoder output and score heads + r"decoder\.enc_output\.proj\.(weight|bias)": r"model.enc_output.0.\1", + r"decoder\.enc_output\.norm\.(weight|bias)": r"model.enc_output.1.\1", + r"decoder\.enc_score_head\.(weight|bias)": r"model.enc_score_head.\1", + r"decoder\.enc_bbox_head\.layers\.(\d+)\.(weight|bias)": r"model.enc_bbox_head.layers.\1.\2", + # Denoising class embed + r"decoder\.denoising_class_embed\.weight": r"model.denoising_class_embed.weight", + # Decoder parameters + r"decoder\.decoder\.up": r"model.decoder.up", + r"decoder\.decoder\.reg_scale": r"model.decoder.reg_scale", +} + + +def convert_old_keys_to_new_keys(state_dict): + # Use the mapping to rename keys + for original_key, converted_key in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): + for key in list(state_dict.keys()): + new_key = re.sub(f"^{original_key}$", converted_key, key) + if new_key != key: + state_dict[new_key] = state_dict.pop(key) + return state_dict + + +def read_in_q_k_v(state_dict, config): + encoder_hidden_dim = config.encoder_hidden_dim + d_model = config.d_model + + # first: transformer encoder + for i in range(config.encoder_layers): + # read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"encoder.encoder.{i}.layers.0.self_attn.in_proj_weight", None) + in_proj_bias = state_dict.pop(f"encoder.encoder.{i}.layers.0.self_attn.in_proj_bias", None) + if in_proj_weight is not None: + # next, add query, keys and values (in that order) to the state dict + state_dict[f"model.encoder.aifi.{i}.layers.0.self_attn.q_proj.weight"] = in_proj_weight[ + :encoder_hidden_dim + ] + state_dict[f"model.encoder.aifi.{i}.layers.0.self_attn.k_proj.weight"] = in_proj_weight[ + encoder_hidden_dim : 2 * encoder_hidden_dim + ] + state_dict[f"model.encoder.aifi.{i}.layers.0.self_attn.v_proj.weight"] = in_proj_weight[ + -encoder_hidden_dim: + ] + if in_proj_bias is not None: + state_dict[f"model.encoder.aifi.{i}.layers.0.self_attn.q_proj.bias"] = in_proj_bias[:encoder_hidden_dim] + state_dict[f"model.encoder.aifi.{i}.layers.0.self_attn.k_proj.bias"] = in_proj_bias[ + encoder_hidden_dim : 2 * encoder_hidden_dim + ] + state_dict[f"model.encoder.aifi.{i}.layers.0.self_attn.v_proj.bias"] = in_proj_bias[-encoder_hidden_dim:] + + # next: transformer decoder (which is a bit more complex because it also includes cross-attention) + for i in range(config.decoder_layers): + # read in weights + bias of input projection layer of self-attention + in_proj_weight = state_dict.pop(f"decoder.decoder.layers.{i}.self_attn.in_proj_weight", None) + in_proj_bias = state_dict.pop(f"decoder.decoder.layers.{i}.self_attn.in_proj_bias", None) + if in_proj_weight is not None: + # next, add query, keys and values (in that order) to the state dict + state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:d_model] + state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:d_model] + state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[d_model : 2 * d_model] + state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[d_model : 2 * d_model] + state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-d_model:] + state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-d_model:] + + +def load_original_state_dict(repo_id): + filepath = hf_hub_download(repo_id=repo_id, filename="model.safetensors") + return load_file(filepath) + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + with httpx.stream("GET", url) as response: + image = Image.open(BytesIO(response.read())) + return image + + +@torch.no_grad() +def convert_deimv2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, repo_id): + """ + Copy/paste/tweak model's weights to our Deimv2 structure. + """ + hub_repo = MODEL_NAME_TO_HUB_REPO[model_name] + config = get_deimv2_config(model_name) + state_dict = load_original_state_dict(hub_repo) + + logger.info(f"Converting model {model_name} from {hub_repo}...") + logger.info(f"Original state dict has {len(state_dict)} keys") + + state_dict.pop("decoder.valid_mask", None) + state_dict.pop("decoder.anchors", None) + + # query, key and value matrices need special treatment + read_in_q_k_v(state_dict, config) + + state_dict = convert_old_keys_to_new_keys(state_dict) + + if "model.enc_output.0.weight" not in state_dict: + d_model = config.d_model + state_dict["model.enc_output.0.weight"] = torch.eye(d_model) + state_dict["model.enc_output.0.bias"] = torch.zeros(d_model) + state_dict["model.enc_output.1.weight"] = torch.ones(d_model) + state_dict["model.enc_output.1.bias"] = torch.zeros(d_model) + + # for two_stage + for key in list(state_dict.keys()): + if "bbox_embed" in key or ("class_embed" in key and "denoising_" not in key): + new_key = key.split("model.decoder.")[-1] + if new_key not in state_dict: + state_dict[new_key] = state_dict[key] + + model = Deimv2ForObjectDetection(config) + missing, unexpected = model.load_state_dict(state_dict, strict=False) + + if missing: + logger.warning(f"Missing keys ({len(missing)}): {missing[:10]}...") + if unexpected: + logger.warning(f"Unexpected keys ({len(unexpected)}): {unexpected[:10]}...") + + model.eval() + + image_processor = RTDetrImageProcessor() + img = prepare_img() + + transformations = transforms.Compose( + [ + transforms.Resize([640, 640], interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + ] + ) + original_pixel_values = transformations(img).unsqueeze(0) + encoding = image_processor(images=img, return_tensors="pt") + pixel_values = encoding["pixel_values"] + + assert torch.allclose(original_pixel_values, pixel_values), "Image preprocessing mismatch!" + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + pixel_values = pixel_values.to(device) + + outputs = model(pixel_values) + logger.info(f"Logits shape: {outputs.logits.shape}") + logger.info(f"Boxes shape: {outputs.pred_boxes.shape}") + logger.info(f"Logits sample: {outputs.logits[0, :3, :3]}") + logger.info(f"Boxes sample: {outputs.pred_boxes[0, :3, :3]}") + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + logger.info(f"Saving model to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + logger.info(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + push_repo = repo_id or f"deimv2-{model_name}" + logger.info(f"Pushing to hub: {push_repo}") + config.push_to_hub(repo_id=push_repo) + model.push_to_hub(repo_id=push_repo) + image_processor.push_to_hub(repo_id=push_repo) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + default="deimv2_hgnetv2_n_coco", + type=str, + choices=list(MODEL_NAME_TO_HUB_REPO.keys()), + help="Model name to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + help="Path to the output directory.", + ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether to push to the hub.") + parser.add_argument("--repo_id", type=str, default=None, help="Hub repo_id to push to.") + args = parser.parse_args() + convert_deimv2_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.repo_id) diff --git a/src/transformers/models/deimv2/modeling_deimv2.py b/src/transformers/models/deimv2/modeling_deimv2.py new file mode 100644 index 000000000000..154170174227 --- /dev/null +++ b/src/transformers/models/deimv2/modeling_deimv2.py @@ -0,0 +1,2098 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/deimv2/modular_deimv2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_deimv2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from collections.abc import Callable +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from ... import initialization as init +from ...activations import ACT2CLS +from ...backbone_utils import load_backbone +from ...image_transforms import center_to_corners_format, corners_to_center_format +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import compile_compatible_method_lru_cache +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int +from ...utils.generic import can_return_tuple, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from .configuration_deimv2 import Deimv2Config + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the Deimv2Decoder. This class adds two attributes to + BaseModelOutputWithCrossAttentions, namely: + - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer) + - a stacked tensor of intermediate reference points. + """ +) +class Deimv2DecoderOutput(ModelOutput): + r""" + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate predicted corners (predicted corners of each layer of the decoder). + initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked initial reference points (initial reference points of each layer of the decoder). + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor | None = None + intermediate_hidden_states: torch.FloatTensor | None = None + intermediate_logits: torch.FloatTensor | None = None + intermediate_reference_points: torch.FloatTensor | None = None + intermediate_predicted_corners: torch.FloatTensor | None = None + initial_reference_points: torch.FloatTensor | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + cross_attentions: tuple[torch.FloatTensor] | None = None + + +class Deimv2RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.float() + hidden_states = hidden_states * torch.rsqrt(hidden_states.pow(2).mean(-1, keepdim=True) + self.eps) + return (hidden_states * self.scale).to(input_dtype) + + +class Deimv2SwiGLUFFN(nn.Module): + def __init__(self, in_features: int, hidden_features: int, out_features: int): + super().__init__() + self.w12 = nn.Linear(in_features, 2 * hidden_features) + self.w3 = nn.Linear(hidden_features, out_features) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + x12 = self.w12(hidden_states) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +class Deimv2Gate(nn.Module): + def __init__(self, d_model: int): + super().__init__() + self.gate = nn.Linear(2 * d_model, 2 * d_model) + self.norm = Deimv2RMSNorm(d_model) + + def forward(self, second_residual: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor: + gate_input = torch.cat([second_residual, hidden_states], dim=-1) + gates = torch.sigmoid(self.gate(gate_input)) + gate1, gate2 = gates.chunk(2, dim=-1) + hidden_states = self.norm(gate1 * second_residual + gate2 * hidden_states) + return hidden_states + + +class Deimv2MLP(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act: str = "relu"): + super().__init__() + self.num_layers = num_layers + hidden_dims = [hidden_dim] * (num_layers - 1) + input_dims = [input_dim] + hidden_dims + output_dims = hidden_dims + [output_dim] + self.layers = nn.ModuleList(nn.Linear(in_dim, out_dim) for in_dim, out_dim in zip(input_dims, output_dims)) + self.act = ACT2CLS[act]() + + def forward(self, stat_features: torch.Tensor) -> torch.Tensor: + for i, layer in enumerate(self.layers): + stat_features = self.act(layer(stat_features)) if i < self.num_layers - 1 else layer(stat_features) + return stat_features + + +def multi_scale_deformable_attention_v2( + value: Tensor, + value_spatial_shapes: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, + num_points_list: list[int], + method="default", +) -> Tensor: + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points = sampling_locations.shape + value_list = ( + value.permute(0, 2, 3, 1) + .flatten(0, 1) + .split([height * width for height, width in value_spatial_shapes], dim=-1) + ) + # sampling_offsets [8, 480, 8, 12, 2] + if method == "default": + sampling_grids = 2 * sampling_locations - 1 + elif method == "discrete": + sampling_grids = sampling_locations + sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) + sampling_grids = sampling_grids.split(num_points_list, dim=-2) + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = value_list[level_id].reshape(batch_size * num_heads, hidden_dim, height, width) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[level_id] + # batch_size*num_heads, hidden_dim, num_queries, num_points + if method == "default": + sampling_value_l_ = nn.functional.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + elif method == "discrete": + sampling_coord = (sampling_grid_l_ * torch.tensor([[width, height]], device=value.device) + 0.5).to( + torch.int64 + ) + + # Separate clamping for x and y coordinates + sampling_coord_x = sampling_coord[..., 0].clamp(0, width - 1) + sampling_coord_y = sampling_coord[..., 1].clamp(0, height - 1) + + # Combine the clamped coordinates + sampling_coord = torch.stack([sampling_coord_x, sampling_coord_y], dim=-1) + sampling_coord = sampling_coord.reshape(batch_size * num_heads, num_queries * num_points_list[level_id], 2) + sampling_idx = ( + torch.arange(sampling_coord.shape[0], device=value.device) + .unsqueeze(-1) + .repeat(1, sampling_coord.shape[1]) + ) + sampling_value_l_ = value_l_[sampling_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] + sampling_value_l_ = sampling_value_l_.permute(0, 2, 1).reshape( + batch_size * num_heads, hidden_dim, num_queries, num_points_list[level_id] + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.permute(0, 2, 1, 3).reshape( + batch_size * num_heads, 1, num_queries, sum(num_points_list) + ) + output = ( + (torch.concat(sampling_value_list, dim=-1) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +class Deimv2MultiscaleDeformableAttention(nn.Module): + def __init__(self, config: Deimv2Config): + """ + D-Fine version of multiscale deformable attention + """ + super().__init__() + self.d_model = config.d_model + self.n_heads = config.decoder_attention_heads + self.n_levels = config.num_feature_levels + self.offset_scale = config.decoder_offset_scale + self.decoder_method = config.decoder_method + self.n_points = config.decoder_n_points + + if isinstance(self.n_points, list): + num_points_list = self.n_points + else: + num_points_list = [self.n_points for _ in range(self.n_levels)] + + self.num_points_list = num_points_list + num_points_scale = [1 / n for n in self.num_points_list for _ in range(n)] + self.register_buffer("num_points_scale", torch.tensor(num_points_scale, dtype=torch.float32)) + + self.total_points = self.n_heads * sum(self.num_points_list) + + self.sampling_offsets = nn.Linear(self.d_model, self.total_points * 2) + self.attention_weights = nn.Linear(self.d_model, self.total_points) + + self.ms_deformable_attn_core = multi_scale_deformable_attention_v2 + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + reference_points=None, + encoder_hidden_states=None, + spatial_shapes=None, + spatial_shapes_list=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + + torch_compilable_check( + (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == sequence_length, + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states", + ) + + # Reshape for multi-head attention + value = encoder_hidden_states.reshape(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + if attention_mask is not None: + value = value.masked_fill(~attention_mask[..., None], float(0)) + + sampling_offsets: torch.Tensor = self.sampling_offsets(hidden_states) + sampling_offsets = sampling_offsets.reshape( + batch_size, num_queries, self.n_heads, sum(self.num_points_list), 2 + ) + + attention_weights = self.attention_weights(hidden_states).reshape( + batch_size, num_queries, self.n_heads, sum(self.num_points_list) + ) + attention_weights = F.softmax(attention_weights, dim=-1) + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.tensor(spatial_shapes) + offset_normalizer = offset_normalizer.flip([1]).reshape(1, 1, 1, self.n_levels, 1, 2) + sampling_locations = ( + reference_points.reshape(batch_size, sequence_length, 1, self.n_levels, 1, 2) + + sampling_offsets / offset_normalizer + ) + elif reference_points.shape[-1] == 4: + # reference_points [8, 480, None, 1, 4] + # sampling_offsets [8, 480, 8, 12, 2] + num_points_scale = self.num_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1) + offset = sampling_offsets * num_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale + sampling_locations = reference_points[:, :, None, :, :2] + offset + else: + raise ValueError( + f"Last dim of reference_points must be 2 or 4, but get {reference_points.shape[-1]} instead." + ) + + output = self.ms_deformable_attn_core( + value, + spatial_shapes_list, + sampling_locations, + attention_weights, + self.num_points_list, + self.decoder_method, + ) + + return output, attention_weights + + +class Deimv2ConvNormLayer(nn.Module): + def __init__( + self, + config: Deimv2Config, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + groups: int = 1, + padding: int | None = None, + activation: str | None = None, + ): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + groups=groups, + padding=(kernel_size - 1) // 2 if padding is None else padding, + bias=False, + ) + self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, hidden_state): + hidden_state = self.conv(hidden_state) + hidden_state = self.norm(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class Deimv2RepVggBlock(nn.Module): + """ + RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again". + """ + + def __init__(self, config: Deimv2Config, in_channels: int, out_channels: int): + super().__init__() + + activation = config.activation_function + hidden_channels = in_channels + self.conv1 = Deimv2ConvNormLayer(config, hidden_channels, out_channels, 3, 1, padding=1) + self.conv2 = Deimv2ConvNormLayer(config, hidden_channels, out_channels, 1, 1, padding=0) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, x): + y = self.conv1(x) + self.conv2(x) + return self.activation(y) + + +class Deimv2CSPRepLayer2(nn.Module): + def __init__( + self, config: Deimv2Config, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0 + ): + super().__init__() + activation = config.activation_function + hidden_channels = int(out_channels * expansion) + self.conv1 = Deimv2ConvNormLayer(config, in_channels, hidden_channels * 2, 1, 1, activation=activation) + self.bottlenecks = nn.ModuleList( + [Deimv2RepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)] + ) + self.conv3 = Deimv2ConvNormLayer(config, hidden_channels, out_channels, 3, 1, activation=activation) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + chunks = list(self.conv1(hidden_state).chunk(2, 1)) + bottleneck_out = chunks[1] + for bottleneck in self.bottlenecks: + bottleneck_out = bottleneck(bottleneck_out) + return self.conv3(chunks[0] + bottleneck_out) + + +class Deimv2RepNCSPELAN5(nn.Module): + def __init__(self, config: Deimv2Config, numb_blocks: int = 3): + super().__init__() + act = config.activation_function + c1 = config.encoder_hidden_dim + c2 = config.encoder_hidden_dim + c3 = config.encoder_hidden_dim * 2 + c4 = round(config.hidden_expansion * config.encoder_hidden_dim // 2) + self.conv_dim = c3 // 2 + self.conv1 = Deimv2ConvNormLayer(config, c1, c3, 1, 1, activation=act) + self.csp_rep1 = Deimv2CSPRepLayer2(config, c3 // 2, c4, num_blocks=numb_blocks) + self.csp_rep2 = Deimv2CSPRepLayer2(config, c4, c4, num_blocks=numb_blocks) + self.conv4 = Deimv2ConvNormLayer(config, c3 + (2 * c4), c2, 1, 1, activation=act) + + def forward(self, input_features: torch.Tensor) -> torch.Tensor: + split_features = list(self.conv1(input_features).split((self.conv_dim, self.conv_dim), 1)) + branch1 = self.csp_rep1(split_features[-1]) + branch2 = self.csp_rep2(branch1) + split_features.extend([branch1, branch2]) + merged_features = torch.cat(split_features, 1) + return self.conv4(merged_features) + + +class Deimv2SCDown(nn.Module): + def __init__(self, config: Deimv2Config, kernel_size: int, stride: int): + super().__init__() + self.conv1 = Deimv2ConvNormLayer(config, config.encoder_hidden_dim, config.encoder_hidden_dim, 1, 1) + self.conv2 = Deimv2ConvNormLayer( + config, + config.encoder_hidden_dim, + config.encoder_hidden_dim, + kernel_size, + stride, + config.encoder_hidden_dim, + ) + + def forward(self, input_features: torch.Tensor) -> torch.Tensor: + input_features = self.conv1(input_features) + input_features = self.conv2(input_features) + return input_features + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float | None = None, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Deimv2SelfAttention(nn.Module): + """ + Multi-headed self-attention from 'Attention Is All You Need' paper. + + In DEIMV2, position embeddings are added to both queries and keys (but not values) in self-attention. + """ + + def __init__( + self, + config: Deimv2Config, + hidden_size: int, + num_attention_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + self.config = config + self.head_dim = hidden_size // num_attention_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = dropout + self.is_causal = False + + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Position embeddings are added to both queries and keys (but not values). + """ + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states + + query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Deimv2EncoderLayer(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + self.normalize_before = config.normalize_before + self.hidden_size = config.encoder_hidden_dim + + # self-attention + self.self_attn = Deimv2SelfAttention( + config=config, + hidden_size=self.hidden_size, + num_attention_heads=config.num_attention_heads, + dropout=config.dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) + self.dropout = config.dropout + self.mlp = Deimv2MLP( + self.hidden_size, config.encoder_ffn_dim, self.hidden_size, 2, config.encoder_activation_function + ) + self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + spatial_position_embeddings: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + spatial_position_embeddings (`torch.FloatTensor`, *optional*): + Spatial position embeddings (2D positional encodings of image locations), to be added to both + the queries and keys in self-attention (but not to values). + """ + residual = hidden_states + if self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=spatial_position_embeddings, + **kwargs, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + if self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + residual = hidden_states + + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + + if self.training: + if not torch.isfinite(hidden_states).all(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + return hidden_states + + +class Deimv2SinePositionEmbedding(nn.Module): + """ + 2D sinusoidal position embedding used in RT-DETR hybrid encoder. + """ + + def __init__(self, embed_dim: int = 256, temperature: int = 10000): + super().__init__() + self.embed_dim = embed_dim + self.temperature = temperature + + @compile_compatible_method_lru_cache(maxsize=32) + def forward( + self, + width: int, + height: int, + device: torch.device | str, + dtype: torch.dtype, + ) -> torch.Tensor: + """ + Generate 2D sinusoidal position embeddings. + + Returns: + Position embeddings of shape (1, height*width, embed_dim) + """ + grid_w = torch.arange(torch_int(width), device=device).to(dtype) + grid_h = torch.arange(torch_int(height), device=device).to(dtype) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy") + if self.embed_dim % 4 != 0: + raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding") + pos_dim = self.embed_dim // 4 + omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim + omega = 1.0 / (self.temperature**omega) + + out_w = grid_w.flatten()[..., None] @ omega[None] + out_h = grid_h.flatten()[..., None] @ omega[None] + + return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :] + + +class Deimv2AIFILayer(nn.Module): + """ + AIFI (Attention-based Intra-scale Feature Interaction) layer used in RT-DETR hybrid encoder. + """ + + def __init__(self, config: Deimv2Config): + super().__init__() + self.config = config + self.encoder_hidden_dim = config.encoder_hidden_dim + self.eval_size = config.eval_size + + self.position_embedding = Deimv2SinePositionEmbedding( + embed_dim=self.encoder_hidden_dim, + temperature=config.positional_encoding_temperature, + ) + self.layers = nn.ModuleList([Deimv2EncoderLayer(config) for _ in range(config.encoder_layers)]) + + def forward( + self, + hidden_states: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`): + Feature map to process. + """ + batch_size = hidden_states.shape[0] + height, width = hidden_states.shape[2:] + + hidden_states = hidden_states.flatten(2).permute(0, 2, 1) + + if self.training or self.eval_size is None: + pos_embed = self.position_embedding( + width=width, + height=height, + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + else: + pos_embed = None + + for layer in self.layers: + hidden_states = layer( + hidden_states, + attention_mask=None, + spatial_position_embeddings=pos_embed, + **kwargs, + ) + + hidden_states = ( + hidden_states.permute(0, 2, 1).reshape(batch_size, self.encoder_hidden_dim, height, width).contiguous() + ) + + return hidden_states + + +class Deimv2SpatialTuningAdapter(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + inplanes = config.sta_inplanes + self.stem = nn.Sequential( + nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(inplanes), + nn.GELU(), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + self.conv2 = nn.Sequential( + nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(2 * inplanes), + ) + self.conv3 = nn.Sequential( + nn.GELU(), + nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(4 * inplanes), + ) + self.conv4 = nn.Sequential( + nn.GELU(), + nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(4 * inplanes), + ) + + def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + c1 = self.stem(pixel_values) + c2 = self.conv2(c1) + c3 = self.conv3(c2) + c4 = self.conv4(c3) + return c2, c3, c4 + + +class Deimv2Integral(nn.Module): + """ + A static layer that calculates integral results from a distribution. + + This layer computes the target location using the formula: `sum{Pr(n) * W(n)}`, + where Pr(n) is the softmax probability vector representing the discrete + distribution, and W(n) is the non-uniform Weighting Function. + + Args: + max_num_bins (int): Max number of the discrete bins. Default is 32. + It can be adjusted based on the dataset or task requirements. + """ + + def __init__(self, config: Deimv2Config): + super().__init__() + self.max_num_bins = config.max_num_bins + + def forward(self, pred_corners: torch.Tensor, project: torch.Tensor) -> torch.Tensor: + batch_size, num_queries, _ = pred_corners.shape + pred_corners = F.softmax(pred_corners.reshape(-1, self.max_num_bins + 1), dim=1) + pred_corners = F.linear(pred_corners, project.to(pred_corners.device)).reshape(-1, 4) + pred_corners = pred_corners.reshape(batch_size, num_queries, -1) + return pred_corners + + +class Deimv2LQE(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + self.top_prob_values = config.top_prob_values + self.max_num_bins = config.max_num_bins + self.reg_conf = Deimv2MLP(4 * (self.top_prob_values + 1), config.lqe_hidden_dim, 1, config.lqe_layers) + + def forward(self, scores: torch.Tensor, pred_corners: torch.Tensor) -> torch.Tensor: + batch_size, length, _ = pred_corners.size() + prob = F.softmax(pred_corners.reshape(batch_size, length, 4, self.max_num_bins + 1), dim=-1) + prob_topk, _ = prob.topk(self.top_prob_values, dim=-1) + stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1) + quality_score = self.reg_conf(stat.reshape(batch_size, length, -1)) + scores = scores + quality_score + return scores + + +class Deimv2DecoderLayer(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + self.hidden_size = config.d_model + + # self-attention + self.self_attn = Deimv2SelfAttention( + config=config, + hidden_size=self.hidden_size, + num_attention_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + ) + self.dropout = config.dropout + self.self_attn_layer_norm = Deimv2RMSNorm(config.d_model) + self.encoder_attn = Deimv2MultiscaleDeformableAttention(config=config) + self.mlp = Deimv2SwiGLUFFN(config.d_model, config.decoder_ffn_dim // 2, config.d_model) + self.final_layer_norm = Deimv2RMSNorm(config.d_model) + self.gateway = Deimv2Gate(config.d_model) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor | None = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, hidden_size)`. + object_queries_position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings for the object query slots. These are added to both queries and keys + in the self-attention layer (not values). + reference_points (`torch.FloatTensor`, *optional*): + Reference points. + spatial_shapes (`torch.LongTensor`, *optional*): + Spatial shapes. + level_start_index (`torch.LongTensor`, *optional*): + Level start index. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, hidden_size)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + """ + residual = hidden_states + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=encoder_attention_mask, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + + # Cross-Attention + hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings + hidden_states, _ = self.encoder_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.gateway(residual, hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states.clamp(min=-65504, max=65504)) + + return hidden_states + + +class Deimv2MLPPredictionHead(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act: str = "relu"): + super().__init__() + self.num_layers = num_layers + hidden_dims = [hidden_dim] * (num_layers - 1) + input_dims = [input_dim] + hidden_dims + output_dims = hidden_dims + [output_dim] + self.layers = nn.ModuleList(nn.Linear(in_dim, out_dim) for in_dim, out_dim in zip(input_dims, output_dims)) + self.act = ACT2CLS[act]() + + def forward(self, stat_features: torch.Tensor) -> torch.Tensor: + for i, layer in enumerate(self.layers): + stat_features = self.act(layer(stat_features)) if i < self.num_layers - 1 else layer(stat_features) + return stat_features + + +@auto_docstring +class Deimv2PreTrainedModel(PreTrainedModel): + config: Deimv2Config + base_model_prefix = "deimv2" + main_input_name = "pixel_values" + input_modalities = ("image",) + _no_split_modules = [r"Deimv2HybridEncoder", r"Deimv2DecoderLayer"] + _supports_sdpa = True + _supports_flash_attn = True + _supports_attention_backend = True + _supports_flex_attn = True + + @torch.no_grad() + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (Deimv2ForObjectDetection, Deimv2Decoder)): + if module.class_embed is not None: + for layer in module.class_embed: + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias = float(-math.log((1 - prior_prob) / prior_prob)) + init.xavier_uniform_(layer.weight) + init.constant_(layer.bias, bias) + + if module.bbox_embed is not None: + for layer in module.bbox_embed: + init.constant_(layer.layers[-1].weight, 0) + init.constant_(layer.layers[-1].bias, 0) + + if hasattr(module, "reg_scale"): + init.constant_(module.reg_scale, self.config.reg_scale) + + if hasattr(module, "up"): + init.constant_(module.up, self.config.up) + + if isinstance(module, Deimv2MultiscaleDeformableAttention): + init.constant_(module.sampling_offsets.weight, 0.0) + default_dtype = torch.get_default_dtype() + thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( + 2.0 * math.pi / module.n_heads + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values + grid_init = grid_init.reshape(module.n_heads, 1, 2).tile([1, sum(module.num_points_list), 1]) + scaling = torch.concat([torch.arange(1, n + 1) for n in module.num_points_list]).reshape(1, -1, 1) + grid_init *= scaling + init.copy_(module.sampling_offsets.bias, grid_init.flatten()) + + init.constant_(module.attention_weights.weight, 0.0) + init.constant_(module.attention_weights.bias, 0.0) + + num_points_scale = [1 / n for n in module.num_points_list for _ in range(n)] + init.copy_(module.num_points_scale, torch.tensor(num_points_scale, dtype=torch.float32)) + + if isinstance(module, Deimv2Model): + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias = float(-math.log((1 - prior_prob) / prior_prob)) + init.xavier_uniform_(module.enc_score_head.weight) + init.constant_(module.enc_score_head.bias, bias) + + if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + init.zeros_(module.bias) + if getattr(module, "running_mean", None) is not None: + init.zeros_(module.running_mean) + init.ones_(module.running_var) + init.zeros_(module.num_batches_tracked) + + if isinstance(module, Deimv2Gate): + bias = float(-math.log((1 - 0.5) / 0.5)) + init.constant_(module.gate.bias, bias) + init.constant_(module.gate.weight, 0) + + if isinstance(module, Deimv2LQE): + init.constant_(module.reg_conf.layers[-1].bias, 0) + init.constant_(module.reg_conf.layers[-1].weight, 0) + + if isinstance(module, Deimv2SwiGLUFFN): + init.xavier_uniform_(module.w12.weight) + init.constant_(module.w12.bias, 0) + init.xavier_uniform_(module.w3.weight) + init.constant_(module.w3.bias, 0) + + if isinstance(module, Deimv2RMSNorm): + init.ones_(module.scale) + + if isinstance(module, nn.LayerNorm): + init.ones_(module.weight) + init.zeros_(module.bias) + + if hasattr(module, "weight_embedding") and self.config.learn_initial_query: + init.xavier_uniform_(module.weight_embedding.weight) + if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: + init.xavier_uniform_(module.denoising_class_embed.weight) + + +class Deimv2HybridEncoder(Deimv2PreTrainedModel): + """ + Hybrid encoder consisting of AIFI (Attention-based Intra-scale Feature Interaction) layers, + a top-down Feature Pyramid Network (FPN) and a bottom-up Path Aggregation Network (PAN). + More details on the paper: https://huggingface.co/papers/2304.08069 + + Args: + config: Deimv2Config + """ + + _can_record_outputs = { + "hidden_states": Deimv2AIFILayer, + "attentions": Deimv2SelfAttention, + } + + def __init__(self, config: Deimv2Config): + super().__init__(config) + self.config = config + self.in_channels = config.encoder_in_channels + self.num_fpn_stages = len(self.in_channels) - 1 + self.feat_strides = config.feat_strides + self.encoder_hidden_dim = config.encoder_hidden_dim + self.encode_proj_layers = config.encode_proj_layers + self.positional_encoding_temperature = config.positional_encoding_temperature + self.eval_size = config.eval_size + self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels] + self.out_strides = self.feat_strides + self.encoder_fuse_op = config.encoder_fuse_op + + self.aifi = nn.ModuleList([Deimv2AIFILayer(config) for _ in range(len(self.encode_proj_layers))]) + + self.lateral_convs = nn.ModuleList() + self.fpn_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1, 0, -1): + lateral_layer = Deimv2ConvNormLayer(config, self.encoder_hidden_dim, self.encoder_hidden_dim, 1, 1) + self.lateral_convs.append(lateral_layer) + num_blocks = round(3 * config.depth_mult) + fpn_layer = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks) + self.fpn_blocks.append(fpn_layer) + + self.downsample_convs = nn.ModuleList() + self.pan_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1): + self.downsample_convs.append(Deimv2SCDown(config, 3, 2)) + num_blocks = round(3 * config.depth_mult) + self.pan_blocks.append(Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks)) + + self.post_init() + + @merge_with_config_defaults + @capture_outputs(tie_last_hidden_states=False) + def forward( + self, + inputs_embeds=None, + **kwargs, + ) -> BaseModelOutput: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + """ + feature_maps = inputs_embeds + + if self.config.encoder_layers > 0: + for i, enc_ind in enumerate(self.encode_proj_layers): + feature_maps[enc_ind] = self.aifi[i](feature_maps[enc_ind], **kwargs) + + fpn_feature_maps = [feature_maps[-1]] + for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)): + backbone_feature_map = feature_maps[self.num_fpn_stages - idx - 1] + top_fpn_feature_map = fpn_feature_maps[-1] + top_fpn_feature_map = lateral_conv(top_fpn_feature_map) + fpn_feature_maps[-1] = top_fpn_feature_map + top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest") + if self.encoder_fuse_op == "sum": + fused_feature_map = top_fpn_feature_map + backbone_feature_map + else: + fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1) + new_fpn_feature_map = fpn_block(fused_feature_map) + fpn_feature_maps.append(new_fpn_feature_map) + + fpn_feature_maps.reverse() + + pan_feature_maps = [fpn_feature_maps[0]] + for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)): + top_pan_feature_map = pan_feature_maps[-1] + fpn_feature_map = fpn_feature_maps[idx + 1] + downsampled_feature_map = downsample_conv(top_pan_feature_map) + if self.encoder_fuse_op == "sum": + fused_feature_map = downsampled_feature_map + fpn_feature_map + else: + fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1) + new_pan_feature_map = pan_block(fused_feature_map) + pan_feature_maps.append(new_pan_feature_map) + + return BaseModelOutput(last_hidden_state=pan_feature_maps) + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +def weighting_function(max_num_bins: int, up: torch.Tensor, reg_scale: int) -> torch.Tensor: + """ + Generates the non-uniform Weighting Function W(n) for bounding box regression. + + Args: + max_num_bins (int): Max number of the discrete bins. + up (Tensor): Controls upper bounds of the sequence, + where maximum offset is ±up * H / W. + reg_scale (float): Controls the curvature of the Weighting Function. + Larger values result in flatter weights near the central axis W(max_num_bins/2)=0 + and steeper weights at both ends. + Returns: + Tensor: Sequence of Weighting Function. + """ + upper_bound1 = abs(up[0]) * abs(reg_scale) + upper_bound2 = abs(up[0]) * abs(reg_scale) * 2 + step = (upper_bound1 + 1) ** (2 / (max_num_bins - 2)) + left_values = [-((step) ** i) + 1 for i in range(max_num_bins // 2 - 1, 0, -1)] + right_values = [(step) ** i - 1 for i in range(1, max_num_bins // 2)] + values = [-upper_bound2] + left_values + [torch.zeros_like(up[0][None])] + right_values + [upper_bound2] + values = torch.cat(values, 0) + return values + + +def distance2bbox(points, distance: torch.Tensor, reg_scale: float) -> torch.Tensor: + """ + Decodes edge-distances into bounding box coordinates. + + Args: + points (`torch.Tensor`): + (batch_size, num_boxes, 4) or (num_boxes, 4) format, representing [x_center, y_center, width, height] + distance (`torch.Tensor`): + (batch_size, num_boxes, 4) or (num_boxes, 4), representing distances from the point to the left, top, right, and bottom boundaries. + reg_scale (`float`): + Controls the curvature of the Weighting Function. + Returns: + `torch.Tensor`: Bounding boxes in (batch_size, num_boxes, 4) or (num_boxes, 4) format, representing [x_center, y_center, width, height] + """ + reg_scale = abs(reg_scale) + top_left_x = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * (points[..., 2] / reg_scale) + top_left_y = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * (points[..., 3] / reg_scale) + bottom_right_x = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * (points[..., 2] / reg_scale) + bottom_right_y = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * (points[..., 3] / reg_scale) + + bboxes = torch.stack([top_left_x, top_left_y, bottom_right_x, bottom_right_y], -1) + + return corners_to_center_format(bboxes) + + +class Deimv2Decoder(Deimv2PreTrainedModel): + """ + D-FINE Decoder implementing Fine-grained Distribution Refinement (FDR). + + This decoder refines object detection predictions through iterative updates across multiple layers, + utilizing attention mechanisms, location quality estimators, and distribution refinement techniques + to improve bounding box accuracy and robustness. + """ + + _can_record_outputs = { + "hidden_states": Deimv2DecoderLayer, + "attentions": Deimv2SelfAttention, + "cross_attentions": Deimv2MultiscaleDeformableAttention, + } + + def __init__(self, config: Deimv2Config): + super().__init__(config) + self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx + self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx + + self.dropout = config.dropout + self.layers = nn.ModuleList( + [Deimv2DecoderLayer(config) for _ in range(config.decoder_layers)] + + [Deimv2DecoderLayer(config) for _ in range(config.decoder_layers - self.eval_idx - 1)] + ) + self.query_pos_head = Deimv2MLP(4, config.d_model, config.d_model, 3, config.decoder_activation_function) + + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + self.reg_scale = nn.Parameter(torch.tensor([config.reg_scale]), requires_grad=False) + self.max_num_bins = config.max_num_bins + self.d_model = config.d_model + self.layer_scale = config.layer_scale + self.pre_bbox_head = Deimv2MLP(config.hidden_size, config.hidden_size, 4, 3) + self.integral = Deimv2Integral(config) + self.num_head = config.decoder_attention_heads + self.up = nn.Parameter(torch.tensor([config.up]), requires_grad=False) + self.lqe_layers = nn.ModuleList([Deimv2LQE(config) for _ in range(config.decoder_layers)]) + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + def forward( + self, + encoder_hidden_states: torch.Tensor, + reference_points: torch.Tensor, + inputs_embeds: torch.Tensor, + spatial_shapes, + level_start_index=None, + spatial_shapes_list=None, + encoder_attention_mask=None, + memory_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> Deimv2DecoderOutput: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + The query embeddings that are passed into the decoder. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected + in `[0, 1]`: + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*): + Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area. + spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of the feature maps. + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*): + Indexes for the start of each feature level. In range `[0, sequence_length]`. + """ + if inputs_embeds is not None: + hidden_states = inputs_embeds + + # decoder layers + intermediate = () + intermediate_reference_points = () + intermediate_logits = () + intermediate_predicted_corners = () + initial_reference_points = () + + output_detach = pred_corners_undetach = 0 + + project = weighting_function(self.max_num_bins, self.up, self.reg_scale) + ref_points_detach = F.sigmoid(reference_points) + + for i, decoder_layer in enumerate(self.layers): + ref_points_input = ref_points_detach.unsqueeze(2) + query_pos_embed = self.query_pos_head(ref_points_detach).clamp(min=-10, max=10) + + hidden_states = decoder_layer( + hidden_states, + position_embeddings=query_pos_embed, + reference_points=ref_points_input, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + **kwargs, + ) + + if i == 0: + # Initial bounding box predictions with inverse sigmoid refinement + new_reference_points = F.sigmoid( + self.pre_bbox_head(hidden_states) + inverse_sigmoid(ref_points_detach) + ) + ref_points_initial = new_reference_points.detach() + + # Refine bounding box corners using FDR, integrating previous layer's corrections + if self.bbox_embed is not None: + pred_corners = self.bbox_embed[i](hidden_states + output_detach) + pred_corners_undetach + inter_ref_bbox = distance2bbox( + ref_points_initial, self.integral(pred_corners, project), self.reg_scale + ) + pred_corners_undetach = pred_corners + ref_points_detach = inter_ref_bbox.detach() + + output_detach = hidden_states.detach() + + intermediate += (hidden_states,) + + if self.class_embed is not None and (self.training or i == self.eval_idx): + scores = self.class_embed[i](hidden_states) + # Add initial logits and reference points with pre-bbox head + if i == 0: + intermediate_logits += (scores,) + intermediate_reference_points += (new_reference_points,) + # Lqe does not affect the performance here. + scores = self.lqe_layers[i](scores, pred_corners) + intermediate_logits += (scores,) + intermediate_reference_points += (inter_ref_bbox,) + initial_reference_points += (ref_points_initial,) + intermediate_predicted_corners += (pred_corners,) + + # Keep batch_size as first dimension + intermediate = torch.stack(intermediate) + if self.class_embed is not None and self.bbox_embed is not None: + intermediate_logits = torch.stack(intermediate_logits, dim=1) + intermediate_predicted_corners = torch.stack(intermediate_predicted_corners, dim=1) + initial_reference_points = torch.stack(initial_reference_points, dim=1) + intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1) + + return Deimv2DecoderOutput( + last_hidden_state=hidden_states, + intermediate_hidden_states=intermediate, + intermediate_logits=intermediate_logits, + intermediate_reference_points=intermediate_reference_points, + intermediate_predicted_corners=intermediate_predicted_corners, + initial_reference_points=initial_reference_points, + ) + + +class Deimv2FrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the RT-DETR encoder-decoder model. + """ +) +class Deimv2ModelOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate predicted corners (predicted corners of each layer of the decoder). + initial_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points used for the first decoder layer. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`): + Logits of predicted bounding boxes coordinates in the encoder stage. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values. + """ + + last_hidden_state: torch.FloatTensor | None = None + intermediate_hidden_states: torch.FloatTensor | None = None + intermediate_logits: torch.FloatTensor | None = None + intermediate_reference_points: torch.FloatTensor | None = None + intermediate_predicted_corners: torch.FloatTensor | None = None + initial_reference_points: torch.FloatTensor | None = None + decoder_hidden_states: tuple[torch.FloatTensor] | None = None + decoder_attentions: tuple[torch.FloatTensor] | None = None + cross_attentions: tuple[torch.FloatTensor] | None = None + encoder_last_hidden_state: torch.FloatTensor | None = None + encoder_hidden_states: tuple[torch.FloatTensor] | None = None + encoder_attentions: tuple[torch.FloatTensor] | None = None + init_reference_points: torch.FloatTensor | None = None + enc_topk_logits: torch.FloatTensor | None = None + enc_topk_bboxes: torch.FloatTensor | None = None + enc_outputs_class: torch.FloatTensor | None = None + enc_outputs_coord_logits: torch.FloatTensor | None = None + denoising_meta_values: dict | None = None + + +def replace_batch_norm(model): + r""" + Recursively replace all `torch.nn.BatchNorm2d` with `Deimv2FrozenBatchNorm2d`. + + Args: + model (torch.nn.Module): + input model + """ + for name, module in model.named_children(): + if isinstance(module, nn.BatchNorm2d): + new_module = Deimv2FrozenBatchNorm2d(module.num_features) + + if module.weight.device != torch.device("meta"): + new_module.weight.copy_(module.weight) + new_module.bias.copy_(module.bias) + new_module.running_mean.copy_(module.running_mean) + new_module.running_var.copy_(module.running_var) + + model._modules[name] = new_module + + if len(list(module.children())) > 0: + replace_batch_norm(module) + + +class Deimv2ConvEncoder(nn.Module): + """ + Convolutional backbone using the modeling_deimv2_resnet.py. + + nn.BatchNorm2d layers are replaced by Deimv2FrozenBatchNorm2d as defined above. + https://github.com/lyuwenyu/RT-DETR/blob/main/Deimv2_pytorch/src/nn/backbone/presnet.py#L142 + """ + + def __init__(self, config): + super().__init__() + + backbone = load_backbone(config) + + if config.freeze_backbone_batch_norms: + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) + self.model = backbone + self.intermediate_channel_sizes = self.model.channels + + def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): + # send pixel_values through the model to get list of feature maps + features = self.model(pixel_values).feature_maps + + out = [] + for feature_map in features: + # downsample pixel_mask to match shape of corresponding feature_map + mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] + out.append((feature_map, mask)) + return out + + +def get_contrastive_denoising_training_group( + targets, + num_classes, + num_queries, + class_embed, + num_denoising_queries=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, +): + """ + Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes. + + Args: + targets (`list[dict]`): + The target objects, each containing 'class_labels' and 'boxes' for objects in an image. + num_classes (`int`): + Total number of classes in the dataset. + num_queries (`int`): + Number of query slots in the transformer. + class_embed (`callable`): + A function or a model layer to embed class labels. + num_denoising_queries (`int`, *optional*, defaults to 100): + Number of denoising queries. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + Ratio of noise applied to labels. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale of noise applied to bounding boxes. + Returns: + `tuple` comprising various elements: + - **input_query_class** (`torch.FloatTensor`) -- + Class queries with applied label noise. + - **input_query_bbox** (`torch.FloatTensor`) -- + Bounding box queries with applied box noise. + - **attn_mask** (`torch.FloatTensor`) -- + Attention mask for separating denoising and reconstruction queries. + - **denoising_meta_values** (`dict`) -- + Metadata including denoising positive indices, number of groups, and split sizes. + """ + + if num_denoising_queries <= 0: + return None, None, None, None + + num_ground_truths = [len(t["class_labels"]) for t in targets] + device = targets[0]["class_labels"].device + + max_gt_num = max(num_ground_truths) + if max_gt_num == 0: + return None, None, None, None + + num_groups_denoising_queries = num_denoising_queries // max_gt_num + num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries + # pad gt to max_num of a batch + batch_size = len(num_ground_truths) + + input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device) + input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device) + pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device) + + for i in range(batch_size): + num_gt = num_ground_truths[i] + if num_gt > 0: + input_query_class[i, :num_gt] = targets[i]["class_labels"] + input_query_bbox[i, :num_gt] = targets[i]["boxes"] + pad_gt_mask[i, :num_gt] = 1 + # each group has positive and negative queries. + input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries]) + input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1]) + pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries]) + # positive and negative mask + negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device) + negative_gt_mask[:, max_gt_num:] = 1 + negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1]) + positive_gt_mask = 1 - negative_gt_mask + # contrastive denoising training positive index + positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask + denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1] + denoise_positive_idx = torch.split( + denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths] + ) + # total denoising queries + num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries) + + if label_noise_ratio > 0: + mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5) + # randomly put a new one here + new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype) + input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class) + + if box_noise_scale > 0: + known_bbox = center_to_corners_format(input_query_bbox) + diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale + rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0 + rand_part = torch.rand_like(input_query_bbox) + rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask) + rand_part *= rand_sign + known_bbox += rand_part * diff + known_bbox.clip_(min=0.0, max=1.0) + input_query_bbox = corners_to_center_format(known_bbox) + input_query_bbox = inverse_sigmoid(input_query_bbox) + + input_query_class = class_embed(input_query_class) + + target_size = num_denoising_queries + num_queries + attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device) + # match query cannot see the reconstruction + attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf + + # reconstructions cannot see each other + for i in range(num_groups_denoising_queries): + idx_block_start = max_gt_num * 2 * i + idx_block_end = max_gt_num * 2 * (i + 1) + attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf + attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf + + denoising_meta_values = { + "dn_positive_idx": denoise_positive_idx, + "dn_num_group": num_groups_denoising_queries, + "dn_num_split": [num_denoising_queries, num_queries], + } + + return input_query_class, input_query_bbox, attn_mask, denoising_meta_values + + +@auto_docstring( + custom_intro=""" + RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top. + """ +) +class Deimv2Model(Deimv2PreTrainedModel): + def __init__(self, config: Deimv2Config): + super().__init__(config) + + # Create backbone + self.backbone = Deimv2ConvEncoder(config) + intermediate_channel_sizes = self.backbone.intermediate_channel_sizes + num_backbone_outs = len(config.decoder_in_channels) + encoder_input_proj_list = [] + for i in range(num_backbone_outs): + in_channels = intermediate_channel_sizes[i] + encoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(config.encoder_hidden_dim), + ) + ) + self.encoder_input_proj = nn.ModuleList(encoder_input_proj_list) + self.encoder = Deimv2HybridEncoder(config=config) + + # denoising part + if config.num_denoising > 0: + self.denoising_class_embed = nn.Embedding( + config.num_labels + 1, config.d_model, padding_idx=config.num_labels + ) + + # decoder embedding + if config.learn_initial_query: + self.weight_embedding = nn.Embedding(config.num_queries, config.d_model) + + # encoder head + self.enc_output = nn.Sequential( + nn.Linear(config.d_model, config.d_model), + nn.LayerNorm(config.d_model, eps=config.layer_norm_eps), + ) + self.enc_score_head = nn.Linear(config.d_model, config.num_labels) + self.enc_bbox_head = Deimv2MLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3) + + # init encoder output anchors and valid_mask + if config.anchor_image_size: + self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype) + num_backbone_outs = len(config.decoder_in_channels) + decoder_input_proj_list = [] + for i in range(num_backbone_outs): + in_channels = config.decoder_in_channels[i] + decoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False), + nn.BatchNorm2d(config.d_model, config.batch_norm_eps), + ) + ) + for _ in range(config.num_feature_levels - num_backbone_outs): + decoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(config.d_model, config.batch_norm_eps), + ) + ) + in_channels = config.d_model + self.decoder = Deimv2Decoder(config) + decoder_input_proj = [] + in_channels = config.decoder_in_channels[-1] + for _ in range(num_backbone_outs): + if config.hidden_size == config.decoder_in_channels[-1]: + decoder_input_proj.append(nn.Identity()) + else: + conv = nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False) + batchnorm = nn.BatchNorm2d(config.d_model, config.batch_norm_eps) + decoder_input_proj.append(nn.Sequential(conv, batchnorm)) + for _ in range(config.num_feature_levels - num_backbone_outs): + if config.hidden_size == config.decoder_in_channels[-1]: + decoder_input_proj.append(nn.Identity()) + else: + conv = nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False) + batchnorm = nn.BatchNorm2d(config.d_model, config.batch_norm_eps) + decoder_input_proj.append(nn.Sequential(conv, batchnorm)) + self.decoder_input_proj = nn.ModuleList(decoder_input_proj) + + if config.use_spatial_tuning_adapter: + self.spatial_tuning_adapter = Deimv2SpatialTuningAdapter(config) + + self.post_init() + + def freeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad_(False) + + def unfreeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad_(True) + + @compile_compatible_method_lru_cache(maxsize=32) + def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32): + if spatial_shapes is None: + spatial_shapes = [ + [int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)] + for s in self.config.feat_strides + ] + anchors = [] + for level, (height, width) in enumerate(spatial_shapes): + grid_y, grid_x = torch.meshgrid( + torch.arange(end=height, device=device).to(dtype), + torch.arange(end=width, device=device).to(dtype), + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], -1) + grid_xy = grid_xy.unsqueeze(0) + 0.5 + grid_xy[..., 0] /= width + grid_xy[..., 1] /= height + wh = torch.ones_like(grid_xy) * grid_size * (2.0**level) + anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4)) + # define the valid range for anchor coordinates + eps = 1e-2 + anchors = torch.concat(anchors, 1) + valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) + anchors = torch.log(anchors / (1 - anchors)) + anchors = torch.where(valid_mask, anchors, torch.tensor(torch.finfo(dtype).max, dtype=dtype, device=device)) + + return anchors, valid_mask + + @auto_docstring + @can_return_tuple + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: torch.LongTensor | None = None, + encoder_outputs: torch.FloatTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: list[dict] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor] | Deimv2ModelOutput: + r""" + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + labels (`list[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, Deimv2Model + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("PekingU/Deimv2_r50vd") + >>> model = Deimv2Model.from_pretrained("PekingU/Deimv2_r50vd") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 300, 256] + ```""" + if pixel_values is None and inputs_embeds is None: + raise ValueError("You have to specify either pixel_values or inputs_embeds") + + if inputs_embeds is None: + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + features = self.backbone(pixel_values, pixel_mask) + proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)] + else: + batch_size = inputs_embeds.shape[0] + device = inputs_embeds.device + proj_feats = inputs_embeds + + if encoder_outputs is None: + encoder_outputs = self.encoder( + proj_feats, + **kwargs, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput + elif not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Equivalent to def _get_encoder_input + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/Deimv2_pytorch/src/zoo/Deimv2/Deimv2_decoder.py#L412 + sources = [] + for level, source in enumerate(encoder_outputs.last_hidden_state): + sources.append(self.decoder_input_proj[level](source)) + + # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage + if self.config.num_feature_levels > len(sources): + _len_sources = len(sources) + sources.append(self.decoder_input_proj[_len_sources](encoder_outputs.last_hidden_state)[-1]) + for i in range(_len_sources + 1, self.config.num_feature_levels): + sources.append(self.decoder_input_proj[i](encoder_outputs.last_hidden_state[-1])) + + # Prepare encoder inputs (by flattening) + source_flatten = [] + spatial_shapes_list = [] + spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long) + for level, source in enumerate(sources): + height, width = source.shape[-2:] + spatial_shapes[level, 0] = height + spatial_shapes[level, 1] = width + spatial_shapes_list.append((height, width)) + source = source.flatten(2).transpose(1, 2) + source_flatten.append(source) + source_flatten = torch.cat(source_flatten, 1) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + + # prepare denoising training + if self.training and self.config.num_denoising > 0 and labels is not None: + ( + denoising_class, + denoising_bbox_unact, + attention_mask, + denoising_meta_values, + ) = get_contrastive_denoising_training_group( + targets=labels, + num_classes=self.config.num_labels, + num_queries=self.config.num_queries, + class_embed=self.denoising_class_embed, + num_denoising_queries=self.config.num_denoising, + label_noise_ratio=self.config.label_noise_ratio, + box_noise_scale=self.config.box_noise_scale, + ) + else: + denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None + + batch_size = len(source_flatten) + device = source_flatten.device + dtype = source_flatten.dtype + + # prepare input for decoder + if self.training or self.config.anchor_image_size is None: + # Pass spatial_shapes as tuple to make it hashable and make sure + # lru_cache is working for generate_anchors() + spatial_shapes_tuple = tuple(spatial_shapes_list) + anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype) + else: + anchors, valid_mask = self.anchors, self.valid_mask + anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype) + + # use the valid_mask to selectively retain values in the feature map where the mask is `True` + memory = valid_mask.to(source_flatten.dtype) * source_flatten + + output_memory = self.enc_output(memory) + + enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors + + _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1) + + reference_points_unact = enc_outputs_coord_logits.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1]) + ) + + enc_topk_bboxes = F.sigmoid(reference_points_unact) + if denoising_bbox_unact is not None: + reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1) + + enc_topk_logits = enc_outputs_class.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]) + ) + + # extract region features + if self.config.learn_initial_query: + target = self.weight_embedding.tile([batch_size, 1, 1]) + else: + target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1])) + target = target.detach() + + if denoising_class is not None: + target = torch.concat([denoising_class, target], 1) + + init_reference_points = reference_points_unact.detach() + + # decoder + decoder_outputs = self.decoder( + inputs_embeds=target, + encoder_hidden_states=source_flatten, + encoder_attention_mask=attention_mask, + reference_points=init_reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + **kwargs, + ) + + return Deimv2ModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + intermediate_logits=decoder_outputs.intermediate_logits, + intermediate_reference_points=decoder_outputs.intermediate_reference_points, + intermediate_predicted_corners=decoder_outputs.intermediate_predicted_corners, + initial_reference_points=decoder_outputs.initial_reference_points, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + init_reference_points=init_reference_points, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + enc_outputs_class=enc_outputs_class, + enc_outputs_coord_logits=enc_outputs_coord_logits, + denoising_meta_values=denoising_meta_values, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`Deimv2ForObjectDetection`]. + """ +) +class Deimv2ObjectDetectionOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~Deimv2ImageProcessor.post_process_object_detection`] to retrieve the + unnormalized (absolute) bounding boxes. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate predicted corners (predicted corners of each layer of the decoder). + initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked initial reference points (initial reference points of each layer of the decoder). + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the encoder. + enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the encoder. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values + """ + + loss: torch.FloatTensor | None = None + loss_dict: dict | None = None + logits: torch.FloatTensor | None = None + pred_boxes: torch.FloatTensor | None = None + auxiliary_outputs: list[dict] | None = None + last_hidden_state: torch.FloatTensor | None = None + intermediate_hidden_states: torch.FloatTensor | None = None + intermediate_logits: torch.FloatTensor | None = None + intermediate_reference_points: torch.FloatTensor | None = None + intermediate_predicted_corners: torch.FloatTensor | None = None + initial_reference_points: torch.FloatTensor | None = None + decoder_hidden_states: tuple[torch.FloatTensor] | None = None + decoder_attentions: tuple[torch.FloatTensor] | None = None + cross_attentions: tuple[torch.FloatTensor] | None = None + encoder_last_hidden_state: torch.FloatTensor | None = None + encoder_hidden_states: tuple[torch.FloatTensor] | None = None + encoder_attentions: tuple[torch.FloatTensor] | None = None + init_reference_points: tuple[torch.FloatTensor] | None = None + enc_topk_logits: torch.FloatTensor | None = None + enc_topk_bboxes: torch.FloatTensor | None = None + enc_outputs_class: torch.FloatTensor | None = None + enc_outputs_coord_logits: torch.FloatTensor | None = None + denoising_meta_values: dict | None = None + + +@auto_docstring( + custom_intro=""" + RT-DETR Model (consisting of a backbone and encoder-decoder) outputting bounding boxes and logits to be further + decoded into scores and classes. + """ +) +class Deimv2ForObjectDetection(Deimv2PreTrainedModel): + _no_split_modules = None + _tied_weights_keys = { + r"bbox_embed.(?![0])\d+": r"bbox_embed.0", + r"class_embed.(?![0])\d+": r"^class_embed.0", + "class_embed": "model.decoder.class_embed", + "bbox_embed": "model.decoder.bbox_embed", + } + + def __init__(self, config: Deimv2Config): + super().__init__(config) + + self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx + self.model = Deimv2Model(config) + scaled_dim = round(config.layer_scale * config.hidden_size) + num_pred = config.decoder_layers + self.class_embed = nn.ModuleList([nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList( + [ + Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) + for _ in range(self.eval_idx + 1) + ] + + [ + Deimv2MLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3) + for _ in range(config.decoder_layers - self.eval_idx - 1) + ] + ) + + self.model.decoder.class_embed = self.class_embed + self.model.decoder.bbox_embed = self.bbox_embed + self.post_init() + + def _set_aux_loss(self, outputs_class, outputs_coord): + return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)] + + @auto_docstring + @can_return_tuple + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: torch.LongTensor | None = None, + encoder_outputs: torch.FloatTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: list[dict] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor] | Deimv2ObjectDetectionOutput: + r""" + Example: + + ```python + >>> import torch + >>> from transformers.image_utils import load_image + >>> from transformers import AutoImageProcessor, Deimv2ForObjectDetection + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> image_processor = AutoImageProcessor.from_pretrained("Intellindust/DEIMv2_HGNetv2_N_COCO") + >>> model = Deimv2ForObjectDetection.from_pretrained("Intellindust/DEIMv2_HGNetv2_N_COCO") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 300, 80] + + >>> boxes = outputs.pred_boxes + >>> list(boxes.shape) + [1, 300, 4] + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = torch.tensor([image.size[::-1]]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes) + >>> result = results[0] # first image in batch + + >>> for score, label, box in zip(result["scores"], result["labels"], result["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + ``` + """ + outputs = self.model( + pixel_values, + pixel_mask=pixel_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + labels=labels, + **kwargs, + ) + + denoising_meta_values = outputs.denoising_meta_values if self.training else None + + outputs_class = outputs.intermediate_logits + outputs_coord = outputs.intermediate_reference_points + predicted_corners = outputs.intermediate_predicted_corners + initial_reference_points = outputs.initial_reference_points + + logits = outputs_class[:, -1] + pred_boxes = outputs_coord[:, -1] + + loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None + if labels is not None: + enc_topk_logits = outputs.enc_topk_logits + enc_topk_bboxes = outputs.enc_topk_bboxes + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, + labels, + self.device, + pred_boxes, + self.config, + outputs_class, + outputs_coord, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + denoising_meta_values=denoising_meta_values, + predicted_corners=predicted_corners, + initial_reference_points=initial_reference_points, + **kwargs, + ) + + return Deimv2ObjectDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=outputs.last_hidden_state, + intermediate_hidden_states=outputs.intermediate_hidden_states, + intermediate_logits=outputs.intermediate_logits, + intermediate_reference_points=outputs.intermediate_reference_points, + intermediate_predicted_corners=outputs.intermediate_predicted_corners, + initial_reference_points=outputs.initial_reference_points, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + init_reference_points=outputs.init_reference_points, + enc_topk_logits=outputs.enc_topk_logits, + enc_topk_bboxes=outputs.enc_topk_bboxes, + enc_outputs_class=outputs.enc_outputs_class, + enc_outputs_coord_logits=outputs.enc_outputs_coord_logits, + denoising_meta_values=outputs.denoising_meta_values, + ) + + +__all__ = ["Deimv2Model", "Deimv2PreTrainedModel", "Deimv2ForObjectDetection"] diff --git a/src/transformers/models/deimv2/modular_deimv2.py b/src/transformers/models/deimv2/modular_deimv2.py new file mode 100644 index 000000000000..be7cc90214e0 --- /dev/null +++ b/src/transformers/models/deimv2/modular_deimv2.py @@ -0,0 +1,819 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ... import initialization as init +from ...modeling_outputs import BaseModelOutput +from ...utils import logging +from ..d_fine.configuration_d_fine import DFineConfig +from ..d_fine.modeling_d_fine import ( + DFineConvNormLayer, + DFineDecoder, + DFineDecoderLayer, + DFineDecoderOutput, + DFineEncoderLayer, + DFineForObjectDetection, + DFineGate, + DFineHybridEncoder, + DFineIntegral, + DFineLQE, + DFineMLP, + DFineModel, + DFineMultiscaleDeformableAttention, + DFinePreTrainedModel, + DFineRepVggBlock, + DFineSCDown, +) +from ..rt_detr.modeling_rt_detr import RTDetrAIFILayer + + +logger = logging.get_logger(__name__) + + +class Deimv2Config(DFineConfig): + """ + This is the configuration class to store the configuration of a [`Deimv2Model`]. It is used to instantiate a + DEIMv2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of DEIMv2-L-COCO. + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + initializer_range (`float`, *optional*, defaults to 0.01): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_bias_prior_prob (`float`, *optional*): + The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`. + If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + batch_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the batch normalization layers. + backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `HGNetV2Config()`): + The configuration of the backbone model. + freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`): + Whether to freeze the batch normalization layers in the backbone. + encoder_hidden_dim (`int`, *optional*, defaults to 256): + Dimension of the layers in hybrid encoder. + encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`): + Multi level features input for encoder. + feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`): + Strides used in each feature map. + encoder_layers (`int`, *optional*, defaults to 1): + Total of layers to be used by the encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in encoder. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + dropout (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`): + Indexes of the projected layers to be used in the encoder. + positional_encoding_temperature (`int`, *optional*, defaults to 10000): + The temperature parameter used to create the positional encodings. + encoder_activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. + activation_function (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the general layer. + eval_size (`tuple[int, int]`, *optional*): + Height and width used to computes the effective height and width of the position embeddings after taking + into account the stride. + normalize_before (`bool`, *optional*, defaults to `False`): + Determine whether to apply layer normalization in the transformer encoder layer before self-attention and + feed-forward modules. + hidden_expansion (`float`, *optional*, defaults to 1.0): + Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer. + d_model (`int`, *optional*, defaults to 256): + Dimension of the layers exclude hybrid encoder. + num_queries (`int`, *optional*, defaults to 300): + Number of object queries. + decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`): + Multi level features dimension for decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of input feature levels. + decoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the decoder. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_activation_function (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the decoder. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_denoising (`int`, *optional*, defaults to 100): + The total number of denoising tasks or queries to be used for contrastive denoising. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + The fraction of denoising labels to which random noise should be added. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale or magnitude of noise to be added to the bounding boxes. + learn_initial_query (`bool`, *optional*, defaults to `False`): + Indicates whether the initial query embeddings for the decoder should be learned during training. + anchor_image_size (`tuple[int, int]`, *optional*): + Height and width of the input image used during evaluation to generate the bounding box anchors. + with_box_refine (`bool`, *optional*, defaults to `True`): + Whether to apply iterative bounding box refinement. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the architecture has an encoder decoder structure. + matcher_alpha (`float`, *optional*, defaults to 0.25): + Parameter alpha used by the Hungarian Matcher. + matcher_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used by the Hungarian Matcher. + matcher_class_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the class loss used by the Hungarian Matcher. + matcher_bbox_cost (`float`, *optional*, defaults to 5.0): + The relative weight of the bounding box loss used by the Hungarian Matcher. + matcher_giou_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the giou loss of used by the Hungarian Matcher. + use_focal_loss (`bool`, *optional*, defaults to `True`): + Parameter informing if focal loss should be used. + auxiliary_loss (`bool`, *optional*, defaults to `True`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + focal_loss_alpha (`float`, *optional*, defaults to 0.75): + Parameter alpha used to compute the focal loss. + focal_loss_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used to compute the focal loss. + weight_loss_vfl (`float`, *optional*, defaults to 1.0): + Relative weight of the varifocal loss in the object detection loss. + weight_loss_mal (`float`, *optional*, defaults to 1.0): + Relative weight of the matching auxiliary loss in the object detection loss. + weight_loss_bbox (`float`, *optional*, defaults to 5.0): + Relative weight of the L1 bounding box loss in the object detection loss. + weight_loss_giou (`float`, *optional*, defaults to 2.0): + Relative weight of the generalized IoU loss in the object detection loss. + weight_loss_fgl (`float`, *optional*, defaults to 0.15): + Relative weight of the fine-grained localization loss in the object detection loss. + weight_loss_ddf (`float`, *optional*, defaults to 1.5): + Relative weight of the decoupled distillation focal loss in the object detection loss. + eos_coefficient (`float`, *optional*, defaults to 0.0001): + Relative classification weight of the 'no-object' class in the object detection loss. + eval_idx (`int`, *optional*, defaults to -1): + Index of the decoder layer to use for evaluation. + layer_scale (`float`, *optional*, defaults to `1.0`): + Scaling factor for the hidden dimension in later decoder layers. + max_num_bins (`int`, *optional*, defaults to 32): + Maximum number of bins for the distribution-guided bounding box refinement. + reg_scale (`float`, *optional*, defaults to 4.0): + Scale factor for the regression distribution. + depth_mult (`float`, *optional*, defaults to 1.0): + Multiplier for the number of blocks in RepNCSPELAN5 layers. + top_prob_values (`int`, *optional*, defaults to 4): + Number of top probability values to consider from each corner's distribution. + lqe_hidden_dim (`int`, *optional*, defaults to 64): + Hidden dimension size for the Location Quality Estimator (LQE) network. + lqe_layers (`int`, *optional*, defaults to 2): + Number of layers in the Location Quality Estimator MLP. + decoder_offset_scale (`float`, *optional*, defaults to 0.5): + Offset scale used in deformable attention. + decoder_method (`str`, *optional*, defaults to `"default"`): + The method to use for the decoder: `"default"` or `"discrete"`. + up (`float`, *optional*, defaults to 0.5): + Controls the upper bounds of the Weighting Function. + use_dense_o2o (`bool`, *optional*, defaults to `True`): + Whether to use dense one-to-one matching across decoder layers. + mal_alpha (`float`, *optional*): + Alpha parameter for the Matching Auxiliary Loss (MAL). If `None`, uses `focal_loss_alpha`. + encoder_fuse_op (`str`, *optional*, defaults to `"sum"`): + Fusion operation used in the encoder FPN. DEIMv2 uses `"sum"` instead of D-Fine's `"cat"`. + use_spatial_tuning_adapter (`bool`, *optional*, defaults to `False`): + Whether to use the Spatial Tuning Adapter (STA) for DINOv2 backbone variants. + sta_inplanes (`int`, *optional*, defaults to 16): + Number of input planes for the STA convolutional stem. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings. + """ + + model_type = "deimv2" + + def __init__( + self, + initializer_range=0.01, + initializer_bias_prior_prob=None, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + backbone_config=None, + freeze_backbone_batch_norms=True, + encoder_hidden_dim=256, + encoder_in_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=1024, + encoder_attention_heads=8, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + hidden_expansion=1.0, + d_model=256, + num_queries=300, + decoder_in_channels=[256, 256, 256], + decoder_ffn_dim=1024, + num_feature_levels=3, + decoder_n_points=4, + decoder_layers=6, + decoder_attention_heads=8, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + with_box_refine=True, + is_encoder_decoder=True, + matcher_alpha=0.25, + matcher_gamma=2.0, + matcher_class_cost=2.0, + matcher_bbox_cost=5.0, + matcher_giou_cost=2.0, + use_focal_loss=True, + auxiliary_loss=True, + focal_loss_alpha=0.75, + focal_loss_gamma=2.0, + weight_loss_vfl=1.0, + weight_loss_bbox=5.0, + weight_loss_giou=2.0, + weight_loss_fgl=0.15, + weight_loss_ddf=1.5, + eos_coefficient=1e-4, + eval_idx=-1, + layer_scale=1, + max_num_bins=32, + reg_scale=4.0, + depth_mult=1.0, + top_prob_values=4, + lqe_hidden_dim=64, + lqe_layers=2, + decoder_offset_scale=0.5, + decoder_method="default", + up=0.5, + weight_loss_mal=1.0, + use_dense_o2o=True, + mal_alpha=None, + encoder_fuse_op="sum", + use_spatial_tuning_adapter=False, + sta_inplanes=16, + tie_word_embeddings=True, + **kwargs, + ): + self.weight_loss_mal = weight_loss_mal + self.use_dense_o2o = use_dense_o2o + self.mal_alpha = mal_alpha + self.encoder_fuse_op = encoder_fuse_op + self.use_spatial_tuning_adapter = use_spatial_tuning_adapter + self.sta_inplanes = sta_inplanes + super().__init__( + initializer_range=initializer_range, + initializer_bias_prior_prob=initializer_bias_prior_prob, + layer_norm_eps=layer_norm_eps, + batch_norm_eps=batch_norm_eps, + backbone_config=backbone_config, + freeze_backbone_batch_norms=freeze_backbone_batch_norms, + encoder_hidden_dim=encoder_hidden_dim, + encoder_in_channels=encoder_in_channels, + feat_strides=feat_strides, + encoder_layers=encoder_layers, + encoder_ffn_dim=encoder_ffn_dim, + encoder_attention_heads=encoder_attention_heads, + dropout=dropout, + activation_dropout=activation_dropout, + encode_proj_layers=encode_proj_layers, + positional_encoding_temperature=positional_encoding_temperature, + encoder_activation_function=encoder_activation_function, + activation_function=activation_function, + eval_size=eval_size, + normalize_before=normalize_before, + hidden_expansion=hidden_expansion, + d_model=d_model, + num_queries=num_queries, + decoder_in_channels=decoder_in_channels, + decoder_ffn_dim=decoder_ffn_dim, + num_feature_levels=num_feature_levels, + decoder_n_points=decoder_n_points, + decoder_layers=decoder_layers, + decoder_attention_heads=decoder_attention_heads, + decoder_activation_function=decoder_activation_function, + attention_dropout=attention_dropout, + num_denoising=num_denoising, + label_noise_ratio=label_noise_ratio, + box_noise_scale=box_noise_scale, + learn_initial_query=learn_initial_query, + anchor_image_size=anchor_image_size, + with_box_refine=with_box_refine, + is_encoder_decoder=is_encoder_decoder, + matcher_alpha=matcher_alpha, + matcher_gamma=matcher_gamma, + matcher_class_cost=matcher_class_cost, + matcher_bbox_cost=matcher_bbox_cost, + matcher_giou_cost=matcher_giou_cost, + use_focal_loss=use_focal_loss, + auxiliary_loss=auxiliary_loss, + focal_loss_alpha=focal_loss_alpha, + focal_loss_gamma=focal_loss_gamma, + weight_loss_vfl=weight_loss_vfl, + weight_loss_bbox=weight_loss_bbox, + weight_loss_giou=weight_loss_giou, + weight_loss_fgl=weight_loss_fgl, + weight_loss_ddf=weight_loss_ddf, + eos_coefficient=eos_coefficient, + eval_idx=eval_idx, + layer_scale=layer_scale, + max_num_bins=max_num_bins, + reg_scale=reg_scale, + depth_mult=depth_mult, + top_prob_values=top_prob_values, + lqe_hidden_dim=lqe_hidden_dim, + lqe_layers=lqe_layers, + decoder_offset_scale=decoder_offset_scale, + decoder_method=decoder_method, + up=up, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class Deimv2DecoderOutput(DFineDecoderOutput): + pass + + +class Deimv2RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.float() + hidden_states = hidden_states * torch.rsqrt(hidden_states.pow(2).mean(-1, keepdim=True) + self.eps) + return (hidden_states * self.scale).to(input_dtype) + + +class Deimv2SwiGLUFFN(nn.Module): + def __init__(self, in_features: int, hidden_features: int, out_features: int): + super().__init__() + self.w12 = nn.Linear(in_features, 2 * hidden_features) + self.w3 = nn.Linear(hidden_features, out_features) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + x12 = self.w12(hidden_states) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +class Deimv2Gate(DFineGate): + def __init__(self, d_model: int): + super().__init__(d_model) + self.norm = Deimv2RMSNorm(d_model) + + +class Deimv2MLP(DFineMLP): + pass + + +class Deimv2MultiscaleDeformableAttention(DFineMultiscaleDeformableAttention): + pass + + +class Deimv2ConvNormLayer(DFineConvNormLayer): + pass + + +class Deimv2RepVggBlock(DFineRepVggBlock): + pass + + +class Deimv2CSPRepLayer2(nn.Module): + def __init__( + self, config: Deimv2Config, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0 + ): + super().__init__() + activation = config.activation_function + hidden_channels = int(out_channels * expansion) + self.conv1 = Deimv2ConvNormLayer(config, in_channels, hidden_channels * 2, 1, 1, activation=activation) + self.bottlenecks = nn.ModuleList( + [Deimv2RepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)] + ) + self.conv3 = Deimv2ConvNormLayer(config, hidden_channels, out_channels, 3, 1, activation=activation) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + chunks = list(self.conv1(hidden_state).chunk(2, 1)) + bottleneck_out = chunks[1] + for bottleneck in self.bottlenecks: + bottleneck_out = bottleneck(bottleneck_out) + return self.conv3(chunks[0] + bottleneck_out) + + +class Deimv2RepNCSPELAN5(nn.Module): + def __init__(self, config: Deimv2Config, numb_blocks: int = 3): + super().__init__() + act = config.activation_function + c1 = config.encoder_hidden_dim + c2 = config.encoder_hidden_dim + c3 = config.encoder_hidden_dim * 2 + c4 = round(config.hidden_expansion * config.encoder_hidden_dim // 2) + self.conv_dim = c3 // 2 + self.conv1 = Deimv2ConvNormLayer(config, c1, c3, 1, 1, activation=act) + self.csp_rep1 = Deimv2CSPRepLayer2(config, c3 // 2, c4, num_blocks=numb_blocks) + self.csp_rep2 = Deimv2CSPRepLayer2(config, c4, c4, num_blocks=numb_blocks) + self.conv4 = Deimv2ConvNormLayer(config, c3 + (2 * c4), c2, 1, 1, activation=act) + + def forward(self, input_features: torch.Tensor) -> torch.Tensor: + split_features = list(self.conv1(input_features).split((self.conv_dim, self.conv_dim), 1)) + branch1 = self.csp_rep1(split_features[-1]) + branch2 = self.csp_rep2(branch1) + split_features.extend([branch1, branch2]) + merged_features = torch.cat(split_features, 1) + return self.conv4(merged_features) + + +class Deimv2SCDown(DFineSCDown): + pass + + +class Deimv2EncoderLayer(DFineEncoderLayer): + pass + + +class Deimv2AIFILayer(RTDetrAIFILayer): + pass + + +class Deimv2SpatialTuningAdapter(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + inplanes = config.sta_inplanes + self.stem = nn.Sequential( + nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(inplanes), + nn.GELU(), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + self.conv2 = nn.Sequential( + nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(2 * inplanes), + ) + self.conv3 = nn.Sequential( + nn.GELU(), + nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(4 * inplanes), + ) + self.conv4 = nn.Sequential( + nn.GELU(), + nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(4 * inplanes), + ) + + def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + c1 = self.stem(pixel_values) + c2 = self.conv2(c1) + c3 = self.conv3(c2) + c4 = self.conv4(c3) + return c2, c3, c4 + + +class Deimv2Integral(DFineIntegral): + pass + + +class Deimv2LQE(DFineLQE): + pass + + +class Deimv2DecoderLayer(DFineDecoderLayer): + def __init__(self, config: Deimv2Config): + super().__init__(config) + self.encoder_attn = Deimv2MultiscaleDeformableAttention(config=config) + self.gateway = Deimv2Gate(config.d_model) + self.self_attn_layer_norm = Deimv2RMSNorm(config.d_model) + self.final_layer_norm = Deimv2RMSNorm(config.d_model) + self.mlp = Deimv2SwiGLUFFN(config.d_model, config.decoder_ffn_dim // 2, config.d_model) + + +class Deimv2MLPPredictionHead(DFineMLP): + pass + + +class Deimv2PreTrainedModel(DFinePreTrainedModel): + @torch.no_grad() + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (Deimv2ForObjectDetection, Deimv2Decoder)): + if module.class_embed is not None: + for layer in module.class_embed: + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias = float(-math.log((1 - prior_prob) / prior_prob)) + init.xavier_uniform_(layer.weight) + init.constant_(layer.bias, bias) + + if module.bbox_embed is not None: + for layer in module.bbox_embed: + init.constant_(layer.layers[-1].weight, 0) + init.constant_(layer.layers[-1].bias, 0) + + if hasattr(module, "reg_scale"): + init.constant_(module.reg_scale, self.config.reg_scale) + + if hasattr(module, "up"): + init.constant_(module.up, self.config.up) + + if isinstance(module, Deimv2MultiscaleDeformableAttention): + init.constant_(module.sampling_offsets.weight, 0.0) + default_dtype = torch.get_default_dtype() + thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( + 2.0 * math.pi / module.n_heads + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values + grid_init = grid_init.reshape(module.n_heads, 1, 2).tile([1, sum(module.num_points_list), 1]) + scaling = torch.concat([torch.arange(1, n + 1) for n in module.num_points_list]).reshape(1, -1, 1) + grid_init *= scaling + init.copy_(module.sampling_offsets.bias, grid_init.flatten()) + + init.constant_(module.attention_weights.weight, 0.0) + init.constant_(module.attention_weights.bias, 0.0) + + num_points_scale = [1 / n for n in module.num_points_list for _ in range(n)] + init.copy_(module.num_points_scale, torch.tensor(num_points_scale, dtype=torch.float32)) + + if isinstance(module, Deimv2Model): + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias = float(-math.log((1 - prior_prob) / prior_prob)) + init.xavier_uniform_(module.enc_score_head.weight) + init.constant_(module.enc_score_head.bias, bias) + + if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + init.zeros_(module.bias) + if getattr(module, "running_mean", None) is not None: + init.zeros_(module.running_mean) + init.ones_(module.running_var) + init.zeros_(module.num_batches_tracked) + + if isinstance(module, Deimv2Gate): + bias = float(-math.log((1 - 0.5) / 0.5)) + init.constant_(module.gate.bias, bias) + init.constant_(module.gate.weight, 0) + + if isinstance(module, Deimv2LQE): + init.constant_(module.reg_conf.layers[-1].bias, 0) + init.constant_(module.reg_conf.layers[-1].weight, 0) + + if isinstance(module, Deimv2SwiGLUFFN): + init.xavier_uniform_(module.w12.weight) + init.constant_(module.w12.bias, 0) + init.xavier_uniform_(module.w3.weight) + init.constant_(module.w3.bias, 0) + + if isinstance(module, Deimv2RMSNorm): + init.ones_(module.scale) + + if isinstance(module, nn.LayerNorm): + init.ones_(module.weight) + init.zeros_(module.bias) + + if hasattr(module, "weight_embedding") and self.config.learn_initial_query: + init.xavier_uniform_(module.weight_embedding.weight) + if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: + init.xavier_uniform_(module.denoising_class_embed.weight) + + +class Deimv2HybridEncoder(DFineHybridEncoder): + def __init__(self, config: Deimv2Config): + Deimv2PreTrainedModel.__init__(config) + self.config = config + self.in_channels = config.encoder_in_channels + self.num_fpn_stages = len(self.in_channels) - 1 + self.feat_strides = config.feat_strides + self.encoder_hidden_dim = config.encoder_hidden_dim + self.encode_proj_layers = config.encode_proj_layers + self.positional_encoding_temperature = config.positional_encoding_temperature + self.eval_size = config.eval_size + self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels] + self.out_strides = self.feat_strides + self.encoder_fuse_op = config.encoder_fuse_op + + self.aifi = nn.ModuleList([Deimv2AIFILayer(config) for _ in range(len(self.encode_proj_layers))]) + + self.lateral_convs = nn.ModuleList() + self.fpn_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1, 0, -1): + lateral_layer = Deimv2ConvNormLayer(config, self.encoder_hidden_dim, self.encoder_hidden_dim, 1, 1) + self.lateral_convs.append(lateral_layer) + num_blocks = round(3 * config.depth_mult) + fpn_layer = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks) + self.fpn_blocks.append(fpn_layer) + + self.downsample_convs = nn.ModuleList() + self.pan_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1): + self.downsample_convs.append(Deimv2SCDown(config, 3, 2)) + num_blocks = round(3 * config.depth_mult) + self.pan_blocks.append(Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks)) + + self.post_init() + + def forward( + self, + inputs_embeds=None, + **kwargs, + ) -> BaseModelOutput: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + """ + feature_maps = inputs_embeds + + if self.config.encoder_layers > 0: + for i, enc_ind in enumerate(self.encode_proj_layers): + feature_maps[enc_ind] = self.aifi[i](feature_maps[enc_ind], **kwargs) + + fpn_feature_maps = [feature_maps[-1]] + for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)): + backbone_feature_map = feature_maps[self.num_fpn_stages - idx - 1] + top_fpn_feature_map = fpn_feature_maps[-1] + top_fpn_feature_map = lateral_conv(top_fpn_feature_map) + fpn_feature_maps[-1] = top_fpn_feature_map + top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest") + if self.encoder_fuse_op == "sum": + fused_feature_map = top_fpn_feature_map + backbone_feature_map + else: + fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1) + new_fpn_feature_map = fpn_block(fused_feature_map) + fpn_feature_maps.append(new_fpn_feature_map) + + fpn_feature_maps.reverse() + + pan_feature_maps = [fpn_feature_maps[0]] + for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)): + top_pan_feature_map = pan_feature_maps[-1] + fpn_feature_map = fpn_feature_maps[idx + 1] + downsampled_feature_map = downsample_conv(top_pan_feature_map) + if self.encoder_fuse_op == "sum": + fused_feature_map = downsampled_feature_map + fpn_feature_map + else: + fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1) + new_pan_feature_map = pan_block(fused_feature_map) + pan_feature_maps.append(new_pan_feature_map) + + return BaseModelOutput(last_hidden_state=pan_feature_maps) + + +class Deimv2Decoder(DFineDecoder): + def __init__(self, config: Deimv2Config): + self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx + super().__init__(config=config) + self.query_pos_head = Deimv2MLP(4, config.d_model, config.d_model, 3, config.decoder_activation_function) + self.reg_scale = nn.Parameter(torch.tensor([config.reg_scale]), requires_grad=False) + self.max_num_bins = config.max_num_bins + self.d_model = config.d_model + self.layer_scale = config.layer_scale + self.pre_bbox_head = Deimv2MLP(config.hidden_size, config.hidden_size, 4, 3) + self.integral = Deimv2Integral(config) + self.num_head = config.decoder_attention_heads + self.up = nn.Parameter(torch.tensor([config.up]), requires_grad=False) + self.lqe_layers = nn.ModuleList([Deimv2LQE(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList( + [Deimv2DecoderLayer(config) for _ in range(config.decoder_layers)] + + [Deimv2DecoderLayer(config) for _ in range(config.decoder_layers - self.eval_idx - 1)] + ) + + +class Deimv2Model(DFineModel): + def __init__(self, config: Deimv2Config): + super().__init__(config) + del self.decoder_input_proj + self.encoder = Deimv2HybridEncoder(config=config) + num_backbone_outs = len(config.decoder_in_channels) + decoder_input_proj = [] + in_channels = config.decoder_in_channels[-1] + for _ in range(num_backbone_outs): + if config.hidden_size == config.decoder_in_channels[-1]: + decoder_input_proj.append(nn.Identity()) + else: + conv = nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False) + batchnorm = nn.BatchNorm2d(config.d_model, config.batch_norm_eps) + decoder_input_proj.append(nn.Sequential(conv, batchnorm)) + for _ in range(config.num_feature_levels - num_backbone_outs): + if config.hidden_size == config.decoder_in_channels[-1]: + decoder_input_proj.append(nn.Identity()) + else: + conv = nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False) + batchnorm = nn.BatchNorm2d(config.d_model, config.batch_norm_eps) + decoder_input_proj.append(nn.Sequential(conv, batchnorm)) + self.decoder_input_proj = nn.ModuleList(decoder_input_proj) + self.decoder = Deimv2Decoder(config) + + if config.use_spatial_tuning_adapter: + self.spatial_tuning_adapter = Deimv2SpatialTuningAdapter(config) + + +class Deimv2ForObjectDetection(DFineForObjectDetection): + _no_split_modules = None + _tied_weights_keys = { + r"bbox_embed.(?![0])\d+": r"bbox_embed.0", + r"class_embed.(?![0])\d+": r"^class_embed.0", + "class_embed": "model.decoder.class_embed", + "bbox_embed": "model.decoder.bbox_embed", + } + + def __init__(self, config: Deimv2Config): + Deimv2PreTrainedModel.__init__(self, config) + + self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx + self.model = Deimv2Model(config) + scaled_dim = round(config.layer_scale * config.hidden_size) + num_pred = config.decoder_layers + self.class_embed = nn.ModuleList([nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList( + [ + Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) + for _ in range(self.eval_idx + 1) + ] + + [ + Deimv2MLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3) + for _ in range(config.decoder_layers - self.eval_idx - 1) + ] + ) + + self.model.decoder.class_embed = self.class_embed + self.model.decoder.bbox_embed = self.bbox_embed + self.post_init() + + def forward(**super_kwargs): + r""" + Example: + + ```python + >>> import torch + >>> from transformers.image_utils import load_image + >>> from transformers import AutoImageProcessor, Deimv2ForObjectDetection + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> image_processor = AutoImageProcessor.from_pretrained("Intellindust/DEIMv2_HGNetv2_N_COCO") + >>> model = Deimv2ForObjectDetection.from_pretrained("Intellindust/DEIMv2_HGNetv2_N_COCO") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 300, 80] + + >>> boxes = outputs.pred_boxes + >>> list(boxes.shape) + [1, 300, 4] + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = torch.tensor([image.size[::-1]]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes) + >>> result = results[0] # first image in batch + + >>> for score, label, box in zip(result["scores"], result["labels"], result["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + ``` + """ + super().forward(**super_kwargs) + + +__all__ = [ + "Deimv2Config", + "Deimv2Model", + "Deimv2PreTrainedModel", + "Deimv2ForObjectDetection", +] diff --git a/tests/models/deimv2/__init__.py b/tests/models/deimv2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/deimv2/test_modeling_deimv2.py b/tests/models/deimv2/test_modeling_deimv2.py new file mode 100644 index 000000000000..9b612c7833d1 --- /dev/null +++ b/tests/models/deimv2/test_modeling_deimv2.py @@ -0,0 +1,655 @@ +# coding = utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch DEIMv2 model.""" + +import copy +import inspect +import math +import tempfile +import unittest + +from parameterized import parameterized + +from transformers import ( + Deimv2Config, + HGNetV2Config, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import ( + require_torch, + require_torch_accelerator, + slow, + torch_device, +) + + +if is_torch_available(): + import torch + + from transformers import Deimv2ForObjectDetection, Deimv2Model + +if is_vision_available(): + from PIL import Image + + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +class Deimv2ModelTester: + def __init__( + self, + parent, + batch_size=3, + is_training=True, + use_labels=True, + n_targets=3, + num_labels=10, + initializer_range=0.02, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + backbone_config=None, + encoder_hidden_dim=32, + encoder_in_channels=[128, 256, 512], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=64, + encoder_attention_heads=2, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + d_model=32, + num_queries=30, + decoder_in_channels=[32, 32, 32], + decoder_ffn_dim=64, + num_feature_levels=3, + decoder_n_points=[3, 6, 3], + decoder_n_levels=3, + decoder_layers=2, + decoder_attention_heads=2, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=0, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + image_size=64, + disable_custom_kernels=True, + with_box_refine=True, + decoder_offset_scale=0.5, + eval_idx=-1, + layer_scale=1, + reg_max=32, + reg_scale=4.0, + depth_mult=0.34, + hidden_expansion=0.5, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = 3 + self.is_training = is_training + self.use_labels = use_labels + self.n_targets = n_targets + self.num_labels = num_labels + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + self.backbone_config = backbone_config + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_layers = encoder_layers + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = encode_proj_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.eval_size = eval_size + self.normalize_before = normalize_before + self.d_model = d_model + self.num_queries = num_queries + self.decoder_in_channels = decoder_in_channels + self.decoder_ffn_dim = decoder_ffn_dim + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_n_levels = decoder_n_levels + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.decoder_offset_scale = decoder_offset_scale + self.eval_idx = eval_idx + self.layer_scale = layer_scale + self.reg_max = reg_max + self.reg_scale = reg_scale + self.depth_mult = depth_mult + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.image_size = image_size + self.disable_custom_kernels = disable_custom_kernels + self.with_box_refine = with_box_refine + self.hidden_expansion = hidden_expansion + + self.encoder_seq_length = math.ceil(self.image_size / 32) * math.ceil(self.image_size / 32) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device) + + labels = None + if self.use_labels: + labels = [] + for i in range(self.batch_size): + target = {} + target["class_labels"] = torch.randint( + high=self.num_labels, size=(self.n_targets,), device=torch_device + ) + target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device) + labels.append(target) + + config = self.get_config() + config.num_labels = self.num_labels + return config, pixel_values, pixel_mask, labels + + def get_config(self): + hidden_sizes = [64, 128, 256, 512] + backbone_config = HGNetV2Config( + stage_in_channels=[16, 64, 128, 256], + stage_mid_channels=[16, 32, 64, 128], + stage_out_channels=[64, 128, 256, 512], + stage_num_blocks=[1, 1, 2, 1], + stage_downsample=[False, True, True, True], + stage_light_block=[False, False, True, True], + stage_kernel_size=[3, 3, 5, 5], + stage_numb_of_layers=[3, 3, 3, 3], + embeddings_size=10, + hidden_sizes=hidden_sizes, + depths=[1, 1, 2, 1], + out_features=["stage2", "stage3", "stage4"], + out_indices=[2, 3, 4], + stem_channels=[3, 16, 16], + use_lab=True, + ) + return Deimv2Config( + backbone_config=backbone_config, + encoder_hidden_dim=self.encoder_hidden_dim, + encoder_in_channels=self.encoder_in_channels, + feat_strides=self.feat_strides, + encoder_layers=self.encoder_layers, + encoder_ffn_dim=self.encoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + dropout=self.dropout, + activation_dropout=self.activation_dropout, + encode_proj_layers=self.encode_proj_layers, + positional_encoding_temperature=self.positional_encoding_temperature, + encoder_activation_function=self.encoder_activation_function, + activation_function=self.activation_function, + eval_size=self.eval_size, + normalize_before=self.normalize_before, + d_model=self.d_model, + num_queries=self.num_queries, + decoder_in_channels=self.decoder_in_channels, + decoder_ffn_dim=self.decoder_ffn_dim, + num_feature_levels=self.num_feature_levels, + decoder_n_points=self.decoder_n_points, + decoder_n_levels=self.decoder_n_levels, + decoder_layers=self.decoder_layers, + decoder_attention_heads=self.decoder_attention_heads, + decoder_activation_function=self.decoder_activation_function, + decoder_offset_scale=self.decoder_offset_scale, + eval_idx=self.eval_idx, + layer_scale=self.layer_scale, + reg_max=self.reg_max, + reg_scale=self.reg_scale, + depth_mult=self.depth_mult, + attention_dropout=self.attention_dropout, + num_denoising=self.num_denoising, + label_noise_ratio=self.label_noise_ratio, + box_noise_scale=self.box_noise_scale, + learn_initial_query=self.learn_initial_query, + anchor_image_size=self.anchor_image_size, + image_size=self.image_size, + disable_custom_kernels=self.disable_custom_kernels, + with_box_refine=self.with_box_refine, + ) + + def prepare_config_and_inputs_for_common(self): + config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + def create_and_check_deimv2_model(self, config, pixel_values, pixel_mask, labels): + model = Deimv2Model(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.d_model)) + + def create_and_check_deimv2_object_detection_head_model(self, config, pixel_values, pixel_mask, labels): + model = Deimv2ForObjectDetection(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels) + + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + +@require_torch +class Deimv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (Deimv2Model, Deimv2ForObjectDetection) if is_torch_available() else () + pipeline_model_mapping = ( + {"image-feature-extraction": Deimv2Model, "object-detection": Deimv2ForObjectDetection} + if is_torch_available() + else {} + ) + is_encoder_decoder = True + + test_missing_keys = False + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "Deimv2ForObjectDetection": + labels = [] + for i in range(self.model_tester.batch_size): + target = {} + target["class_labels"] = torch.ones( + size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long + ) + target["boxes"] = torch.ones( + self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float + ) + labels.append(target) + inputs_dict["labels"] = labels + + return inputs_dict + + def setUp(self): + self.model_tester = Deimv2ModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=Deimv2Config, + has_text_modality=False, + common_properties=["hidden_size", "num_attention_heads"], + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_deimv2_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deimv2_model(*config_and_inputs) + + def test_deimv2_object_detection_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deimv2_object_detection_head_model(*config_and_inputs) + + @unittest.skip(reason="Deimv2 doesn't work well with `nn.DataParallel") + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip(reason="Deimv2 does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Deimv2 does not use test_inputs_embeds_matches_input_ids") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip(reason="Deimv2 does not support input and output embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="Deimv2 does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="Deimv2 does not use token embeddings") + def test_resize_tokens_embeddings(self): + pass + + @unittest.skip(reason="Feed forward chunking is not implemented") + def test_feed_forward_chunking(self): + pass + + @unittest.skip(reason="Weight tying is hardcoded (module_x = module_y) and always `True`") + def test_load_save_without_tied_weights(self): + pass + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class._from_config(config, attn_implementation="eager") + config = model.config + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions + self.assertEqual(len(attentions), self.model_tester.encoder_layers) + + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions + self.assertEqual(len(attentions), self.model_tester.encoder_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [ + self.model_tester.encoder_attention_heads, + self.model_tester.encoder_seq_length, + self.model_tester.encoder_seq_length, + ], + ) + out_len = len(outputs) + + correct_outlen = 15 + + if "labels" in inputs_dict: + correct_outlen += 1 + if model_class.__name__ == "Deimv2ForObjectDetection": + correct_outlen += 2 + + self.assertEqual(out_len, correct_outlen) + + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.decoder_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [ + self.model_tester.decoder_attention_heads, + self.model_tester.num_queries, + self.model_tester.num_queries, + ], + ) + + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.decoder_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_queries, + self.model_tester.decoder_attention_heads, + self.model_tester.decoder_n_levels * self.model_tester.decoder_n_points + if isinstance(self.model_tester.decoder_n_points, int) + else sum(self.model_tester.decoder_n_points), + ], + ) + + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + else: + added_hidden_states = 2 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions + + self.assertEqual(len(self_attentions), self.model_tester.encoder_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [ + self.model_tester.encoder_attention_heads, + self.model_tester.encoder_seq_length, + self.model_tester.encoder_seq_length, + ], + ) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", len(self.model_tester.encoder_in_channels) - 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[1].shape[-2:]), + [ + self.model_tester.image_size // self.model_tester.feat_strides[-1], + self.model_tester.image_size // self.model_tester.feat_strides[-1], + ], + ) + + if config.is_encoder_decoder: + hidden_states = outputs.decoder_hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.decoder_layers + 1 + ) + + self.assertIsInstance(hidden_states, (list, tuple)) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [self.model_tester.num_queries, self.model_tester.d_model], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + + outputs = model(**inputs) + + output = outputs[0] + + encoder_hidden_states = outputs.encoder_hidden_states[0] + encoder_attentions = outputs.encoder_attentions[0] + encoder_hidden_states.retain_grad() + encoder_attentions.retain_grad() + + decoder_attentions = outputs.decoder_attentions[0] + decoder_attentions.retain_grad() + + cross_attentions = outputs.cross_attentions[0] + cross_attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(encoder_hidden_states.grad) + self.assertIsNotNone(encoder_attentions.grad) + self.assertIsNotNone(decoder_attentions.grad) + self.assertIsNotNone(cross_attentions.grad) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + arg_names = [*signature.parameters.keys()] + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_backbone_selection(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def _validate_backbone_init(config): + for model_class in self.all_model_classes: + model = model_class(copy.deepcopy(config)) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if model_class.__name__ == "Deimv2ForObjectDetection": + expected_shape = ( + self.model_tester.batch_size, + self.model_tester.num_queries, + self.model_tester.num_labels, + ) + self.assertEqual(outputs.logits.shape, expected_shape) + self.assertEqual(len(model.model.backbone.intermediate_channel_sizes), 3) + else: + self.assertEqual(len(model.backbone.intermediate_channel_sizes), 3) + + self.assertTrue(outputs) + + config_dict = config.to_dict() + config_dict["encoder_in_channels"] = [24, 40, 432] + config_dict["backbone"] = "tf_mobilenetv3_small_075" + config_dict["backbone_config"] = None + config_dict["use_timm_backbone"] = True + config_dict["backbone_kwargs"] = {"out_indices": [2, 3, 4]} + config = config.__class__(**config_dict) + _validate_backbone_init(config) + + config_dict = config.to_dict() + config_dict["backbone"] = "microsoft/resnet-18" + config_dict["backbone_config"] = None + config_dict["use_timm_backbone"] = False + config_dict["use_pretrained_backbone"] = True + config_dict["backbone_kwargs"] = {"out_indices": [2, 3, 4]} + config = config.__class__(**config_dict) + _validate_backbone_init(config) + + @parameterized.expand(["float32", "float16", "bfloat16"]) + @require_torch_accelerator + @slow + def test_inference_with_different_dtypes(self, dtype_str): + dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[dtype_str] + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device).to(dtype) + model.eval() + for key, tensor in inputs_dict.items(): + if tensor.dtype == torch.float32: + inputs_dict[key] = tensor.to(dtype) + with torch.no_grad(): + _ = model(**self._prepare_for_class(inputs_dict, model_class)) + + @parameterized.expand(["float32", "float16", "bfloat16"]) + @require_torch_accelerator + @slow + def test_inference_equivalence_for_static_and_dynamic_anchors(self, dtype_str): + dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[dtype_str] + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + h, w = inputs_dict["pixel_values"].shape[-2:] + + for key, tensor in inputs_dict.items(): + if tensor.dtype == torch.float32: + inputs_dict[key] = tensor.to(dtype) + + for model_class in self.all_model_classes: + with tempfile.TemporaryDirectory() as tmpdirname: + model_class(config).save_pretrained(tmpdirname) + model_static = model_class.from_pretrained( + tmpdirname, anchor_image_size=[h, w], device_map=torch_device, dtype=dtype + ).eval() + model_dynamic = model_class.from_pretrained( + tmpdirname, anchor_image_size=None, device_map=torch_device, dtype=dtype + ).eval() + + self.assertIsNotNone(model_static.config.anchor_image_size) + self.assertIsNone(model_dynamic.config.anchor_image_size) + + with torch.no_grad(): + outputs_static = model_static(**self._prepare_for_class(inputs_dict, model_class)) + outputs_dynamic = model_dynamic(**self._prepare_for_class(inputs_dict, model_class)) + + torch.testing.assert_close( + outputs_static.last_hidden_state, outputs_dynamic.last_hidden_state, rtol=1e-4, atol=1e-4 + ) + + +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image From ddc1bd71b93335e8789e9768e285f2917463cd20 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Fri, 27 Feb 2026 22:25:05 +0400 Subject: [PATCH 02/25] fix: Fix ci/circleci: check_repository_consistency --- src/transformers/models/deimv2/modular_deimv2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/deimv2/modular_deimv2.py b/src/transformers/models/deimv2/modular_deimv2.py index be7cc90214e0..05bd3b8d4da5 100644 --- a/src/transformers/models/deimv2/modular_deimv2.py +++ b/src/transformers/models/deimv2/modular_deimv2.py @@ -154,8 +154,6 @@ class Deimv2Config(DFineConfig): Parameter gamma used to compute the focal loss. weight_loss_vfl (`float`, *optional*, defaults to 1.0): Relative weight of the varifocal loss in the object detection loss. - weight_loss_mal (`float`, *optional*, defaults to 1.0): - Relative weight of the matching auxiliary loss in the object detection loss. weight_loss_bbox (`float`, *optional*, defaults to 5.0): Relative weight of the L1 bounding box loss in the object detection loss. weight_loss_giou (`float`, *optional*, defaults to 2.0): @@ -188,6 +186,8 @@ class Deimv2Config(DFineConfig): The method to use for the decoder: `"default"` or `"discrete"`. up (`float`, *optional*, defaults to 0.5): Controls the upper bounds of the Weighting Function. + weight_loss_mal (`float`, *optional*, defaults to 1.0): + Relative weight of the matching auxiliary loss in the object detection loss. use_dense_o2o (`bool`, *optional*, defaults to `True`): Whether to use dense one-to-one matching across decoder layers. mal_alpha (`float`, *optional*): From 85c7356ece2be84c5aa18d08e5a94513e8114cce Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Sun, 1 Mar 2026 11:33:22 +0400 Subject: [PATCH 03/25] feat: Add support and test harness for all variants --- .../models/deimv2/configuration_deimv2.py | 210 +++- ...eimv2_original_pytorch_checkpoint_to_hf.py | 424 ++++++- .../models/deimv2/modeling_deimv2.py | 843 ++++++++++---- .../models/deimv2/modular_deimv2.py | 549 ++++++++- tests/models/deimv2/test_modeling_deimv2.py | 1010 ++++++++++++++++- 5 files changed, 2749 insertions(+), 287 deletions(-) diff --git a/src/transformers/models/deimv2/configuration_deimv2.py b/src/transformers/models/deimv2/configuration_deimv2.py index 56936d2cdd7a..d98372c44509 100644 --- a/src/transformers/models/deimv2/configuration_deimv2.py +++ b/src/transformers/models/deimv2/configuration_deimv2.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ...backbone_utils import consolidate_backbone_kwargs_to_config +from ...backbone_utils import BackboneConfigMixin, consolidate_backbone_kwargs_to_config from ...configuration_utils import PreTrainedConfig from ..auto import AutoConfig @@ -28,7 +28,8 @@ class Deimv2Config(PreTrainedConfig): """ This is the configuration class to store the configuration of a [`Deimv2Model`]. It is used to instantiate a DEIMv2 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of DEIMv2-L-COCO. + with the defaults will yield a similar configuration to that of DEIMv2-HGNetv2-N-COCO + [Intellindust/DEIMv2_HGNetv2_N_COCO](https://huggingface.co/Intellindust/DEIMv2_HGNetv2_N_COCO). Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PreTrainedConfig`] for more information. @@ -177,6 +178,29 @@ class Deimv2Config(PreTrainedConfig): Whether to use the Spatial Tuning Adapter (STA) for DINOv2 backbone variants. sta_inplanes (`int`, *optional*, defaults to 16): Number of input planes for the STA convolutional stem. + encoder_type (`str`, *optional*, defaults to `"hybrid"`): + Type of encoder to use. `"hybrid"` uses the full HybridEncoder with AIFI, FPN, and PAN. + `"lite"` uses the lightweight LiteEncoder with GAP fusion for smaller variants (Atto, Femto, Pico). + use_gateway (`bool`, *optional*, defaults to `True`): + Whether to use the gateway mechanism (cross-attention gating) in decoder layers. When `False`, + uses RMSNorm on the encoder attention output instead. + share_bbox_head (`bool`, *optional*, defaults to `False`): + Whether to share the bounding box prediction head across all decoder layers. + backbone_type (`str`, *optional*, defaults to `"hgnetv2"`): + Type of backbone to use. `"hgnetv2"` uses HGNetV2, `"dinov3"` uses DINOv3 ViT backbone with STA. + dinov3_backbone_config (`dict`, *optional*): + Configuration dictionary for the DINOv3 ViT backbone. Passed as kwargs to `DINOv3ViTConfig`. + dinov3_interaction_indexes (`list[int]`, *optional*): + Layer indices in the DINOv3 ViT backbone from which to extract intermediate features. + dinov3_hidden_dim (`int`, *optional*): + Hidden dimension for the DINOv3 backbone projection convolutions. If `None`, uses `hidden_size` from + the DINOv3 ViT config. + dinov3_apply_layernorm (`bool`, *optional*, defaults to `False`): + Whether to apply LayerNorm to intermediate features extracted from the DINOv3 ViT backbone. + encoder_has_trailing_conv (`bool`, *optional*, defaults to `True`): + Whether the encoder's CSP blocks include a trailing 3x3 convolution after the bottleneck path. + `True` for RepNCSPELAN4 (used by HGNetV2 N and LiteEncoder variants). + `False` for RepNCSPELAN5 (used by DINOv3 variants). tie_word_embeddings (`bool`, *optional*, defaults to `True`): Whether to tie weight embeddings. """ @@ -261,6 +285,15 @@ def __init__( encoder_fuse_op="sum", use_spatial_tuning_adapter=False, sta_inplanes=16, + encoder_type="hybrid", + use_gateway=True, + share_bbox_head=False, + backbone_type="hgnetv2", + dinov3_backbone_config=None, + dinov3_interaction_indexes=None, + dinov3_hidden_dim=None, + dinov3_apply_layernorm=False, + encoder_has_trailing_conv=True, tie_word_embeddings=True, **kwargs, ): @@ -270,6 +303,15 @@ def __init__( self.encoder_fuse_op = encoder_fuse_op self.use_spatial_tuning_adapter = use_spatial_tuning_adapter self.sta_inplanes = sta_inplanes + self.encoder_type = encoder_type + self.use_gateway = use_gateway + self.share_bbox_head = share_bbox_head + self.backbone_type = backbone_type + self.dinov3_backbone_config = dinov3_backbone_config + self.dinov3_interaction_indexes = dinov3_interaction_indexes + self.dinov3_hidden_dim = dinov3_hidden_dim + self.dinov3_apply_layernorm = dinov3_apply_layernorm + self.encoder_has_trailing_conv = encoder_has_trailing_conv self.initializer_range = initializer_range self.initializer_bias_prior_prob = initializer_bias_prior_prob self.layer_norm_eps = layer_norm_eps @@ -362,4 +404,168 @@ def __init__( super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) +class Deimv2DINOv3ViTConfig(BackboneConfigMixin, PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DINOv3Model`]. It is used to instantiate an + DINOv3 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the DINOv3 + [facebook/dinov3-vits16-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vits16-pretrain-lvd1689m) architecture. + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_size (`int`, *optional*, defaults to 384): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 1536): + Dimensionality of the "intermediate" (i.e., feed-forward) layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 6): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + rope_theta (`float`, *optional*, defaults to 100.0): + The base period of the RoPE embeddings. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + query_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the query projection. + key_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the key projection. + value_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the value projection. + proj_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the output projection. + mlp_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the MLP layers. + layerscale_value (`float`, *optional*, defaults to 1.0): + Initial value to use for layer scale. + drop_path_rate (`float`, *optional*, defaults to 0.0): + Stochastic depth rate per sample (when applied in the main path of residual layers). + use_gated_mlp (`bool`, *optional*, defaults to `False`): + Whether to use the SwiGLU feedforward neural network. + num_register_tokens (`int`, *optional*, defaults to 0): + The number of register tokens. + pos_embed_shift (`float`, *optional*): + Amount to randomly shift position embedding coordinates in [-shift, shift], + applied only in training mode if not `None`. + pos_embed_jitter (`float`, *optional*): + Amount to randomly jitter position embedding coordinates in log-uniform value in [1/jitter, jitter], + applied only in training mode if not `None`. + pos_embed_rescale (`float`, *optional*, defaults to 2.0): + Amount to randomly rescale position embedding coordinates in log-uniform value in [1/rescale, rescale], + applied only in training mode if not `None`. + out_features (`list[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). Will default to the last stage if unset. + out_indices (`list[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. + (depending on how many stages the model has). Will default to the last stage if unset. + apply_layernorm (`bool`, *optional*, defaults to `True`): + Whether to apply layer normalization to the feature maps when used as backbone. + reshape_hidden_states (`bool`, *optional*, defaults to `True`): + Whether to reshape the hidden states to spatial dimensions when used as backbone. + + Example: + + ```python + >>> from transformers import Deimv2DINOv3ViTConfig, Deimv2DINOv3ViTModel + + >>> # Initializing a DINOv3 ViT-small style configuration + >>> config = Deimv2DINOv3ViTConfig() + + >>> # Initializing a model (with random weights) from the config + >>> model = Deimv2DINOv3ViTModel(config) + + >>> # Accessing the model config + >>> config = model.config + ```""" + + model_type = "deimv2_dinov3_vit" + + def __init__( + self, + patch_size: int = 16, + hidden_size: int = 384, + intermediate_size: int = 1536, + num_hidden_layers: int = 12, + num_attention_heads: int = 6, + hidden_act: str = "gelu", + attention_dropout: float = 0.0, + initializer_range: float = 0.02, + layer_norm_eps: float = 1e-5, + rope_theta: float = 100.0, + image_size: int = 224, + num_channels: int = 3, + query_bias: bool = True, + key_bias: bool = False, + value_bias: bool = True, + proj_bias: bool = True, + mlp_bias: bool = True, + layerscale_value: float = 1.0, + drop_path_rate: float = 0.0, + use_gated_mlp: bool = False, + num_register_tokens: int = 0, + # train augs + pos_embed_shift: float | None = None, + pos_embed_jitter: float | None = None, + pos_embed_rescale: float | None = 2.0, + out_features: list[str] | None = None, + out_indices: list[int] | None = None, + apply_layernorm: bool = True, + reshape_hidden_states: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.layerscale_value = layerscale_value + self.drop_path_rate = drop_path_rate + self.use_gated_mlp = use_gated_mlp + self.rope_theta = rope_theta + self.query_bias = query_bias + self.key_bias = key_bias + self.value_bias = value_bias + self.proj_bias = proj_bias + self.mlp_bias = mlp_bias + self.num_register_tokens = num_register_tokens + + # train augs + self.pos_embed_shift = pos_embed_shift + self.pos_embed_jitter = pos_embed_jitter + self.pos_embed_rescale = pos_embed_rescale + # Initialize backbone-specific configuration + self.apply_layernorm = apply_layernorm + self.reshape_hidden_states = reshape_hidden_states + + # Initialize backbone stage names + stage_names = ["stem"] + [f"stage{i}" for i in range(1, num_hidden_layers + 1)] + self.stage_names = stage_names + + # Initialize backbone features/indices + self.set_output_features_output_indices(out_indices=out_indices, out_features=out_features) + + __all__ = ["Deimv2Config"] diff --git a/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py b/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py index 5aaecc844594..a6125952deeb 100644 --- a/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py +++ b/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py @@ -60,13 +60,10 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: decoder_cfg = orig_config["DEIMTransformer"] if "HybridEncoder" in orig_config: encoder_cfg = orig_config["HybridEncoder"] + encoder_type = "hybrid" elif "LiteEncoder" in orig_config: - raise ValueError( - "LiteEncoder variants (pico/femto/atto) are not yet supported. " - "The LiteEncoder uses a different architecture (AvgPool downsampling, GAP fusion, " - "RepNCSPELAN4 blocks) that requires a dedicated Deimv2LiteEncoder implementation. " - "Supported variants: deimv2_hgnetv2_n_coco and DINOv3 variants." - ) + encoder_cfg = orig_config["LiteEncoder"] + encoder_type = "lite" else: raise ValueError(f"No encoder config found. Available keys: {list(orig_config.keys())}") @@ -76,18 +73,34 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: config.label2id = {v: k for k, v in id2label.items()} # Encoder settings - config.encoder_hidden_dim = encoder_cfg["hidden_dim"] - config.encoder_in_channels = encoder_cfg["in_channels"] - config.feat_strides = encoder_cfg["feat_strides"] - config.activation_function = encoder_cfg.get("act", "silu") - config.depth_mult = encoder_cfg.get("depth_mult", 1.0) - config.hidden_expansion = encoder_cfg.get("expansion", 1.0) - config.encoder_fuse_op = encoder_cfg.get("fuse_op", "sum") - config.encoder_ffn_dim = encoder_cfg["dim_feedforward"] - config.encoder_attention_heads = encoder_cfg["nhead"] - config.dropout = encoder_cfg.get("dropout", 0.0) - config.encode_proj_layers = encoder_cfg["use_encoder_idx"] - config.encoder_activation_function = encoder_cfg.get("enc_act", "gelu") + config.encoder_type = encoder_type + if encoder_type == "lite": + config.encoder_hidden_dim = encoder_cfg["hidden_dim"] + config.encoder_in_channels = encoder_cfg["in_channels"] + config.feat_strides = encoder_cfg.get("feat_strides", [16]) + config.activation_function = encoder_cfg.get("act", "silu") + config.depth_mult = encoder_cfg.get("depth_mult", 1.0) + config.hidden_expansion = encoder_cfg.get("expansion", 1.0) + config.encoder_fuse_op = "sum" + config.encoder_ffn_dim = 1024 + config.encoder_attention_heads = 8 + config.dropout = 0.0 + config.encode_proj_layers = [2] + config.encoder_activation_function = "gelu" + config.encoder_layers = 0 + else: + config.encoder_hidden_dim = encoder_cfg["hidden_dim"] + config.encoder_in_channels = encoder_cfg["in_channels"] + config.feat_strides = encoder_cfg.get("feat_strides", [8, 16, 32]) + config.activation_function = encoder_cfg.get("act", "silu") + config.depth_mult = encoder_cfg.get("depth_mult", 1.0) + config.hidden_expansion = encoder_cfg.get("expansion", 1.0) + config.encoder_fuse_op = encoder_cfg.get("fuse_op", "sum") + config.encoder_ffn_dim = encoder_cfg.get("dim_feedforward", 1024) + config.encoder_attention_heads = encoder_cfg.get("nhead", 8) + config.dropout = encoder_cfg.get("dropout", 0.0) + config.encode_proj_layers = encoder_cfg.get("use_encoder_idx", [2]) + config.encoder_activation_function = encoder_cfg.get("enc_act", "gelu") # Decoder settings config.d_model = decoder_cfg["hidden_dim"] @@ -106,6 +119,8 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: config.decoder_in_channels = decoder_cfg["feat_channels"] config.eval_size = tuple(decoder_cfg["eval_spatial_size"]) if "eval_spatial_size" in decoder_cfg else None config.decoder_activation_function = decoder_cfg.get("activation", "silu") + config.share_bbox_head = decoder_cfg.get("share_bbox_head", False) + config.use_gateway = decoder_cfg.get("use_gateway", True) # Backbone settings if "HGNetv2" in orig_config: @@ -115,8 +130,7 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: config.backbone_config.out_indices = [i + 1 for i in return_idx] config.backbone_config.use_learnable_affine_block = backbone_cfg.get("use_lab", True) - # Set backbone sizes based on the model variant - if backbone_name == "B0": + if backbone_name in ["B0", "B1", "B2"]: config.backbone_config.hidden_sizes = [128, 256, 512, 1024] config.backbone_config.stem_channels = [3, 16, 16] config.backbone_config.stage_in_channels = [16, 64, 256, 512] @@ -127,28 +141,128 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: config.backbone_config.stage_light_block = [False, False, True, True] config.backbone_config.stage_kernel_size = [3, 3, 5, 5] config.backbone_config.stage_numb_of_layers = [3, 3, 3, 3] - elif backbone_name in ["B1", "B2"]: - config.backbone_config.hidden_sizes = [128, 256, 512, 1024] + elif backbone_name == "Atto": + config.backbone_config.hidden_sizes = [64, 256, 256] config.backbone_config.stem_channels = [3, 16, 16] - config.backbone_config.stage_in_channels = [16, 64, 256, 512] - config.backbone_config.stage_mid_channels = [16, 32, 64, 128] - config.backbone_config.stage_out_channels = [64, 256, 512, 1024] - config.backbone_config.stage_num_blocks = [1, 1, 2, 1] - config.backbone_config.stage_downsample = [False, True, True, True] - config.backbone_config.stage_light_block = [False, False, True, True] - config.backbone_config.stage_kernel_size = [3, 3, 5, 5] - config.backbone_config.stage_numb_of_layers = [3, 3, 3, 3] + config.backbone_config.stage_in_channels = [16, 64, 256] + config.backbone_config.stage_mid_channels = [16, 32, 64] + config.backbone_config.stage_out_channels = [64, 256, 256] + config.backbone_config.stage_num_blocks = [1, 1, 1] + config.backbone_config.stage_downsample = [False, True, True] + config.backbone_config.stage_light_block = [False, False, True] + config.backbone_config.stage_kernel_size = [3, 3, 3] + config.backbone_config.stage_numb_of_layers = [3, 3, 3] + elif backbone_name == "Femto": + config.backbone_config.hidden_sizes = [64, 256, 512] + config.backbone_config.stem_channels = [3, 16, 16] + config.backbone_config.stage_in_channels = [16, 64, 256] + config.backbone_config.stage_mid_channels = [16, 32, 64] + config.backbone_config.stage_out_channels = [64, 256, 512] + config.backbone_config.stage_num_blocks = [1, 1, 1] + config.backbone_config.stage_downsample = [False, True, True] + config.backbone_config.stage_light_block = [False, False, True] + config.backbone_config.stage_kernel_size = [3, 3, 5] + config.backbone_config.stage_numb_of_layers = [3, 3, 3] + elif backbone_name == "Pico": + config.backbone_config.hidden_sizes = [64, 256, 512] + config.backbone_config.stem_channels = [3, 16, 16] + config.backbone_config.stage_in_channels = [16, 64, 256] + config.backbone_config.stage_mid_channels = [16, 32, 64] + config.backbone_config.stage_out_channels = [64, 256, 512] + config.backbone_config.stage_num_blocks = [1, 1, 2] + config.backbone_config.stage_downsample = [False, True, True] + config.backbone_config.stage_light_block = [False, False, True] + config.backbone_config.stage_kernel_size = [3, 3, 5] + config.backbone_config.stage_numb_of_layers = [3, 3, 3] else: raise ValueError(f"Unknown HGNetv2 variant: {backbone_name}") + num_stages = len(config.backbone_config.hidden_sizes) + config.backbone_config.depths = config.backbone_config.stage_numb_of_layers + config.backbone_config.stage_names = ["stem"] + [f"stage{i}" for i in range(1, num_stages + 1)] + config.use_spatial_tuning_adapter = False elif "DINOv3STAs" in orig_config: - raise ValueError( - "DINOv3 backbone variants are not yet supported. " - "The DINOv3+STA architecture requires ViT backbone key mappings and " - "STA adapter integration that are not yet implemented in the conversion script. " - "Supported variants: deimv2_hgnetv2_n_coco." - ) + dinov3_cfg = orig_config["DINOv3STAs"] + name = dinov3_cfg["name"] + config.backbone_type = "dinov3" + config.dinov3_interaction_indexes = dinov3_cfg["interaction_indexes"] + config.sta_inplanes = dinov3_cfg.get("conv_inplane") or 16 + config.use_spatial_tuning_adapter = True + + is_dinov3 = "dinov3" in name + + DINOV3_PRESETS = { + "vit_tiny": { + "ffn_ratio": 4, + "use_gated_mlp": False, + "layerscale_value": 1.0, + "num_register_tokens": 0, + "pos_embed_rescale": None, + "key_bias": False, + }, + "vit_tinyplus": { + "ffn_ratio": 4, + "use_gated_mlp": False, + "layerscale_value": 1.0, + "num_register_tokens": 0, + "pos_embed_rescale": None, + "key_bias": False, + }, + "dinov3_vits16": { + "vit_hidden_size": 384, + "vit_num_heads": 6, + "ffn_ratio": 4, + "use_gated_mlp": False, + "layerscale_value": 1e-5, + "num_register_tokens": 4, + "pos_embed_rescale": 2.0, + "key_bias": True, + }, + "dinov3_vits16plus": { + "vit_hidden_size": 384, + "vit_num_heads": 6, + "ffn_ratio": 6, + "use_gated_mlp": True, + "layerscale_value": 1e-5, + "num_register_tokens": 4, + "pos_embed_rescale": 2.0, + "key_bias": True, + }, + } + preset = DINOV3_PRESETS[name] + + if is_dinov3: + vit_hidden_size = preset["vit_hidden_size"] + vit_num_heads = preset["vit_num_heads"] + else: + vit_hidden_size = dinov3_cfg.get("embed_dim") or 192 + vit_num_heads = dinov3_cfg.get("num_heads") or 3 + + config.dinov3_hidden_dim = dinov3_cfg.get("hidden_dim") or vit_hidden_size + + ffn_ratio = preset["ffn_ratio"] + if preset["use_gated_mlp"]: + hidden_features = vit_hidden_size * ffn_ratio + d = int(hidden_features * 2 / 3) + intermediate_size = d + (-d % 8) + else: + intermediate_size = vit_hidden_size * ffn_ratio + + config.dinov3_backbone_config = { + "hidden_size": vit_hidden_size, + "num_attention_heads": vit_num_heads, + "num_hidden_layers": 12, + "intermediate_size": intermediate_size, + "layerscale_value": preset["layerscale_value"], + "use_gated_mlp": preset["use_gated_mlp"], + "num_register_tokens": preset["num_register_tokens"], + "pos_embed_rescale": preset["pos_embed_rescale"], + "key_bias": preset["key_bias"], + "rope_theta": 100.0, + } + config.dinov3_apply_layernorm = is_dinov3 + config.encoder_has_trailing_conv = False else: raise ValueError(f"Unknown backbone in config: {list(orig_config.keys())}") @@ -280,10 +394,132 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: r"decoder\.decoder\.reg_scale": r"model.decoder.reg_scale", } +LITE_ENCODER_KEY_MAPPING = { + # LiteEncoder input_proj + r"encoder\.input_proj\.(\d+)\.conv\.weight": r"model.encoder.input_proj.\1.0.weight", + r"encoder\.input_proj\.(\d+)\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.input_proj.\1.1.\2", + # Downsamples + r"encoder\.down_sample(\d+)\.1\.weight": r"model.encoder.down_sample\1.1.weight", + r"encoder\.down_sample(\d+)\.2\.(weight|bias|running_mean|running_var)": r"model.encoder.down_sample\1.2.\2", + # GAP_Fusion + r"encoder\.bi_fusion\.cv\.conv\.weight": r"model.encoder.bi_fusion.cv.conv.weight", + r"encoder\.bi_fusion\.cv\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.bi_fusion.cv.norm.\1", + # FPN block (RepNCSPELAN4) + r"encoder\.fpn_block\.cv1\.conv\.weight": r"model.encoder.fpn_block.conv1.conv.weight", + r"encoder\.fpn_block\.cv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.conv1.norm.\1", + r"encoder\.fpn_block\.cv4\.conv\.weight": r"model.encoder.fpn_block.conv4.conv.weight", + r"encoder\.fpn_block\.cv4\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.conv4.norm.\1", + r"encoder\.fpn_block\.cv2\.0\.conv1\.conv\.weight": r"model.encoder.fpn_block.csp_rep1.conv1.conv.weight", + r"encoder\.fpn_block\.cv2\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep1.conv1.norm.\1", + r"encoder\.fpn_block\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.fpn_block.csp_rep1.bottlenecks.\1.conv1.conv.weight", + r"encoder\.fpn_block\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep1.bottlenecks.\1.conv1.norm.\2", + r"encoder\.fpn_block\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.fpn_block.csp_rep1.bottlenecks.\1.conv2.conv.weight", + r"encoder\.fpn_block\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep1.bottlenecks.\1.conv2.norm.\2", + r"encoder\.fpn_block\.cv2\.1\.conv\.weight": r"model.encoder.fpn_block.csp_rep1.conv3.conv.weight", + r"encoder\.fpn_block\.cv2\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep1.conv3.norm.\1", + r"encoder\.fpn_block\.cv3\.0\.conv1\.conv\.weight": r"model.encoder.fpn_block.csp_rep2.conv1.conv.weight", + r"encoder\.fpn_block\.cv3\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep2.conv1.norm.\1", + r"encoder\.fpn_block\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.fpn_block.csp_rep2.bottlenecks.\1.conv1.conv.weight", + r"encoder\.fpn_block\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep2.bottlenecks.\1.conv1.norm.\2", + r"encoder\.fpn_block\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.fpn_block.csp_rep2.bottlenecks.\1.conv2.conv.weight", + r"encoder\.fpn_block\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep2.bottlenecks.\1.conv2.norm.\2", + r"encoder\.fpn_block\.cv3\.1\.conv\.weight": r"model.encoder.fpn_block.csp_rep2.conv3.conv.weight", + r"encoder\.fpn_block\.cv3\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep2.conv3.norm.\1", + # PAN block (same structure as FPN) + r"encoder\.pan_block\.cv1\.conv\.weight": r"model.encoder.pan_block.conv1.conv.weight", + r"encoder\.pan_block\.cv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.conv1.norm.\1", + r"encoder\.pan_block\.cv4\.conv\.weight": r"model.encoder.pan_block.conv4.conv.weight", + r"encoder\.pan_block\.cv4\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.conv4.norm.\1", + r"encoder\.pan_block\.cv2\.0\.conv1\.conv\.weight": r"model.encoder.pan_block.csp_rep1.conv1.conv.weight", + r"encoder\.pan_block\.cv2\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep1.conv1.norm.\1", + r"encoder\.pan_block\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.pan_block.csp_rep1.bottlenecks.\1.conv1.conv.weight", + r"encoder\.pan_block\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep1.bottlenecks.\1.conv1.norm.\2", + r"encoder\.pan_block\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.pan_block.csp_rep1.bottlenecks.\1.conv2.conv.weight", + r"encoder\.pan_block\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep1.bottlenecks.\1.conv2.norm.\2", + r"encoder\.pan_block\.cv2\.1\.conv\.weight": r"model.encoder.pan_block.csp_rep1.conv3.conv.weight", + r"encoder\.pan_block\.cv2\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep1.conv3.norm.\1", + r"encoder\.pan_block\.cv3\.0\.conv1\.conv\.weight": r"model.encoder.pan_block.csp_rep2.conv1.conv.weight", + r"encoder\.pan_block\.cv3\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep2.conv1.norm.\1", + r"encoder\.pan_block\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.pan_block.csp_rep2.bottlenecks.\1.conv1.conv.weight", + r"encoder\.pan_block\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep2.bottlenecks.\1.conv1.norm.\2", + r"encoder\.pan_block\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.pan_block.csp_rep2.bottlenecks.\1.conv2.conv.weight", + r"encoder\.pan_block\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep2.bottlenecks.\1.conv2.norm.\2", + r"encoder\.pan_block\.cv3\.1\.conv\.weight": r"model.encoder.pan_block.csp_rep2.conv3.conv.weight", + r"encoder\.pan_block\.cv3\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep2.conv3.norm.\1", +} + +DECODER_NO_GATEWAY_KEY_MAPPING = { + r"decoder\.decoder\.layers\.(\d+)\.norm2\.scale": r"model.decoder.layers.\1.encoder_attn_layer_norm.scale", +} + +DINOV3_KEY_MAPPING = { + # ViT embeddings + r"backbone\.dinov3\.patch_embed\.proj\.(weight|bias)": r"model.dinov3_backbone.embeddings.patch_embeddings.\1", + r"backbone\.dinov3\.cls_token": r"model.dinov3_backbone.embeddings.cls_token", + r"backbone\.dinov3\.storage_tokens": r"model.dinov3_backbone.embeddings.register_tokens", + r"backbone\.dinov3\.mask_token": r"model.dinov3_backbone.embeddings.mask_token", + # ViT blocks + r"backbone\.dinov3\.blocks\.(\d+)\.norm1\.(weight|bias)": r"model.dinov3_backbone.layers.\1.norm1.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.norm2\.(weight|bias)": r"model.dinov3_backbone.layers.\1.norm2.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.attn\.qkv\.(weight|bias)": r"model.dinov3_backbone.layers.\1.attention.qkv.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.attn\.proj\.(weight|bias)": r"model.dinov3_backbone.layers.\1.attention.o_proj.\2", + # Standard MLP (S/M/L) + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.fc1\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.up_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.fc2\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.down_proj.\2", + # SwiGLU MLP (X only) + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w1\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.gate_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w2\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.up_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w3\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.down_proj.\2", + # LayerScale (L/X only) + r"backbone\.dinov3\.blocks\.(\d+)\.ls1\.gamma": r"model.dinov3_backbone.layers.\1.layer_scale1.lambda1", + r"backbone\.dinov3\.blocks\.(\d+)\.ls2\.gamma": r"model.dinov3_backbone.layers.\1.layer_scale2.lambda1", + # Norm (L/X only) + r"backbone\.dinov3\.norm\.(weight|bias)": r"model.dinov3_backbone.norm.\1", + # STA adapter + r"backbone\.sta\.stem\.0\.(weight|bias)": r"model.dinov3_backbone.sta.stem.0.\1", + r"backbone\.sta\.stem\.1\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.sta.stem.1.\1", + r"backbone\.sta\.conv2\.0\.(weight)": r"model.dinov3_backbone.sta.conv2.0.\1", + r"backbone\.sta\.conv2\.1\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.sta.conv2.1.\1", + r"backbone\.sta\.conv3\.1\.(weight)": r"model.dinov3_backbone.sta.conv3.1.\1", + r"backbone\.sta\.conv3\.2\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.sta.conv3.2.\1", + r"backbone\.sta\.conv4\.1\.(weight)": r"model.dinov3_backbone.sta.conv4.1.\1", + r"backbone\.sta\.conv4\.2\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.sta.conv4.2.\1", + # Projection convs/norms + r"backbone\.convs\.(\d+)\.weight": r"model.dinov3_backbone.convs.\1.weight", + r"backbone\.norms\.(\d+)\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.norms.\1.\2", +} + -def convert_old_keys_to_new_keys(state_dict): - # Use the mapping to rename keys - for original_key, converted_key in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): +def convert_old_keys_to_new_keys(state_dict, config=None): + mapping = dict(ORIGINAL_TO_CONVERTED_KEY_MAPPING) + + if config is not None: + if config.encoder_type == "lite": + for k in list(mapping.keys()): + if ( + k.startswith(r"encoder\.input_proj") + or k.startswith(r"encoder\.lateral") + or k.startswith(r"encoder\.fpn_blocks") + or k.startswith(r"encoder\.pan_blocks") + or k.startswith(r"encoder\.downsample") + or k.startswith(r"encoder\.encoder") + ): + del mapping[k] + mapping.update(LITE_ENCODER_KEY_MAPPING) + + if not config.use_gateway: + mapping.update(DECODER_NO_GATEWAY_KEY_MAPPING) + for k in list(mapping.keys()): + if "gateway" in k: + del mapping[k] + + if config.backbone_type == "dinov3": + for k in list(mapping.keys()): + if k.startswith(r"backbone\."): + del mapping[k] + mapping.update(DINOV3_KEY_MAPPING) + + for original_key, converted_key in mapping.items(): for key in list(state_dict.keys()): new_key = re.sub(f"^{original_key}$", converted_key, key) if new_key != key: @@ -333,6 +569,39 @@ def read_in_q_k_v(state_dict, config): state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-d_model:] +def strip_dinov3_model_prefix(state_dict): + for key in list(state_dict.keys()): + if "backbone.dinov3._model." in key: + new_key = key.replace("backbone.dinov3._model.", "backbone.dinov3.") + state_dict[new_key] = state_dict.pop(key) + return state_dict + + +def read_in_q_k_v_vit(state_dict, config): + from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig + + vit_config = DINOv3ViTConfig(**config.dinov3_backbone_config) + has_key_bias = config.dinov3_backbone_config.get("key_bias", True) + prefix = "model.dinov3_backbone" + for i in range(vit_config.num_hidden_layers): + qkv_key = f"{prefix}.layers.{i}.attention.qkv.weight" + if qkv_key in state_dict: + qkv_w = state_dict.pop(qkv_key) + q, k, v = qkv_w.chunk(3, dim=0) + state_dict[f"{prefix}.layers.{i}.attention.q_proj.weight"] = q + state_dict[f"{prefix}.layers.{i}.attention.k_proj.weight"] = k + state_dict[f"{prefix}.layers.{i}.attention.v_proj.weight"] = v + + qkv_bias_key = f"{prefix}.layers.{i}.attention.qkv.bias" + if qkv_bias_key in state_dict: + qkv_b = state_dict.pop(qkv_bias_key) + q_b, k_b, v_b = qkv_b.chunk(3, dim=0) + state_dict[f"{prefix}.layers.{i}.attention.q_proj.bias"] = q_b + if has_key_bias: + state_dict[f"{prefix}.layers.{i}.attention.k_proj.bias"] = k_b + state_dict[f"{prefix}.layers.{i}.attention.v_proj.bias"] = v_b + + def load_original_state_dict(repo_id): filepath = hf_hub_download(repo_id=repo_id, filename="model.safetensors") return load_file(filepath) @@ -361,10 +630,26 @@ def convert_deimv2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, state_dict.pop("decoder.valid_mask", None) state_dict.pop("decoder.anchors", None) + for key in list(state_dict.keys()): + if key.endswith(".num_batches_tracked"): + state_dict.pop(key) + + if config.backbone_type == "dinov3": + strip_dinov3_model_prefix(state_dict) + for key in list(state_dict.keys()): + if "rope_embed.periods" in key or "qkv.bias_mask" in key: + state_dict.pop(key) + # query, key and value matrices need special treatment read_in_q_k_v(state_dict, config) - state_dict = convert_old_keys_to_new_keys(state_dict) + state_dict = convert_old_keys_to_new_keys(state_dict, config) + + if config.backbone_type == "dinov3": + read_in_q_k_v_vit(state_dict, config) + mask_key = "model.dinov3_backbone.embeddings.mask_token" + if mask_key in state_dict and state_dict[mask_key].dim() == 2: + state_dict[mask_key] = state_dict[mask_key].unsqueeze(1) if "model.enc_output.0.weight" not in state_dict: d_model = config.d_model @@ -373,37 +658,74 @@ def convert_deimv2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, state_dict["model.enc_output.1.weight"] = torch.ones(d_model) state_dict["model.enc_output.1.bias"] = torch.zeros(d_model) + if config.share_bbox_head: + num_decoder_layers = config.decoder_layers + for key in list(state_dict.keys()): + if "model.decoder.bbox_embed.0." in key: + for i in range(1, num_decoder_layers): + new_key = key.replace("bbox_embed.0.", f"bbox_embed.{i}.") + if new_key not in state_dict: + state_dict[new_key] = state_dict[key] + # for two_stage for key in list(state_dict.keys()): if "bbox_embed" in key or ("class_embed" in key and "denoising_" not in key): new_key = key.split("model.decoder.")[-1] - if new_key not in state_dict: + if new_key != key and new_key not in state_dict: state_dict[new_key] = state_dict[key] model = Deimv2ForObjectDetection(config) missing, unexpected = model.load_state_dict(state_dict, strict=False) - if missing: - logger.warning(f"Missing keys ({len(missing)}): {missing[:10]}...") + expected_missing = {"mask_token", "register_tokens", "layer_scale1", "layer_scale2"} + unexpected_missing = [k for k in missing if not any(e in k for e in expected_missing)] + if unexpected_missing: + logger.warning(f"Missing keys ({len(unexpected_missing)}): {unexpected_missing[:10]}...") + elif missing: + logger.info( + f"All {len(missing)} missing keys are expected model-init defaults (mask_token, register_tokens, layer_scale)" + ) if unexpected: logger.warning(f"Unexpected keys ({len(unexpected)}): {unexpected[:10]}...") model.eval() - image_processor = RTDetrImageProcessor() + is_dinov3 = config.backbone_type == "dinov3" + if is_dinov3: + image_processor = RTDetrImageProcessor( + do_normalize=True, + image_mean=[0.485, 0.456, 0.406], + image_std=[0.229, 0.224, 0.225], + ) + else: + image_processor = RTDetrImageProcessor() + img = prepare_img() - transformations = transforms.Compose( - [ - transforms.Resize([640, 640], interpolation=transforms.InterpolationMode.BILINEAR), - transforms.ToTensor(), - ] - ) + if is_dinov3: + transformations = transforms.Compose( + [ + transforms.Resize([640, 640], interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + else: + transformations = transforms.Compose( + [ + transforms.Resize([640, 640], interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + ] + ) original_pixel_values = transformations(img).unsqueeze(0) encoding = image_processor(images=img, return_tensors="pt") pixel_values = encoding["pixel_values"] - assert torch.allclose(original_pixel_values, pixel_values), "Image preprocessing mismatch!" + if not torch.allclose(original_pixel_values, pixel_values, atol=1e-4): + max_diff = (original_pixel_values - pixel_values).abs().max().item() + logger.warning(f"Image preprocessing mismatch! Max diff: {max_diff:.6f}") + if max_diff > 1e-2: + raise ValueError(f"Image preprocessing mismatch too large: {max_diff}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) diff --git a/src/transformers/models/deimv2/modeling_deimv2.py b/src/transformers/models/deimv2/modeling_deimv2.py index 154170174227..9173609eb0df 100644 --- a/src/transformers/models/deimv2/modeling_deimv2.py +++ b/src/transformers/models/deimv2/modeling_deimv2.py @@ -21,23 +21,24 @@ from collections.abc import Callable from dataclasses import dataclass +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from ... import initialization as init -from ...activations import ACT2CLS -from ...backbone_utils import load_backbone +from ...activations import ACT2CLS, ACT2FN from ...image_transforms import center_to_corners_format, corners_to_center_format +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import compile_compatible_method_lru_cache from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int -from ...utils.generic import can_return_tuple, merge_with_config_defaults +from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from .configuration_deimv2 import Deimv2Config +from .configuration_deimv2 import Deimv2Config, Deimv2DINOv3ViTConfig @dataclass @@ -351,6 +352,399 @@ def forward(self, x): return self.activation(y) +class Deimv2DINOv3ViTEmbeddings(nn.Module): + """ + Construct the CLS token, mask token, position and patch embeddings. + """ + + def __init__(self, config: Deimv2DINOv3ViTConfig): + super().__init__() + self.config = config + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.register_tokens = nn.Parameter(torch.empty(1, config.num_register_tokens, config.hidden_size)) + self.patch_embeddings = nn.Conv2d( + config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size + ) + + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embeddings.weight.dtype + + # (batch_size, num_channels, height, width) -> (batch_size, num_patches, hidden_size) + patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2) + + if bool_masked_pos is not None: + mask_token = self.mask_token.to(patch_embeddings.dtype) + patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings) + + # Add CLS and register tokens + cls_token = self.cls_token.expand(batch_size, -1, -1) + register_tokens = self.register_tokens.expand(batch_size, -1, -1) + embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) + + return embeddings + + +class Deimv2DINOv3ViTLayerScale(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size)) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return hidden_state * self.lambda1 + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float | None = None, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, **kwargs +) -> tuple[torch.Tensor, torch.Tensor]: + """Applies Rotary Position Embedding to the query and key tensors, but only to the patch tokens, + ignoring the prefix tokens (cls token and register tokens). + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + num_tokens = q.shape[-2] + num_patches = sin.shape[-2] + num_prefix_tokens = num_tokens - num_patches # cls token + register tokens + + q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2) + k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2) + + # apply rope only to patch tokens + q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin) + k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin) + + q = torch.cat((q_prefix_tokens, q_patches), dim=-2) + k = torch.cat((k_prefix_tokens, k_patches), dim=-2) + + return q, k + + +class Deimv2DINOv3ViTAttention(nn.Module): + """ + Multi-headed attention compatible with ALL_ATTENTION_FUNCTIONS. + """ + + def __init__(self, config: Deimv2DINOv3ViTConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.is_causal = False + + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.dropout = config.attention_dropout + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.key_bias) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.value_bias) + + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.query_bias) + self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.proj_bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Input shape: Batch x Time x Channel""" + + batch_size, patches, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +class Deimv2DINOv3ViTDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float | None = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return f"p={self.drop_prob}" + + +class Deimv2DINOv3ViTMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + + +class Deimv2DINOv3ViTGatedMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Deimv2DINOv3ViTLayer(GradientCheckpointingLayer): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: Deimv2DINOv3ViTConfig): + super().__init__() + + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = Deimv2DINOv3ViTAttention(config) + self.layer_scale1 = Deimv2DINOv3ViTLayerScale(config) + self.drop_path = ( + Deimv2DINOv3ViTDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + ) + + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.use_gated_mlp: + self.mlp = Deimv2DINOv3ViTGatedMLP(config) + else: + self.mlp = Deimv2DINOv3ViTMLP(config) + self.layer_scale2 = Deimv2DINOv3ViTLayerScale(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + # Attention with residual connection + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states, _ = self.attention( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + hidden_states = self.layer_scale1(hidden_states) + hidden_states = self.drop_path(hidden_states) + residual + + # MLP with residual connection + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.layer_scale2(hidden_states) + hidden_states = self.drop_path(hidden_states) + residual + + return hidden_states + + +@compile_compatible_method_lru_cache(maxsize=32) +def get_patches_center_coordinates( + num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + """ + Computes the 2D coordinates of the centers of image patches, normalized to the range [-1, +1]. + The center of each patch is exactly halfway between its top-left and bottom-right corners. + + Args: + num_patches_h (int): Number of patches along the vertical (height) axis. + num_patches_w (int): Number of patches along the horizontal (width) axis. + dtype (torch.dtype): The desired data type of the returned tensor. + + Returns: + torch.Tensor: A tensor of shape (height * width, 2), where each row contains the (y, x) + coordinates of a patch center, normalized to [-1, +1]. + """ + coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device) + coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device) + coords_h = coords_h / num_patches_h + coords_w = coords_w / num_patches_w + # (height, width, 2) -> (height * width, 2) + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) + coords = coords.flatten(0, 1) + # Shift range [0, 1] to [-1, +1] + coords = 2.0 * coords - 1.0 + return coords + + +def augment_patches_center_coordinates( + coords: torch.Tensor, + shift: float | None = None, + jitter: float | None = None, + rescale: float | None = None, +) -> torch.Tensor: + # Shift coords by adding a uniform value in [-shift, shift] + if shift is not None: + shift_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype) + shift_hw = shift_hw.uniform_(-shift, shift) + coords = coords + shift_hw + + # Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter] + if jitter is not None: + jitter_range = np.log(jitter) + jitter_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype) + jitter_hw = jitter_hw.uniform_(-jitter_range, jitter_range).exp() + coords = coords * jitter_hw + + # Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale] + if rescale is not None: + rescale_range = np.log(rescale) + rescale_hw = torch.empty(1, device=coords.device, dtype=coords.dtype) + rescale_hw = rescale_hw.uniform_(-rescale_range, rescale_range).exp() + coords = coords * rescale_hw + + return coords + + +class Deimv2DINOv3ViTRopePositionEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, config: Deimv2DINOv3ViTConfig): + super().__init__() + + self.config = config + self.base = config.rope_theta + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_patches_h = config.image_size // config.patch_size + self.num_patches_w = config.image_size // config.patch_size + + inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32) # (head_dim / 4,) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + _, _, height, width = pixel_values.shape + num_patches_h = height // self.config.patch_size + num_patches_w = width // self.config.patch_size + + device = pixel_values.device + device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu" + + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + # Although we could precompute static patch_coords from image_size and patch_size in the config, + # the model was trained with random_scale, so it can process images of varying sizes. + # Therefore, it's better to compute patch_coords dynamically (with lru_cache). + patch_coords = get_patches_center_coordinates( + num_patches_h, num_patches_w, dtype=torch.float32, device=device + ) + if self.training: + patch_coords = augment_patches_center_coordinates( + patch_coords, + shift=self.config.pos_embed_shift, + jitter=self.config.pos_embed_jitter, + rescale=self.config.pos_embed_rescale, + ) + + # (height * width, 2, head_dim / 4) -> (height * width, head_dim / 2) -> (height * width, head_dim) + angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] + angles = angles.flatten(1, 2) + angles = angles.tile(2) + + cos = torch.cos(angles) + sin = torch.sin(angles) + + dtype = pixel_values.dtype + return cos.to(dtype=dtype), sin.to(dtype=dtype) + + class Deimv2CSPRepLayer2(nn.Module): def __init__( self, config: Deimv2Config, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0 @@ -362,7 +756,10 @@ def __init__( self.bottlenecks = nn.ModuleList( [Deimv2RepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)] ) - self.conv3 = Deimv2ConvNormLayer(config, hidden_channels, out_channels, 3, 1, activation=activation) + if config.encoder_has_trailing_conv: + self.conv3 = Deimv2ConvNormLayer(config, hidden_channels, out_channels, 3, 1, activation=activation) + else: + self.conv3 = nn.Identity() def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: chunks = list(self.conv1(hidden_state).chunk(2, 1)) @@ -373,13 +770,21 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: class Deimv2RepNCSPELAN5(nn.Module): - def __init__(self, config: Deimv2Config, numb_blocks: int = 3): + def __init__( + self, + config: Deimv2Config, + numb_blocks: int = 3, + c1: int | None = None, + c2: int | None = None, + c3: int | None = None, + c4: int | None = None, + ): super().__init__() act = config.activation_function - c1 = config.encoder_hidden_dim - c2 = config.encoder_hidden_dim - c3 = config.encoder_hidden_dim * 2 - c4 = round(config.hidden_expansion * config.encoder_hidden_dim // 2) + c1 = c1 if c1 is not None else config.encoder_hidden_dim + c2 = c2 if c2 is not None else config.encoder_hidden_dim + c3 = c3 if c3 is not None else config.encoder_hidden_dim * 2 + c4 = c4 if c4 is not None else round(config.hidden_expansion * config.encoder_hidden_dim // 2) self.conv_dim = c3 // 2 self.conv1 = Deimv2ConvNormLayer(config, c1, c3, 1, 1, activation=act) self.csp_rep1 = Deimv2CSPRepLayer2(config, c3 // 2, c4, num_blocks=numb_blocks) @@ -414,34 +819,6 @@ def forward(self, input_features: torch.Tensor) -> torch.Tensor: return input_features -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor | None, - scaling: float | None = None, - dropout: float = 0.0, - **kwargs: Unpack[TransformersKwargs], -): - if scaling is None: - scaling = query.size(-1) ** -0.5 - - # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - - attn_output = torch.matmul(attn_weights, value) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - class Deimv2SelfAttention(nn.Module): """ Multi-headed self-attention from 'Attention Is All You Need' paper. @@ -708,6 +1085,135 @@ def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tenso return c2, c3, c4 +class Deimv2GAPFusion(nn.Module): + def __init__(self, config: Deimv2Config, channels: int): + super().__init__() + self.cv = Deimv2ConvNormLayer(config, channels, channels, 1, 1, activation=config.activation_function) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return self.cv(hidden_state + F.adaptive_avg_pool2d(hidden_state, 1)) + + +class Deimv2LiteEncoder(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + hidden_dim = config.encoder_hidden_dim + act = config.activation_function + + self.input_proj = nn.ModuleList() + for in_channel in config.encoder_in_channels: + self.input_proj.append( + nn.Sequential( + nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(hidden_dim), + ) + ) + + self.down_sample1 = nn.Sequential( + nn.AvgPool2d(kernel_size=3, stride=2, padding=1), + nn.Conv2d(hidden_dim, hidden_dim, 1, 1, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.SiLU(inplace=True) if act == "silu" else nn.ReLU(inplace=True), + ) + self.down_sample2 = nn.Sequential( + nn.AvgPool2d(kernel_size=3, stride=2, padding=1), + nn.Conv2d(hidden_dim, hidden_dim, 1, 1, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.SiLU(inplace=True) if act == "silu" else nn.ReLU(inplace=True), + ) + + self.bi_fusion = Deimv2GAPFusion(config, hidden_dim) + + c1, c2 = hidden_dim, hidden_dim + c3 = hidden_dim * 2 + c4 = round(config.hidden_expansion * hidden_dim // 2) + num_blocks = round(3 * config.depth_mult) + self.fpn_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks, c1=c1, c2=c2, c3=c3, c4=c4) + self.pan_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks, c1=c1, c2=c2, c3=c3, c4=c4) + + def forward(self, inputs_embeds=None, **kwargs) -> BaseModelOutput: + feats = inputs_embeds + proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] + proj_feats.append(self.down_sample1(proj_feats[-1])) + + proj_feats[-1] = self.bi_fusion(proj_feats[-1]) + + outs = [] + fuse_feat = proj_feats[0] + F.interpolate(proj_feats[1], scale_factor=2.0, mode="nearest") + outs.append(self.fpn_block(fuse_feat)) + + fuse_feat = proj_feats[1] + self.down_sample2(outs[-1]) + outs.append(self.pan_block(fuse_feat)) + + return BaseModelOutput(last_hidden_state=outs) + + +class Deimv2DINOv3Backbone(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + vit_config = Deimv2DINOv3ViTConfig(**config.dinov3_backbone_config) + + self.embeddings = Deimv2DINOv3ViTEmbeddings(vit_config) + self.rope_embeddings = Deimv2DINOv3ViTRopePositionEmbedding(vit_config) + self.layers = nn.ModuleList([Deimv2DINOv3ViTLayer(vit_config) for _ in range(vit_config.num_hidden_layers)]) + + self.apply_layernorm = config.dinov3_apply_layernorm + if self.apply_layernorm: + self.norm = nn.LayerNorm(vit_config.hidden_size, eps=vit_config.layer_norm_eps) + + self.interaction_indexes = config.dinov3_interaction_indexes + self.patch_size = vit_config.patch_size + self.num_prefix_tokens = 1 + vit_config.num_register_tokens + + self.sta = Deimv2SpatialTuningAdapter(config) + + embed_dim = vit_config.hidden_size + hidden_dim = config.dinov3_hidden_dim or embed_dim + sta_ch = config.sta_inplanes + self.convs = nn.ModuleList( + [ + nn.Conv2d(embed_dim + sta_ch * 2, hidden_dim, 1, bias=False), + nn.Conv2d(embed_dim + sta_ch * 4, hidden_dim, 1, bias=False), + nn.Conv2d(embed_dim + sta_ch * 4, hidden_dim, 1, bias=False), + ] + ) + self.norms = nn.ModuleList([nn.BatchNorm2d(hidden_dim) for _ in range(3)]) + + def forward(self, pixel_values: torch.Tensor) -> list[torch.Tensor]: + hidden_states = self.embeddings(pixel_values) + position_embeddings = self.rope_embeddings(pixel_values) + + intermediate_outputs = [] + for i, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, position_embeddings=position_embeddings) + if i in self.interaction_indexes: + out = self.norm(hidden_states) if self.apply_layernorm else hidden_states + intermediate_outputs.append(out) + + batch_size = pixel_values.shape[0] + h_patches = pixel_values.shape[2] // self.patch_size + w_patches = pixel_values.shape[3] // self.patch_size + + sem_feats = [] + num_scales = len(intermediate_outputs) + for i, feat in enumerate(intermediate_outputs): + patch_tokens = feat[:, self.num_prefix_tokens :] + spatial = patch_tokens.transpose(1, 2).reshape(batch_size, -1, h_patches, w_patches).contiguous() + resize_h = int(h_patches * 2 ** (num_scales - 2 - i)) + resize_w = int(w_patches * 2 ** (num_scales - 2 - i)) + spatial = F.interpolate(spatial, size=[resize_h, resize_w], mode="bilinear", align_corners=False) + sem_feats.append(spatial) + + detail_feats = self.sta(pixel_values) + + outs = [] + for i, (sem_feat, detail_feat) in enumerate(zip(sem_feats, detail_feats)): + fused = torch.cat([sem_feat, detail_feat], dim=1) + outs.append(self.norms[i](self.convs[i](fused))) + + return outs + + class Deimv2Integral(nn.Module): """ A static layer that calculates integral results from a distribution. @@ -767,7 +1273,13 @@ def __init__(self, config: Deimv2Config): self.encoder_attn = Deimv2MultiscaleDeformableAttention(config=config) self.mlp = Deimv2SwiGLUFFN(config.d_model, config.decoder_ffn_dim // 2, config.d_model) self.final_layer_norm = Deimv2RMSNorm(config.d_model) + # gate self.gateway = Deimv2Gate(config.d_model) + if config.use_gateway: + self.gateway = Deimv2Gate(config.d_model) + else: + del self.gateway + self.encoder_attn_layer_norm = Deimv2RMSNorm(config.d_model) def forward( self, @@ -778,7 +1290,7 @@ def forward( spatial_shapes_list=None, encoder_hidden_states: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, - **kwargs: Unpack[TransformersKwargs], + **kwargs, ) -> torch.Tensor: """ Args: @@ -801,7 +1313,6 @@ def forward( """ residual = hidden_states - # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=encoder_attention_mask, @@ -815,7 +1326,6 @@ def forward( residual = hidden_states - # Cross-Attention hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings hidden_states, _ = self.encoder_attn( hidden_states=hidden_states, @@ -825,9 +1335,13 @@ def forward( spatial_shapes_list=spatial_shapes_list, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = self.gateway(residual, hidden_states) - # Fully Connected + if hasattr(self, "gateway"): + hidden_states = self.gateway(residual, hidden_states) + else: + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + residual = hidden_states hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states @@ -942,6 +1456,20 @@ def _init_weights(self, module): init.ones_(module.weight) init.zeros_(module.bias) + if isinstance(module, Deimv2DINOv3ViTEmbeddings): + init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) + if module.config.num_register_tokens > 0: + init.trunc_normal_(module.register_tokens, mean=0.0, std=self.config.initializer_range) + init.zeros_(module.mask_token) + + if isinstance(module, Deimv2DINOv3ViTLayerScale) and self.config.dinov3_backbone_config is not None: + layerscale_value = self.config.dinov3_backbone_config.get("layerscale_value", 1.0) + init.constant_(module.lambda1, layerscale_value) + + if isinstance(module, Deimv2DINOv3ViTRopePositionEmbedding): + inv_freq = 1 / module.base ** torch.arange(0, 1, 4 / module.head_dim, dtype=torch.float32) + init.copy_(module.inv_freq, inv_freq) + if hasattr(module, "weight_embedding") and self.config.learn_initial_query: init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: @@ -1259,45 +1787,6 @@ def forward( ) -class Deimv2FrozenBatchNorm2d(nn.Module): - """ - BatchNorm2d where the batch statistics and the affine parameters are fixed. - - Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than - torchvision.models.resnet[18,34,50,101] produce nans. - """ - - def __init__(self, n): - super().__init__() - self.register_buffer("weight", torch.ones(n)) - self.register_buffer("bias", torch.zeros(n)) - self.register_buffer("running_mean", torch.zeros(n)) - self.register_buffer("running_var", torch.ones(n)) - - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - num_batches_tracked_key = prefix + "num_batches_tracked" - if num_batches_tracked_key in state_dict: - del state_dict[num_batches_tracked_key] - - super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) - - def forward(self, x): - # move reshapes to the beginning - # to make it user-friendly - weight = self.weight.reshape(1, -1, 1, 1) - bias = self.bias.reshape(1, -1, 1, 1) - running_var = self.running_var.reshape(1, -1, 1, 1) - running_mean = self.running_mean.reshape(1, -1, 1, 1) - epsilon = 1e-5 - scale = weight * (running_var + epsilon).rsqrt() - bias = bias - running_mean * scale - return x * scale + bias - - @dataclass @auto_docstring( custom_intro=""" @@ -1356,62 +1845,6 @@ class Deimv2ModelOutput(ModelOutput): denoising_meta_values: dict | None = None -def replace_batch_norm(model): - r""" - Recursively replace all `torch.nn.BatchNorm2d` with `Deimv2FrozenBatchNorm2d`. - - Args: - model (torch.nn.Module): - input model - """ - for name, module in model.named_children(): - if isinstance(module, nn.BatchNorm2d): - new_module = Deimv2FrozenBatchNorm2d(module.num_features) - - if module.weight.device != torch.device("meta"): - new_module.weight.copy_(module.weight) - new_module.bias.copy_(module.bias) - new_module.running_mean.copy_(module.running_mean) - new_module.running_var.copy_(module.running_var) - - model._modules[name] = new_module - - if len(list(module.children())) > 0: - replace_batch_norm(module) - - -class Deimv2ConvEncoder(nn.Module): - """ - Convolutional backbone using the modeling_deimv2_resnet.py. - - nn.BatchNorm2d layers are replaced by Deimv2FrozenBatchNorm2d as defined above. - https://github.com/lyuwenyu/RT-DETR/blob/main/Deimv2_pytorch/src/nn/backbone/presnet.py#L142 - """ - - def __init__(self, config): - super().__init__() - - backbone = load_backbone(config) - - if config.freeze_backbone_batch_norms: - # replace batch norm by frozen batch norm - with torch.no_grad(): - replace_batch_norm(backbone) - self.model = backbone - self.intermediate_channel_sizes = self.model.channels - - def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): - # send pixel_values through the model to get list of feature maps - features = self.model(pixel_values).feature_maps - - out = [] - for feature_map in features: - # downsample pixel_mask to match shape of corresponding feature_map - mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] - out.append((feature_map, mask)) - return out - - def get_contrastive_denoising_training_group( targets, num_classes, @@ -1544,62 +1977,48 @@ class Deimv2Model(Deimv2PreTrainedModel): def __init__(self, config: Deimv2Config): super().__init__(config) - # Create backbone - self.backbone = Deimv2ConvEncoder(config) - intermediate_channel_sizes = self.backbone.intermediate_channel_sizes - num_backbone_outs = len(config.decoder_in_channels) - encoder_input_proj_list = [] - for i in range(num_backbone_outs): - in_channels = intermediate_channel_sizes[i] - encoder_input_proj_list.append( - nn.Sequential( - nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False), - nn.BatchNorm2d(config.encoder_hidden_dim), - ) - ) - self.encoder_input_proj = nn.ModuleList(encoder_input_proj_list) - self.encoder = Deimv2HybridEncoder(config=config) + if config.backbone_type == "dinov3": + self.dinov3_backbone = Deimv2DINOv3Backbone(config) + else: + from ..d_fine.modeling_d_fine import DFineConvEncoder + + self.backbone = DFineConvEncoder(config) + if config.encoder_type != "lite": + intermediate_channel_sizes = self.backbone.intermediate_channel_sizes + encoder_input_proj = [] + for in_channel in intermediate_channel_sizes: + encoder_input_proj.append( + nn.Sequential( + nn.Conv2d(in_channel, config.encoder_hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(config.encoder_hidden_dim), + ) + ) + self.encoder_input_proj = nn.ModuleList(encoder_input_proj) + + if config.encoder_type == "lite": + self.encoder = Deimv2LiteEncoder(config) + else: + self.encoder = Deimv2HybridEncoder(config=config) - # denoising part if config.num_denoising > 0: self.denoising_class_embed = nn.Embedding( config.num_labels + 1, config.d_model, padding_idx=config.num_labels ) - # decoder embedding if config.learn_initial_query: self.weight_embedding = nn.Embedding(config.num_queries, config.d_model) - # encoder head self.enc_output = nn.Sequential( nn.Linear(config.d_model, config.d_model), nn.LayerNorm(config.d_model, eps=config.layer_norm_eps), ) self.enc_score_head = nn.Linear(config.d_model, config.num_labels) - self.enc_bbox_head = Deimv2MLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3) + self.enc_bbox_head = Deimv2MLP(config.d_model, config.d_model, 4, 3) - # init encoder output anchors and valid_mask if config.anchor_image_size: self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype) + num_backbone_outs = len(config.decoder_in_channels) - decoder_input_proj_list = [] - for i in range(num_backbone_outs): - in_channels = config.decoder_in_channels[i] - decoder_input_proj_list.append( - nn.Sequential( - nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False), - nn.BatchNorm2d(config.d_model, config.batch_norm_eps), - ) - ) - for _ in range(config.num_feature_levels - num_backbone_outs): - decoder_input_proj_list.append( - nn.Sequential( - nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False), - nn.BatchNorm2d(config.d_model, config.batch_norm_eps), - ) - ) - in_channels = config.d_model - self.decoder = Deimv2Decoder(config) decoder_input_proj = [] in_channels = config.decoder_in_channels[-1] for _ in range(num_backbone_outs): @@ -1617,8 +2036,9 @@ def __init__(self, config: Deimv2Config): batchnorm = nn.BatchNorm2d(config.d_model, config.batch_norm_eps) decoder_input_proj.append(nn.Sequential(conv, batchnorm)) self.decoder_input_proj = nn.ModuleList(decoder_input_proj) + self.decoder = Deimv2Decoder(config) - if config.use_spatial_tuning_adapter: + if config.use_spatial_tuning_adapter and config.backbone_type != "dinov3": self.spatial_tuning_adapter = Deimv2SpatialTuningAdapter(config) self.post_init() @@ -1669,7 +2089,7 @@ def forward( encoder_outputs: torch.FloatTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: list[dict] | None = None, - **kwargs: Unpack[TransformersKwargs], + **kwargs, ) -> tuple[torch.FloatTensor] | Deimv2ModelOutput: r""" inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): @@ -1710,8 +2130,15 @@ def forward( device = pixel_values.device if pixel_mask is None: pixel_mask = torch.ones(((batch_size, height, width)), device=device) - features = self.backbone(pixel_values, pixel_mask) - proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)] + + if self.config.backbone_type == "dinov3": + proj_feats = self.dinov3_backbone(pixel_values) + elif self.config.encoder_type == "lite": + features = self.backbone(pixel_values, pixel_mask) + proj_feats = [source for source, mask in features] + else: + features = self.backbone(pixel_values, pixel_mask) + proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)] else: batch_size = inputs_embeds.shape[0] device = inputs_embeds.device @@ -1722,7 +2149,6 @@ def forward( proj_feats, **kwargs, ) - # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput elif not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], @@ -1730,20 +2156,16 @@ def forward( attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - # Equivalent to def _get_encoder_input - # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/Deimv2_pytorch/src/zoo/Deimv2/Deimv2_decoder.py#L412 sources = [] for level, source in enumerate(encoder_outputs.last_hidden_state): sources.append(self.decoder_input_proj[level](source)) - # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage if self.config.num_feature_levels > len(sources): _len_sources = len(sources) - sources.append(self.decoder_input_proj[_len_sources](encoder_outputs.last_hidden_state)[-1]) + sources.append(self.decoder_input_proj[_len_sources](encoder_outputs.last_hidden_state[-1])) for i in range(_len_sources + 1, self.config.num_feature_levels): sources.append(self.decoder_input_proj[i](encoder_outputs.last_hidden_state[-1])) - # Prepare encoder inputs (by flattening) source_flatten = [] spatial_shapes_list = [] spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long) @@ -1757,8 +2179,9 @@ def forward( source_flatten = torch.cat(source_flatten, 1) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) - # prepare denoising training if self.training and self.config.num_denoising > 0 and labels is not None: + from ..d_fine.modeling_d_fine import get_contrastive_denoising_training_group + ( denoising_class, denoising_bbox_unact, @@ -1780,17 +2203,13 @@ def forward( device = source_flatten.device dtype = source_flatten.dtype - # prepare input for decoder if self.training or self.config.anchor_image_size is None: - # Pass spatial_shapes as tuple to make it hashable and make sure - # lru_cache is working for generate_anchors() spatial_shapes_tuple = tuple(spatial_shapes_list) anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype) else: anchors, valid_mask = self.anchors, self.valid_mask anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype) - # use the valid_mask to selectively retain values in the feature map where the mask is `True` memory = valid_mask.to(source_flatten.dtype) * source_flatten output_memory = self.enc_output(memory) @@ -1812,7 +2231,6 @@ def forward( dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]) ) - # extract region features if self.config.learn_initial_query: target = self.weight_embedding.tile([batch_size, 1, 1]) else: @@ -1824,7 +2242,8 @@ def forward( init_reference_points = reference_points_unact.detach() - # decoder + from ..d_fine.modeling_d_fine import DFineModelOutput + decoder_outputs = self.decoder( inputs_embeds=target, encoder_hidden_states=source_flatten, @@ -1836,7 +2255,7 @@ def forward( **kwargs, ) - return Deimv2ModelOutput( + return DFineModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, intermediate_logits=decoder_outputs.intermediate_logits, @@ -1959,16 +2378,20 @@ def __init__(self, config: Deimv2Config): scaled_dim = round(config.layer_scale * config.hidden_size) num_pred = config.decoder_layers self.class_embed = nn.ModuleList([nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)]) - self.bbox_embed = nn.ModuleList( - [ - Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) - for _ in range(self.eval_idx + 1) - ] - + [ - Deimv2MLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3) - for _ in range(config.decoder_layers - self.eval_idx - 1) - ] - ) + if config.share_bbox_head: + shared_bbox = Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) + self.bbox_embed = nn.ModuleList([shared_bbox] * num_pred) + else: + self.bbox_embed = nn.ModuleList( + [ + Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) + for _ in range(self.eval_idx + 1) + ] + + [ + Deimv2MLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3) + for _ in range(config.decoder_layers - self.eval_idx - 1) + ] + ) self.model.decoder.class_embed = self.class_embed self.model.decoder.bbox_embed = self.bbox_embed @@ -2094,5 +2517,17 @@ def forward( denoising_meta_values=outputs.denoising_meta_values, ) + @property + def _tied_weights_keys(self): + keys = { + r"class_embed.(?![0])\d+": r"^class_embed.0", + "class_embed": "model.decoder.class_embed", + "bbox_embed": "model.decoder.bbox_embed", + } + if getattr(self.config, "share_bbox_head", False): + keys[r"model\.decoder\.bbox_embed\.(?![0])\d+"] = r"model.decoder.bbox_embed.0" + keys[r"bbox_embed.(?![0])\d+"] = r"bbox_embed.0" + return keys + __all__ = ["Deimv2Model", "Deimv2PreTrainedModel", "Deimv2ForObjectDetection"] diff --git a/src/transformers/models/deimv2/modular_deimv2.py b/src/transformers/models/deimv2/modular_deimv2.py index 05bd3b8d4da5..5a7edf292378 100644 --- a/src/transformers/models/deimv2/modular_deimv2.py +++ b/src/transformers/models/deimv2/modular_deimv2.py @@ -39,6 +39,13 @@ DFineRepVggBlock, DFineSCDown, ) +from ..dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig +from ..dinov3_vit.modeling_dinov3_vit import ( + DINOv3ViTEmbeddings, + DINOv3ViTLayer, + DINOv3ViTLayerScale, + DINOv3ViTRopePositionEmbedding, +) from ..rt_detr.modeling_rt_detr import RTDetrAIFILayer @@ -49,7 +56,8 @@ class Deimv2Config(DFineConfig): """ This is the configuration class to store the configuration of a [`Deimv2Model`]. It is used to instantiate a DEIMv2 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of DEIMv2-L-COCO. + with the defaults will yield a similar configuration to that of DEIMv2-HGNetv2-N-COCO + [Intellindust/DEIMv2_HGNetv2_N_COCO](https://huggingface.co/Intellindust/DEIMv2_HGNetv2_N_COCO). Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PreTrainedConfig`] for more information. @@ -198,6 +206,29 @@ class Deimv2Config(DFineConfig): Whether to use the Spatial Tuning Adapter (STA) for DINOv2 backbone variants. sta_inplanes (`int`, *optional*, defaults to 16): Number of input planes for the STA convolutional stem. + encoder_type (`str`, *optional*, defaults to `"hybrid"`): + Type of encoder to use. `"hybrid"` uses the full HybridEncoder with AIFI, FPN, and PAN. + `"lite"` uses the lightweight LiteEncoder with GAP fusion for smaller variants (Atto, Femto, Pico). + use_gateway (`bool`, *optional*, defaults to `True`): + Whether to use the gateway mechanism (cross-attention gating) in decoder layers. When `False`, + uses RMSNorm on the encoder attention output instead. + share_bbox_head (`bool`, *optional*, defaults to `False`): + Whether to share the bounding box prediction head across all decoder layers. + backbone_type (`str`, *optional*, defaults to `"hgnetv2"`): + Type of backbone to use. `"hgnetv2"` uses HGNetV2, `"dinov3"` uses DINOv3 ViT backbone with STA. + dinov3_backbone_config (`dict`, *optional*): + Configuration dictionary for the DINOv3 ViT backbone. Passed as kwargs to `DINOv3ViTConfig`. + dinov3_interaction_indexes (`list[int]`, *optional*): + Layer indices in the DINOv3 ViT backbone from which to extract intermediate features. + dinov3_hidden_dim (`int`, *optional*): + Hidden dimension for the DINOv3 backbone projection convolutions. If `None`, uses `hidden_size` from + the DINOv3 ViT config. + dinov3_apply_layernorm (`bool`, *optional*, defaults to `False`): + Whether to apply LayerNorm to intermediate features extracted from the DINOv3 ViT backbone. + encoder_has_trailing_conv (`bool`, *optional*, defaults to `True`): + Whether the encoder's CSP blocks include a trailing 3x3 convolution after the bottleneck path. + `True` for RepNCSPELAN4 (used by HGNetV2 N and LiteEncoder variants). + `False` for RepNCSPELAN5 (used by DINOv3 variants). tie_word_embeddings (`bool`, *optional*, defaults to `True`): Whether to tie weight embeddings. """ @@ -276,6 +307,15 @@ def __init__( encoder_fuse_op="sum", use_spatial_tuning_adapter=False, sta_inplanes=16, + encoder_type="hybrid", + use_gateway=True, + share_bbox_head=False, + backbone_type="hgnetv2", + dinov3_backbone_config=None, + dinov3_interaction_indexes=None, + dinov3_hidden_dim=None, + dinov3_apply_layernorm=False, + encoder_has_trailing_conv=True, tie_word_embeddings=True, **kwargs, ): @@ -285,6 +325,15 @@ def __init__( self.encoder_fuse_op = encoder_fuse_op self.use_spatial_tuning_adapter = use_spatial_tuning_adapter self.sta_inplanes = sta_inplanes + self.encoder_type = encoder_type + self.use_gateway = use_gateway + self.share_bbox_head = share_bbox_head + self.backbone_type = backbone_type + self.dinov3_backbone_config = dinov3_backbone_config + self.dinov3_interaction_indexes = dinov3_interaction_indexes + self.dinov3_hidden_dim = dinov3_hidden_dim + self.dinov3_apply_layernorm = dinov3_apply_layernorm + self.encoder_has_trailing_conv = encoder_has_trailing_conv super().__init__( initializer_range=initializer_range, initializer_bias_prior_prob=initializer_bias_prior_prob, @@ -355,6 +404,10 @@ def __init__( ) +class Deimv2DINOv3ViTConfig(DINOv3ViTConfig): + model_type = "deimv2_dinov3_vit" + + class Deimv2DecoderOutput(DFineDecoderOutput): pass @@ -407,6 +460,22 @@ class Deimv2RepVggBlock(DFineRepVggBlock): pass +class Deimv2DINOv3ViTEmbeddings(DINOv3ViTEmbeddings): + pass + + +class Deimv2DINOv3ViTLayerScale(DINOv3ViTLayerScale): + pass + + +class Deimv2DINOv3ViTLayer(DINOv3ViTLayer): + pass + + +class Deimv2DINOv3ViTRopePositionEmbedding(DINOv3ViTRopePositionEmbedding): + pass + + class Deimv2CSPRepLayer2(nn.Module): def __init__( self, config: Deimv2Config, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0 @@ -418,7 +487,10 @@ def __init__( self.bottlenecks = nn.ModuleList( [Deimv2RepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)] ) - self.conv3 = Deimv2ConvNormLayer(config, hidden_channels, out_channels, 3, 1, activation=activation) + if config.encoder_has_trailing_conv: + self.conv3 = Deimv2ConvNormLayer(config, hidden_channels, out_channels, 3, 1, activation=activation) + else: + self.conv3 = nn.Identity() def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: chunks = list(self.conv1(hidden_state).chunk(2, 1)) @@ -429,13 +501,21 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: class Deimv2RepNCSPELAN5(nn.Module): - def __init__(self, config: Deimv2Config, numb_blocks: int = 3): + def __init__( + self, + config: Deimv2Config, + numb_blocks: int = 3, + c1: int | None = None, + c2: int | None = None, + c3: int | None = None, + c4: int | None = None, + ): super().__init__() act = config.activation_function - c1 = config.encoder_hidden_dim - c2 = config.encoder_hidden_dim - c3 = config.encoder_hidden_dim * 2 - c4 = round(config.hidden_expansion * config.encoder_hidden_dim // 2) + c1 = c1 if c1 is not None else config.encoder_hidden_dim + c2 = c2 if c2 is not None else config.encoder_hidden_dim + c3 = c3 if c3 is not None else config.encoder_hidden_dim * 2 + c4 = c4 if c4 is not None else round(config.hidden_expansion * config.encoder_hidden_dim // 2) self.conv_dim = c3 // 2 self.conv1 = Deimv2ConvNormLayer(config, c1, c3, 1, 1, activation=act) self.csp_rep1 = Deimv2CSPRepLayer2(config, c3 // 2, c4, num_blocks=numb_blocks) @@ -496,6 +576,135 @@ def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tenso return c2, c3, c4 +class Deimv2GAPFusion(nn.Module): + def __init__(self, config: Deimv2Config, channels: int): + super().__init__() + self.cv = Deimv2ConvNormLayer(config, channels, channels, 1, 1, activation=config.activation_function) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return self.cv(hidden_state + F.adaptive_avg_pool2d(hidden_state, 1)) + + +class Deimv2LiteEncoder(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + hidden_dim = config.encoder_hidden_dim + act = config.activation_function + + self.input_proj = nn.ModuleList() + for in_channel in config.encoder_in_channels: + self.input_proj.append( + nn.Sequential( + nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(hidden_dim), + ) + ) + + self.down_sample1 = nn.Sequential( + nn.AvgPool2d(kernel_size=3, stride=2, padding=1), + nn.Conv2d(hidden_dim, hidden_dim, 1, 1, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.SiLU(inplace=True) if act == "silu" else nn.ReLU(inplace=True), + ) + self.down_sample2 = nn.Sequential( + nn.AvgPool2d(kernel_size=3, stride=2, padding=1), + nn.Conv2d(hidden_dim, hidden_dim, 1, 1, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.SiLU(inplace=True) if act == "silu" else nn.ReLU(inplace=True), + ) + + self.bi_fusion = Deimv2GAPFusion(config, hidden_dim) + + c1, c2 = hidden_dim, hidden_dim + c3 = hidden_dim * 2 + c4 = round(config.hidden_expansion * hidden_dim // 2) + num_blocks = round(3 * config.depth_mult) + self.fpn_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks, c1=c1, c2=c2, c3=c3, c4=c4) + self.pan_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks, c1=c1, c2=c2, c3=c3, c4=c4) + + def forward(self, inputs_embeds=None, **kwargs) -> BaseModelOutput: + feats = inputs_embeds + proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] + proj_feats.append(self.down_sample1(proj_feats[-1])) + + proj_feats[-1] = self.bi_fusion(proj_feats[-1]) + + outs = [] + fuse_feat = proj_feats[0] + F.interpolate(proj_feats[1], scale_factor=2.0, mode="nearest") + outs.append(self.fpn_block(fuse_feat)) + + fuse_feat = proj_feats[1] + self.down_sample2(outs[-1]) + outs.append(self.pan_block(fuse_feat)) + + return BaseModelOutput(last_hidden_state=outs) + + +class Deimv2DINOv3Backbone(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + vit_config = Deimv2DINOv3ViTConfig(**config.dinov3_backbone_config) + + self.embeddings = Deimv2DINOv3ViTEmbeddings(vit_config) + self.rope_embeddings = Deimv2DINOv3ViTRopePositionEmbedding(vit_config) + self.layers = nn.ModuleList([Deimv2DINOv3ViTLayer(vit_config) for _ in range(vit_config.num_hidden_layers)]) + + self.apply_layernorm = config.dinov3_apply_layernorm + if self.apply_layernorm: + self.norm = nn.LayerNorm(vit_config.hidden_size, eps=vit_config.layer_norm_eps) + + self.interaction_indexes = config.dinov3_interaction_indexes + self.patch_size = vit_config.patch_size + self.num_prefix_tokens = 1 + vit_config.num_register_tokens + + self.sta = Deimv2SpatialTuningAdapter(config) + + embed_dim = vit_config.hidden_size + hidden_dim = config.dinov3_hidden_dim or embed_dim + sta_ch = config.sta_inplanes + self.convs = nn.ModuleList( + [ + nn.Conv2d(embed_dim + sta_ch * 2, hidden_dim, 1, bias=False), + nn.Conv2d(embed_dim + sta_ch * 4, hidden_dim, 1, bias=False), + nn.Conv2d(embed_dim + sta_ch * 4, hidden_dim, 1, bias=False), + ] + ) + self.norms = nn.ModuleList([nn.BatchNorm2d(hidden_dim) for _ in range(3)]) + + def forward(self, pixel_values: torch.Tensor) -> list[torch.Tensor]: + hidden_states = self.embeddings(pixel_values) + position_embeddings = self.rope_embeddings(pixel_values) + + intermediate_outputs = [] + for i, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, position_embeddings=position_embeddings) + if i in self.interaction_indexes: + out = self.norm(hidden_states) if self.apply_layernorm else hidden_states + intermediate_outputs.append(out) + + batch_size = pixel_values.shape[0] + h_patches = pixel_values.shape[2] // self.patch_size + w_patches = pixel_values.shape[3] // self.patch_size + + sem_feats = [] + num_scales = len(intermediate_outputs) + for i, feat in enumerate(intermediate_outputs): + patch_tokens = feat[:, self.num_prefix_tokens :] + spatial = patch_tokens.transpose(1, 2).reshape(batch_size, -1, h_patches, w_patches).contiguous() + resize_h = int(h_patches * 2 ** (num_scales - 2 - i)) + resize_w = int(w_patches * 2 ** (num_scales - 2 - i)) + spatial = F.interpolate(spatial, size=[resize_h, resize_w], mode="bilinear", align_corners=False) + sem_feats.append(spatial) + + detail_feats = self.sta(pixel_values) + + outs = [] + for i, (sem_feat, detail_feat) in enumerate(zip(sem_feats, detail_feats)): + fused = torch.cat([sem_feat, detail_feat], dim=1) + outs.append(self.norms[i](self.convs[i](fused))) + + return outs + + class Deimv2Integral(DFineIntegral): pass @@ -508,10 +717,63 @@ class Deimv2DecoderLayer(DFineDecoderLayer): def __init__(self, config: Deimv2Config): super().__init__(config) self.encoder_attn = Deimv2MultiscaleDeformableAttention(config=config) - self.gateway = Deimv2Gate(config.d_model) self.self_attn_layer_norm = Deimv2RMSNorm(config.d_model) self.final_layer_norm = Deimv2RMSNorm(config.d_model) self.mlp = Deimv2SwiGLUFFN(config.d_model, config.decoder_ffn_dim // 2, config.d_model) + if config.use_gateway: + self.gateway = Deimv2Gate(config.d_model) + else: + del self.gateway + self.encoder_attn_layer_norm = Deimv2RMSNorm(config.d_model) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor | None = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=encoder_attention_mask, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + + hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings + hidden_states, _ = self.encoder_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if hasattr(self, "gateway"): + hidden_states = self.gateway(residual, hidden_states) + else: + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states.clamp(min=-65504, max=65504)) + + return hidden_states class Deimv2MLPPredictionHead(DFineMLP): @@ -597,6 +859,20 @@ def _init_weights(self, module): init.ones_(module.weight) init.zeros_(module.bias) + if isinstance(module, Deimv2DINOv3ViTEmbeddings): + init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) + if module.config.num_register_tokens > 0: + init.trunc_normal_(module.register_tokens, mean=0.0, std=self.config.initializer_range) + init.zeros_(module.mask_token) + + if isinstance(module, Deimv2DINOv3ViTLayerScale) and self.config.dinov3_backbone_config is not None: + layerscale_value = self.config.dinov3_backbone_config.get("layerscale_value", 1.0) + init.constant_(module.lambda1, layerscale_value) + + if isinstance(module, Deimv2DINOv3ViTRopePositionEmbedding): + inv_freq = 1 / module.base ** torch.arange(0, 1, 4 / module.head_dim, dtype=torch.float32) + init.copy_(module.inv_freq, inv_freq) + if hasattr(module, "weight_embedding") and self.config.learn_initial_query: init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: @@ -605,7 +881,7 @@ def _init_weights(self, module): class Deimv2HybridEncoder(DFineHybridEncoder): def __init__(self, config: Deimv2Config): - Deimv2PreTrainedModel.__init__(config) + Deimv2PreTrainedModel.__init__(self, config) self.config = config self.in_channels = config.encoder_in_channels self.num_fpn_stages = len(self.in_channels) - 1 @@ -707,9 +983,49 @@ def __init__(self, config: Deimv2Config): class Deimv2Model(DFineModel): def __init__(self, config: Deimv2Config): - super().__init__(config) - del self.decoder_input_proj - self.encoder = Deimv2HybridEncoder(config=config) + Deimv2PreTrainedModel.__init__(self, config) + + if config.backbone_type == "dinov3": + self.dinov3_backbone = Deimv2DINOv3Backbone(config) + else: + from ..d_fine.modeling_d_fine import DFineConvEncoder + + self.backbone = DFineConvEncoder(config) + if config.encoder_type != "lite": + intermediate_channel_sizes = self.backbone.intermediate_channel_sizes + encoder_input_proj = [] + for in_channel in intermediate_channel_sizes: + encoder_input_proj.append( + nn.Sequential( + nn.Conv2d(in_channel, config.encoder_hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(config.encoder_hidden_dim), + ) + ) + self.encoder_input_proj = nn.ModuleList(encoder_input_proj) + + if config.encoder_type == "lite": + self.encoder = Deimv2LiteEncoder(config) + else: + self.encoder = Deimv2HybridEncoder(config=config) + + if config.num_denoising > 0: + self.denoising_class_embed = nn.Embedding( + config.num_labels + 1, config.d_model, padding_idx=config.num_labels + ) + + if config.learn_initial_query: + self.weight_embedding = nn.Embedding(config.num_queries, config.d_model) + + self.enc_output = nn.Sequential( + nn.Linear(config.d_model, config.d_model), + nn.LayerNorm(config.d_model, eps=config.layer_norm_eps), + ) + self.enc_score_head = nn.Linear(config.d_model, config.num_labels) + self.enc_bbox_head = Deimv2MLP(config.d_model, config.d_model, 4, 3) + + if config.anchor_image_size: + self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype) + num_backbone_outs = len(config.decoder_in_channels) decoder_input_proj = [] in_channels = config.decoder_in_channels[-1] @@ -730,18 +1046,189 @@ def __init__(self, config: Deimv2Config): self.decoder_input_proj = nn.ModuleList(decoder_input_proj) self.decoder = Deimv2Decoder(config) - if config.use_spatial_tuning_adapter: + if config.use_spatial_tuning_adapter and config.backbone_type != "dinov3": self.spatial_tuning_adapter = Deimv2SpatialTuningAdapter(config) + self.post_init() + + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: torch.LongTensor | None = None, + encoder_outputs: torch.FloatTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: list[dict] | None = None, + **kwargs, + ): + if pixel_values is None and inputs_embeds is None: + raise ValueError("You have to specify either pixel_values or inputs_embeds") + + if inputs_embeds is None: + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + + if self.config.backbone_type == "dinov3": + proj_feats = self.dinov3_backbone(pixel_values) + elif self.config.encoder_type == "lite": + features = self.backbone(pixel_values, pixel_mask) + proj_feats = [source for source, mask in features] + else: + features = self.backbone(pixel_values, pixel_mask) + proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)] + else: + batch_size = inputs_embeds.shape[0] + device = inputs_embeds.device + proj_feats = inputs_embeds + + if encoder_outputs is None: + encoder_outputs = self.encoder( + proj_feats, + **kwargs, + ) + elif not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + sources = [] + for level, source in enumerate(encoder_outputs.last_hidden_state): + sources.append(self.decoder_input_proj[level](source)) + + if self.config.num_feature_levels > len(sources): + _len_sources = len(sources) + sources.append(self.decoder_input_proj[_len_sources](encoder_outputs.last_hidden_state[-1])) + for i in range(_len_sources + 1, self.config.num_feature_levels): + sources.append(self.decoder_input_proj[i](encoder_outputs.last_hidden_state[-1])) + + source_flatten = [] + spatial_shapes_list = [] + spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long) + for level, source in enumerate(sources): + height, width = source.shape[-2:] + spatial_shapes[level, 0] = height + spatial_shapes[level, 1] = width + spatial_shapes_list.append((height, width)) + source = source.flatten(2).transpose(1, 2) + source_flatten.append(source) + source_flatten = torch.cat(source_flatten, 1) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + + if self.training and self.config.num_denoising > 0 and labels is not None: + from ..d_fine.modeling_d_fine import get_contrastive_denoising_training_group + + ( + denoising_class, + denoising_bbox_unact, + attention_mask, + denoising_meta_values, + ) = get_contrastive_denoising_training_group( + targets=labels, + num_classes=self.config.num_labels, + num_queries=self.config.num_queries, + class_embed=self.denoising_class_embed, + num_denoising_queries=self.config.num_denoising, + label_noise_ratio=self.config.label_noise_ratio, + box_noise_scale=self.config.box_noise_scale, + ) + else: + denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None + + batch_size = len(source_flatten) + device = source_flatten.device + dtype = source_flatten.dtype + + if self.training or self.config.anchor_image_size is None: + spatial_shapes_tuple = tuple(spatial_shapes_list) + anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype) + else: + anchors, valid_mask = self.anchors, self.valid_mask + anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype) + + memory = valid_mask.to(source_flatten.dtype) * source_flatten + + output_memory = self.enc_output(memory) + + enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors + + _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1) + + reference_points_unact = enc_outputs_coord_logits.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1]) + ) + + enc_topk_bboxes = F.sigmoid(reference_points_unact) + if denoising_bbox_unact is not None: + reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1) + + enc_topk_logits = enc_outputs_class.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]) + ) + + if self.config.learn_initial_query: + target = self.weight_embedding.tile([batch_size, 1, 1]) + else: + target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1])) + target = target.detach() + + if denoising_class is not None: + target = torch.concat([denoising_class, target], 1) + + init_reference_points = reference_points_unact.detach() + + from ..d_fine.modeling_d_fine import DFineModelOutput + + decoder_outputs = self.decoder( + inputs_embeds=target, + encoder_hidden_states=source_flatten, + encoder_attention_mask=attention_mask, + reference_points=init_reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + **kwargs, + ) + + return DFineModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + intermediate_logits=decoder_outputs.intermediate_logits, + intermediate_reference_points=decoder_outputs.intermediate_reference_points, + intermediate_predicted_corners=decoder_outputs.intermediate_predicted_corners, + initial_reference_points=decoder_outputs.initial_reference_points, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + init_reference_points=init_reference_points, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + enc_outputs_class=enc_outputs_class, + enc_outputs_coord_logits=enc_outputs_coord_logits, + denoising_meta_values=denoising_meta_values, + ) + class Deimv2ForObjectDetection(DFineForObjectDetection): _no_split_modules = None - _tied_weights_keys = { - r"bbox_embed.(?![0])\d+": r"bbox_embed.0", - r"class_embed.(?![0])\d+": r"^class_embed.0", - "class_embed": "model.decoder.class_embed", - "bbox_embed": "model.decoder.bbox_embed", - } + + @property + def _tied_weights_keys(self): + keys = { + r"class_embed.(?![0])\d+": r"^class_embed.0", + "class_embed": "model.decoder.class_embed", + "bbox_embed": "model.decoder.bbox_embed", + } + if getattr(self.config, "share_bbox_head", False): + keys[r"model\.decoder\.bbox_embed\.(?![0])\d+"] = r"model.decoder.bbox_embed.0" + keys[r"bbox_embed.(?![0])\d+"] = r"bbox_embed.0" + return keys def __init__(self, config: Deimv2Config): Deimv2PreTrainedModel.__init__(self, config) @@ -751,16 +1238,20 @@ def __init__(self, config: Deimv2Config): scaled_dim = round(config.layer_scale * config.hidden_size) num_pred = config.decoder_layers self.class_embed = nn.ModuleList([nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)]) - self.bbox_embed = nn.ModuleList( - [ - Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) - for _ in range(self.eval_idx + 1) - ] - + [ - Deimv2MLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3) - for _ in range(config.decoder_layers - self.eval_idx - 1) - ] - ) + if config.share_bbox_head: + shared_bbox = Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) + self.bbox_embed = nn.ModuleList([shared_bbox] * num_pred) + else: + self.bbox_embed = nn.ModuleList( + [ + Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) + for _ in range(self.eval_idx + 1) + ] + + [ + Deimv2MLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3) + for _ in range(config.decoder_layers - self.eval_idx - 1) + ] + ) self.model.decoder.class_embed = self.class_embed self.model.decoder.bbox_embed = self.bbox_embed diff --git a/tests/models/deimv2/test_modeling_deimv2.py b/tests/models/deimv2/test_modeling_deimv2.py index 9b612c7833d1..db4ce97086b4 100644 --- a/tests/models/deimv2/test_modeling_deimv2.py +++ b/tests/models/deimv2/test_modeling_deimv2.py @@ -46,7 +46,12 @@ from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_modeling_common import ( + TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, + ModelTesterMixin, + _test_eager_matches_sdpa_inference, + floats_tensor, +) from ...test_pipeline_mixin import PipelineTesterMixin @@ -650,6 +655,1009 @@ def test_inference_equivalence_for_static_and_dynamic_anchors(self, dtype_str): ) +class Deimv2LiteEncoderModelTester: + def __init__( + self, + parent, + batch_size=3, + is_training=True, + use_labels=True, + n_targets=3, + num_labels=10, + initializer_range=0.02, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + encoder_hidden_dim=32, + encoder_in_channels=[256], + feat_strides=[16, 32], + dropout=0.0, + activation_dropout=0.0, + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + d_model=32, + num_queries=10, + decoder_in_channels=[32, 32], + decoder_ffn_dim=64, + num_feature_levels=2, + decoder_n_points=[4, 2], + decoder_n_levels=2, + decoder_layers=2, + decoder_attention_heads=2, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=0, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + image_size=64, + disable_custom_kernels=True, + with_box_refine=True, + decoder_offset_scale=0.5, + eval_idx=-1, + layer_scale=1, + reg_max=32, + reg_scale=4.0, + depth_mult=0.34, + hidden_expansion=0.5, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = 3 + self.is_training = is_training + self.use_labels = use_labels + self.n_targets = n_targets + self.num_labels = num_labels + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_layers = 0 + self.encoder_ffn_dim = 64 + self.encoder_attention_heads = 2 + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = [] + self.positional_encoding_temperature = positional_encoding_temperature + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.eval_size = eval_size + self.normalize_before = normalize_before + self.d_model = d_model + self.num_queries = num_queries + self.decoder_in_channels = decoder_in_channels + self.decoder_ffn_dim = decoder_ffn_dim + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_n_levels = decoder_n_levels + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.decoder_offset_scale = decoder_offset_scale + self.eval_idx = eval_idx + self.layer_scale = layer_scale + self.reg_max = reg_max + self.reg_scale = reg_scale + self.depth_mult = depth_mult + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.image_size = image_size + self.disable_custom_kernels = disable_custom_kernels + self.with_box_refine = with_box_refine + self.hidden_expansion = hidden_expansion + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device) + + labels = None + if self.use_labels: + labels = [] + for i in range(self.batch_size): + target = {} + target["class_labels"] = torch.randint( + high=self.num_labels, size=(self.n_targets,), device=torch_device + ) + target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device) + labels.append(target) + + config = self.get_config() + config.num_labels = self.num_labels + return config, pixel_values, pixel_mask, labels + + def get_config(self): + backbone_config = HGNetV2Config( + stage_in_channels=[16, 64, 128], + stage_mid_channels=[16, 32, 64], + stage_out_channels=[64, 128, 256], + stage_num_blocks=[1, 1, 1], + stage_downsample=[False, True, True], + stage_light_block=[False, False, True], + stage_kernel_size=[3, 3, 3], + stage_numb_of_layers=[3, 3, 3], + embeddings_size=10, + hidden_sizes=[64, 128, 256], + depths=[1, 1, 1], + out_features=["stage3"], + out_indices=[3], + stem_channels=[3, 16, 16], + use_lab=True, + ) + return Deimv2Config( + backbone_config=backbone_config, + encoder_hidden_dim=self.encoder_hidden_dim, + encoder_in_channels=self.encoder_in_channels, + feat_strides=self.feat_strides, + encoder_layers=self.encoder_layers, + encoder_ffn_dim=self.encoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + dropout=self.dropout, + activation_dropout=self.activation_dropout, + encode_proj_layers=self.encode_proj_layers, + positional_encoding_temperature=self.positional_encoding_temperature, + encoder_activation_function=self.encoder_activation_function, + activation_function=self.activation_function, + eval_size=self.eval_size, + normalize_before=self.normalize_before, + d_model=self.d_model, + num_queries=self.num_queries, + decoder_in_channels=self.decoder_in_channels, + decoder_ffn_dim=self.decoder_ffn_dim, + num_feature_levels=self.num_feature_levels, + decoder_n_points=self.decoder_n_points, + decoder_n_levels=self.decoder_n_levels, + decoder_layers=self.decoder_layers, + decoder_attention_heads=self.decoder_attention_heads, + decoder_activation_function=self.decoder_activation_function, + decoder_offset_scale=self.decoder_offset_scale, + eval_idx=self.eval_idx, + layer_scale=self.layer_scale, + reg_max=self.reg_max, + reg_scale=self.reg_scale, + depth_mult=self.depth_mult, + attention_dropout=self.attention_dropout, + num_denoising=self.num_denoising, + label_noise_ratio=self.label_noise_ratio, + box_noise_scale=self.box_noise_scale, + learn_initial_query=self.learn_initial_query, + anchor_image_size=self.anchor_image_size, + image_size=self.image_size, + disable_custom_kernels=self.disable_custom_kernels, + with_box_refine=self.with_box_refine, + encoder_type="lite", + use_gateway=False, + share_bbox_head=False, + ) + + def prepare_config_and_inputs_for_common(self): + config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + def create_and_check_deimv2_model(self, config, pixel_values, pixel_mask, labels): + model = Deimv2Model(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.d_model)) + + def create_and_check_deimv2_object_detection_head_model(self, config, pixel_values, pixel_mask, labels): + model = Deimv2ForObjectDetection(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels) + + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + +@require_torch +class Deimv2LiteEncoderModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (Deimv2Model, Deimv2ForObjectDetection) if is_torch_available() else () + pipeline_model_mapping = {} + is_encoder_decoder = True + + test_missing_keys = False + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "Deimv2ForObjectDetection": + labels = [] + for i in range(self.model_tester.batch_size): + target = {} + target["class_labels"] = torch.ones( + size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long + ) + target["boxes"] = torch.ones( + self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float + ) + labels.append(target) + inputs_dict["labels"] = labels + + return inputs_dict + + def setUp(self): + self.model_tester = Deimv2LiteEncoderModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=Deimv2Config, + has_text_modality=False, + common_properties=["hidden_size", "num_attention_heads"], + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_deimv2_lite_encoder_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deimv2_model(*config_and_inputs) + + def test_deimv2_lite_encoder_object_detection_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deimv2_object_detection_head_model(*config_and_inputs) + + @unittest.skip(reason="Deimv2 doesn't work well with `nn.DataParallel") + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip(reason="Deimv2 does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Deimv2 does not use test_inputs_embeds_matches_input_ids") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip(reason="Deimv2 does not support input and output embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="Deimv2 does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="Deimv2 does not use token embeddings") + def test_resize_tokens_embeddings(self): + pass + + @unittest.skip(reason="Feed forward chunking is not implemented") + def test_feed_forward_chunking(self): + pass + + @unittest.skip(reason="Weight tying is hardcoded (module_x = module_y) and always `True`") + def test_load_save_without_tied_weights(self): + pass + + @unittest.skip(reason="LiteEncoder has no AIFI layers, so no encoder attentions are produced") + def test_attention_outputs(self): + pass + + @unittest.skip(reason="LiteEncoder has no encoder attentions for gradient retention check") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip( + reason="LiteEncoder expects exactly 1 backbone feature map (single scale), but test_backbone_selection hardcodes 3-output backbones (out_indices=[2,3,4])" + ) + def test_backbone_selection(self): + pass + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if config.is_encoder_decoder: + hidden_states = outputs.decoder_hidden_states + + expected_num_layers = self.model_tester.decoder_layers + 1 + + self.assertIsInstance(hidden_states, (list, tuple)) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [self.model_tester.num_queries, self.model_tester.d_model], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + arg_names = [*signature.parameters.keys()] + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + +class Deimv2DINOv3ModelTester: + def __init__( + self, + parent, + batch_size=3, + is_training=True, + use_labels=True, + n_targets=3, + num_labels=10, + initializer_range=0.02, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + encoder_hidden_dim=32, + encoder_in_channels=[32, 32, 32], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=64, + encoder_attention_heads=2, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + d_model=32, + num_queries=30, + decoder_in_channels=[32, 32, 32], + decoder_ffn_dim=64, + num_feature_levels=3, + decoder_n_points=4, + decoder_n_levels=3, + decoder_layers=2, + decoder_attention_heads=2, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=0, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + image_size=64, + disable_custom_kernels=True, + with_box_refine=True, + decoder_offset_scale=0.5, + eval_idx=-1, + layer_scale=1, + reg_max=32, + reg_scale=4.0, + depth_mult=0.34, + hidden_expansion=0.5, + sta_inplanes=8, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = 3 + self.is_training = is_training + self.use_labels = use_labels + self.n_targets = n_targets + self.num_labels = num_labels + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_layers = encoder_layers + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = encode_proj_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.eval_size = eval_size + self.normalize_before = normalize_before + self.d_model = d_model + self.num_queries = num_queries + self.decoder_in_channels = decoder_in_channels + self.decoder_ffn_dim = decoder_ffn_dim + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_n_levels = decoder_n_levels + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.decoder_offset_scale = decoder_offset_scale + self.eval_idx = eval_idx + self.layer_scale = layer_scale + self.reg_max = reg_max + self.reg_scale = reg_scale + self.depth_mult = depth_mult + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.image_size = image_size + self.disable_custom_kernels = disable_custom_kernels + self.with_box_refine = with_box_refine + self.hidden_expansion = hidden_expansion + self.sta_inplanes = sta_inplanes + + self.encoder_seq_length = math.ceil(self.image_size / 32) * math.ceil(self.image_size / 32) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device) + + labels = None + if self.use_labels: + labels = [] + for i in range(self.batch_size): + target = {} + target["class_labels"] = torch.randint( + high=self.num_labels, size=(self.n_targets,), device=torch_device + ) + target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device) + labels.append(target) + + config = self.get_config() + config.num_labels = self.num_labels + return config, pixel_values, pixel_mask, labels + + def get_config(self): + hidden_sizes = [64, 128, 256, 512] + backbone_config = HGNetV2Config( + stage_in_channels=[16, 64, 128, 256], + stage_mid_channels=[16, 32, 64, 128], + stage_out_channels=[64, 128, 256, 512], + stage_num_blocks=[1, 1, 2, 1], + stage_downsample=[False, True, True, True], + stage_light_block=[False, False, True, True], + stage_kernel_size=[3, 3, 5, 5], + stage_numb_of_layers=[3, 3, 3, 3], + embeddings_size=10, + hidden_sizes=hidden_sizes, + depths=[1, 1, 2, 1], + out_features=["stage2", "stage3", "stage4"], + out_indices=[2, 3, 4], + stem_channels=[3, 16, 16], + use_lab=True, + ) + dinov3_backbone_config = { + "hidden_size": 32, + "num_attention_heads": 2, + "num_hidden_layers": 4, + "intermediate_size": 64, + "num_register_tokens": 1, + "layerscale_value": 1.0, + "use_gated_mlp": False, + "rope_theta": 100.0, + "patch_size": 16, + "image_size": self.image_size, + } + return Deimv2Config( + backbone_config=backbone_config, + encoder_hidden_dim=self.encoder_hidden_dim, + encoder_in_channels=self.encoder_in_channels, + feat_strides=self.feat_strides, + encoder_layers=self.encoder_layers, + encoder_ffn_dim=self.encoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + dropout=self.dropout, + activation_dropout=self.activation_dropout, + encode_proj_layers=self.encode_proj_layers, + positional_encoding_temperature=self.positional_encoding_temperature, + encoder_activation_function=self.encoder_activation_function, + activation_function=self.activation_function, + eval_size=self.eval_size, + normalize_before=self.normalize_before, + d_model=self.d_model, + num_queries=self.num_queries, + decoder_in_channels=self.decoder_in_channels, + decoder_ffn_dim=self.decoder_ffn_dim, + num_feature_levels=self.num_feature_levels, + decoder_n_points=self.decoder_n_points, + decoder_n_levels=self.decoder_n_levels, + decoder_layers=self.decoder_layers, + decoder_attention_heads=self.decoder_attention_heads, + decoder_activation_function=self.decoder_activation_function, + decoder_offset_scale=self.decoder_offset_scale, + eval_idx=self.eval_idx, + layer_scale=self.layer_scale, + reg_max=self.reg_max, + reg_scale=self.reg_scale, + depth_mult=self.depth_mult, + attention_dropout=self.attention_dropout, + num_denoising=self.num_denoising, + label_noise_ratio=self.label_noise_ratio, + box_noise_scale=self.box_noise_scale, + learn_initial_query=self.learn_initial_query, + anchor_image_size=self.anchor_image_size, + image_size=self.image_size, + disable_custom_kernels=self.disable_custom_kernels, + with_box_refine=self.with_box_refine, + backbone_type="dinov3", + dinov3_backbone_config=dinov3_backbone_config, + dinov3_interaction_indexes=[1, 2, 3], + dinov3_hidden_dim=self.encoder_hidden_dim, + dinov3_apply_layernorm=False, + sta_inplanes=self.sta_inplanes, + use_spatial_tuning_adapter=True, + encoder_has_trailing_conv=False, + ) + + def prepare_config_and_inputs_for_common(self): + config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + def create_and_check_deimv2_model(self, config, pixel_values, pixel_mask, labels): + model = Deimv2Model(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.d_model)) + + def create_and_check_deimv2_object_detection_head_model(self, config, pixel_values, pixel_mask, labels): + model = Deimv2ForObjectDetection(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels) + + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + +@require_torch +class Deimv2DINOv3ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (Deimv2Model, Deimv2ForObjectDetection) if is_torch_available() else () + pipeline_model_mapping = ( + {"image-feature-extraction": Deimv2Model, "object-detection": Deimv2ForObjectDetection} + if is_torch_available() + else {} + ) + is_encoder_decoder = True + + test_missing_keys = False + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "Deimv2ForObjectDetection": + labels = [] + for i in range(self.model_tester.batch_size): + target = {} + target["class_labels"] = torch.ones( + size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long + ) + target["boxes"] = torch.ones( + self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float + ) + labels.append(target) + inputs_dict["labels"] = labels + + return inputs_dict + + def setUp(self): + self.model_tester = Deimv2DINOv3ModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=Deimv2Config, + has_text_modality=False, + common_properties=["hidden_size", "num_attention_heads"], + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_deimv2_dinov3_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deimv2_model(*config_and_inputs) + + def test_deimv2_dinov3_object_detection_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deimv2_object_detection_head_model(*config_and_inputs) + + @unittest.skip(reason="Deimv2 doesn't work well with `nn.DataParallel") + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip(reason="Deimv2 does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Deimv2 does not use test_inputs_embeds_matches_input_ids") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip(reason="Deimv2 does not support input and output embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="Deimv2 does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="Deimv2 does not use token embeddings") + def test_resize_tokens_embeddings(self): + pass + + @unittest.skip(reason="Feed forward chunking is not implemented") + def test_feed_forward_chunking(self): + pass + + @unittest.skip(reason="Weight tying is hardcoded (module_x = module_y) and always `True`") + def test_load_save_without_tied_weights(self): + pass + + @unittest.skip(reason="DINOv3 backbone does not support timm/HF backbone selection") + def test_backbone_selection(self): + pass + + @unittest.skip( + reason="DINOv3 backbone with RoPE and LayerScale produces numerical differences beyond tolerance during offloading" + ) + def test_cpu_offload(self): + pass + + @unittest.skip( + reason="DINOv3 backbone with RoPE and LayerScale produces numerical differences beyond tolerance during offloading" + ) + def test_disk_offload_bin(self): + pass + + @unittest.skip( + reason="DINOv3 backbone with RoPE and LayerScale produces numerical differences beyond tolerance during offloading" + ) + def test_disk_offload_safetensors(self): + pass + + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) + def test_eager_matches_sdpa_inference( + self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels + ): + atols = { + ("cpu", False, torch.float32): 1e-4, + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-4, + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-4, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-4, + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + rtols = { + ("cpu", False, torch.float32): 1e-3, + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-3, + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-3, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-3, + ("cuda", True, torch.bfloat16): 3e-2, + ("cuda", True, torch.float16): 5e-3, + } + _test_eager_matches_sdpa_inference( + self, + name, + torch_dtype, + padding_side, + use_attention_mask, + output_attentions, + enable_kernels, + atols=atols, + rtols=rtols, + ) + + def test_batching_equivalence(self): + super().test_batching_equivalence(atol=1e-4, rtol=1e-4) + + @unittest.skip(reason="Deimv2 is not a generative encoder-decoder model and has no decoder_input_ids") + def test_flex_attention_with_grads(self): + pass + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class._from_config(config, attn_implementation="eager") + config = model.config + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions + self.assertEqual(len(attentions), self.model_tester.encoder_layers) + + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions + self.assertEqual(len(attentions), self.model_tester.encoder_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [ + self.model_tester.encoder_attention_heads, + self.model_tester.encoder_seq_length, + self.model_tester.encoder_seq_length, + ], + ) + out_len = len(outputs) + + correct_outlen = 15 + + if "labels" in inputs_dict: + correct_outlen += 1 + if model_class.__name__ == "Deimv2ForObjectDetection": + correct_outlen += 2 + + self.assertEqual(out_len, correct_outlen) + + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.decoder_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [ + self.model_tester.decoder_attention_heads, + self.model_tester.num_queries, + self.model_tester.num_queries, + ], + ) + + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.decoder_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_queries, + self.model_tester.decoder_attention_heads, + self.model_tester.decoder_n_levels * self.model_tester.decoder_n_points + if isinstance(self.model_tester.decoder_n_points, int) + else sum(self.model_tester.decoder_n_points), + ], + ) + + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + else: + added_hidden_states = 2 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions + + self.assertEqual(len(self_attentions), self.model_tester.encoder_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [ + self.model_tester.encoder_attention_heads, + self.model_tester.encoder_seq_length, + self.model_tester.encoder_seq_length, + ], + ) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", len(self.model_tester.encoder_in_channels) - 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[1].shape[-2:]), + [ + self.model_tester.image_size // self.model_tester.feat_strides[-1], + self.model_tester.image_size // self.model_tester.feat_strides[-1], + ], + ) + + if config.is_encoder_decoder: + hidden_states = outputs.decoder_hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.decoder_layers + 1 + ) + + self.assertIsInstance(hidden_states, (list, tuple)) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [self.model_tester.num_queries, self.model_tester.d_model], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + + outputs = model(**inputs) + + output = outputs[0] + + encoder_hidden_states = outputs.encoder_hidden_states[0] + encoder_attentions = outputs.encoder_attentions[0] + encoder_hidden_states.retain_grad() + encoder_attentions.retain_grad() + + decoder_attentions = outputs.decoder_attentions[0] + decoder_attentions.retain_grad() + + cross_attentions = outputs.cross_attentions[0] + cross_attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(encoder_hidden_states.grad) + self.assertIsNotNone(encoder_attentions.grad) + self.assertIsNotNone(decoder_attentions.grad) + self.assertIsNotNone(cross_attentions.grad) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + arg_names = [*signature.parameters.keys()] + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + @parameterized.expand(["float32", "float16", "bfloat16"]) + @require_torch_accelerator + @slow + def test_inference_with_different_dtypes(self, dtype_str): + dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[dtype_str] + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device).to(dtype) + model.eval() + for key, tensor in inputs_dict.items(): + if tensor.dtype == torch.float32: + inputs_dict[key] = tensor.to(dtype) + with torch.no_grad(): + _ = model(**self._prepare_for_class(inputs_dict, model_class)) + + @parameterized.expand(["float32", "float16", "bfloat16"]) + @require_torch_accelerator + @slow + def test_inference_equivalence_for_static_and_dynamic_anchors(self, dtype_str): + dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[dtype_str] + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + h, w = inputs_dict["pixel_values"].shape[-2:] + + for key, tensor in inputs_dict.items(): + if tensor.dtype == torch.float32: + inputs_dict[key] = tensor.to(dtype) + + for model_class in self.all_model_classes: + with tempfile.TemporaryDirectory() as tmpdirname: + model_class(config).save_pretrained(tmpdirname) + model_static = model_class.from_pretrained( + tmpdirname, anchor_image_size=[h, w], device_map=torch_device, dtype=dtype + ).eval() + model_dynamic = model_class.from_pretrained( + tmpdirname, anchor_image_size=None, device_map=torch_device, dtype=dtype + ).eval() + + self.assertIsNotNone(model_static.config.anchor_image_size) + self.assertIsNone(model_dynamic.config.anchor_image_size) + + with torch.no_grad(): + outputs_static = model_static(**self._prepare_for_class(inputs_dict, model_class)) + outputs_dynamic = model_dynamic(**self._prepare_for_class(inputs_dict, model_class)) + + torch.testing.assert_close( + outputs_static.last_hidden_state, outputs_dynamic.last_hidden_state, rtol=5e-3, atol=5e-3 + ) + + def prepare_img(): image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") return image From adc4079e540e1549581ca75bae3cd13de7e85afa Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Sun, 1 Mar 2026 11:46:55 +0400 Subject: [PATCH 04/25] fix: Fix ci/circleci: check_repository_consistency --- utils/check_config_attributes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index c301d841ab2f..413f5e6d61fa 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -96,6 +96,8 @@ "SwitchTransformersConfig": True, "DetrConfig": True, "DFineConfig": True, + "Deimv2Config": True, + "Deimv2DINOv3ViTConfig": True, "GroundingDinoConfig": True, "MMGroundingDinoConfig": True, "RTDetrConfig": True, From 39d300e87c7988c2c9da4826c7ef7ee11e1332bc Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Tue, 17 Mar 2026 18:46:31 +0400 Subject: [PATCH 05/25] refactor: Resolve review comments --- .../models/deimv2/configuration_deimv2.py | 195 +--- ...eimv2_original_pytorch_checkpoint_to_hf.py | 157 ++-- .../models/deimv2/modeling_deimv2.py | 856 ++++++------------ .../models/deimv2/modular_deimv2.py | 354 ++------ tests/models/deimv2/test_modeling_deimv2.py | 52 +- utils/check_config_attributes.py | 1 - 6 files changed, 465 insertions(+), 1150 deletions(-) diff --git a/src/transformers/models/deimv2/configuration_deimv2.py b/src/transformers/models/deimv2/configuration_deimv2.py index d98372c44509..962c34a74596 100644 --- a/src/transformers/models/deimv2/configuration_deimv2.py +++ b/src/transformers/models/deimv2/configuration_deimv2.py @@ -17,7 +17,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ...backbone_utils import BackboneConfigMixin, consolidate_backbone_kwargs_to_config + +from ...backbone_utils import consolidate_backbone_kwargs_to_config from ...configuration_utils import PreTrainedConfig from ..auto import AutoConfig @@ -45,7 +46,8 @@ class Deimv2Config(PreTrainedConfig): batch_norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the batch normalization layers. backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `HGNetV2Config()`): - The configuration of the backbone model. + The configuration of the backbone model. For HGNetV2 variants, use `HGNetV2Config`. + For DINOv3 variants, use `DINOv3ViTConfig`. freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`): Whether to freeze the batch normalization layers in the backbone. encoder_hidden_dim (`int`, *optional*, defaults to 256): @@ -174,8 +176,6 @@ class Deimv2Config(PreTrainedConfig): Alpha parameter for the Matching Auxiliary Loss (MAL). If `None`, uses `focal_loss_alpha`. encoder_fuse_op (`str`, *optional*, defaults to `"sum"`): Fusion operation used in the encoder FPN. DEIMv2 uses `"sum"` instead of D-Fine's `"cat"`. - use_spatial_tuning_adapter (`bool`, *optional*, defaults to `False`): - Whether to use the Spatial Tuning Adapter (STA) for DINOv2 backbone variants. sta_inplanes (`int`, *optional*, defaults to 16): Number of input planes for the STA convolutional stem. encoder_type (`str`, *optional*, defaults to `"hybrid"`): @@ -186,17 +186,6 @@ class Deimv2Config(PreTrainedConfig): uses RMSNorm on the encoder attention output instead. share_bbox_head (`bool`, *optional*, defaults to `False`): Whether to share the bounding box prediction head across all decoder layers. - backbone_type (`str`, *optional*, defaults to `"hgnetv2"`): - Type of backbone to use. `"hgnetv2"` uses HGNetV2, `"dinov3"` uses DINOv3 ViT backbone with STA. - dinov3_backbone_config (`dict`, *optional*): - Configuration dictionary for the DINOv3 ViT backbone. Passed as kwargs to `DINOv3ViTConfig`. - dinov3_interaction_indexes (`list[int]`, *optional*): - Layer indices in the DINOv3 ViT backbone from which to extract intermediate features. - dinov3_hidden_dim (`int`, *optional*): - Hidden dimension for the DINOv3 backbone projection convolutions. If `None`, uses `hidden_size` from - the DINOv3 ViT config. - dinov3_apply_layernorm (`bool`, *optional*, defaults to `False`): - Whether to apply LayerNorm to intermediate features extracted from the DINOv3 ViT backbone. encoder_has_trailing_conv (`bool`, *optional*, defaults to `True`): Whether the encoder's CSP blocks include a trailing 3x3 convolution after the bottleneck path. `True` for RepNCSPELAN4 (used by HGNetV2 N and LiteEncoder variants). @@ -283,16 +272,10 @@ def __init__( use_dense_o2o=True, mal_alpha=None, encoder_fuse_op="sum", - use_spatial_tuning_adapter=False, sta_inplanes=16, encoder_type="hybrid", use_gateway=True, share_bbox_head=False, - backbone_type="hgnetv2", - dinov3_backbone_config=None, - dinov3_interaction_indexes=None, - dinov3_hidden_dim=None, - dinov3_apply_layernorm=False, encoder_has_trailing_conv=True, tie_word_embeddings=True, **kwargs, @@ -301,16 +284,10 @@ def __init__( self.use_dense_o2o = use_dense_o2o self.mal_alpha = mal_alpha self.encoder_fuse_op = encoder_fuse_op - self.use_spatial_tuning_adapter = use_spatial_tuning_adapter self.sta_inplanes = sta_inplanes self.encoder_type = encoder_type self.use_gateway = use_gateway self.share_bbox_head = share_bbox_head - self.backbone_type = backbone_type - self.dinov3_backbone_config = dinov3_backbone_config - self.dinov3_interaction_indexes = dinov3_interaction_indexes - self.dinov3_hidden_dim = dinov3_hidden_dim - self.dinov3_apply_layernorm = dinov3_apply_layernorm self.encoder_has_trailing_conv = encoder_has_trailing_conv self.initializer_range = initializer_range self.initializer_bias_prior_prob = initializer_bias_prior_prob @@ -404,168 +381,4 @@ def __init__( super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) -class Deimv2DINOv3ViTConfig(BackboneConfigMixin, PreTrainedConfig): - r""" - This is the configuration class to store the configuration of a [`DINOv3Model`]. It is used to instantiate an - DINOv3 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the DINOv3 - [facebook/dinov3-vits16-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vits16-pretrain-lvd1689m) architecture. - - Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PreTrainedConfig`] for more information. - - Args: - patch_size (`int`, *optional*, defaults to 16): - The size (resolution) of each patch. - hidden_size (`int`, *optional*, defaults to 384): - Dimensionality of the encoder layers and the pooler layer. - intermediate_size (`int`, *optional*, defaults to 1536): - Dimensionality of the "intermediate" (i.e., feed-forward) layer. - num_hidden_layers (`int`, *optional*, defaults to 12): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 6): - Number of attention heads for each attention layer in the Transformer encoder. - hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): - The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, - `"relu"`, `"selu"` and `"gelu_new"` are supported. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - layer_norm_eps (`float`, *optional*, defaults to 1e-05): - The epsilon used by the layer normalization layers. - rope_theta (`float`, *optional*, defaults to 100.0): - The base period of the RoPE embeddings. - image_size (`int`, *optional*, defaults to 224): - The size (resolution) of each image. - num_channels (`int`, *optional*, defaults to 3): - The number of input channels. - query_bias (`bool`, *optional*, defaults to `True`): - Whether to add a bias to the query projection. - key_bias (`bool`, *optional*, defaults to `False`): - Whether to add a bias to the key projection. - value_bias (`bool`, *optional*, defaults to `True`): - Whether to add a bias to the value projection. - proj_bias (`bool`, *optional*, defaults to `True`): - Whether to add a bias to the output projection. - mlp_bias (`bool`, *optional*, defaults to `True`): - Whether to add a bias to the MLP layers. - layerscale_value (`float`, *optional*, defaults to 1.0): - Initial value to use for layer scale. - drop_path_rate (`float`, *optional*, defaults to 0.0): - Stochastic depth rate per sample (when applied in the main path of residual layers). - use_gated_mlp (`bool`, *optional*, defaults to `False`): - Whether to use the SwiGLU feedforward neural network. - num_register_tokens (`int`, *optional*, defaults to 0): - The number of register tokens. - pos_embed_shift (`float`, *optional*): - Amount to randomly shift position embedding coordinates in [-shift, shift], - applied only in training mode if not `None`. - pos_embed_jitter (`float`, *optional*): - Amount to randomly jitter position embedding coordinates in log-uniform value in [1/jitter, jitter], - applied only in training mode if not `None`. - pos_embed_rescale (`float`, *optional*, defaults to 2.0): - Amount to randomly rescale position embedding coordinates in log-uniform value in [1/rescale, rescale], - applied only in training mode if not `None`. - out_features (`list[str]`, *optional*): - If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. - (depending on how many stages the model has). Will default to the last stage if unset. - out_indices (`list[int]`, *optional*): - If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. - (depending on how many stages the model has). Will default to the last stage if unset. - apply_layernorm (`bool`, *optional*, defaults to `True`): - Whether to apply layer normalization to the feature maps when used as backbone. - reshape_hidden_states (`bool`, *optional*, defaults to `True`): - Whether to reshape the hidden states to spatial dimensions when used as backbone. - - Example: - - ```python - >>> from transformers import Deimv2DINOv3ViTConfig, Deimv2DINOv3ViTModel - - >>> # Initializing a DINOv3 ViT-small style configuration - >>> config = Deimv2DINOv3ViTConfig() - - >>> # Initializing a model (with random weights) from the config - >>> model = Deimv2DINOv3ViTModel(config) - - >>> # Accessing the model config - >>> config = model.config - ```""" - - model_type = "deimv2_dinov3_vit" - - def __init__( - self, - patch_size: int = 16, - hidden_size: int = 384, - intermediate_size: int = 1536, - num_hidden_layers: int = 12, - num_attention_heads: int = 6, - hidden_act: str = "gelu", - attention_dropout: float = 0.0, - initializer_range: float = 0.02, - layer_norm_eps: float = 1e-5, - rope_theta: float = 100.0, - image_size: int = 224, - num_channels: int = 3, - query_bias: bool = True, - key_bias: bool = False, - value_bias: bool = True, - proj_bias: bool = True, - mlp_bias: bool = True, - layerscale_value: float = 1.0, - drop_path_rate: float = 0.0, - use_gated_mlp: bool = False, - num_register_tokens: int = 0, - # train augs - pos_embed_shift: float | None = None, - pos_embed_jitter: float | None = None, - pos_embed_rescale: float | None = 2.0, - out_features: list[str] | None = None, - out_indices: list[int] | None = None, - apply_layernorm: bool = True, - reshape_hidden_states: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.attention_dropout = attention_dropout - self.initializer_range = initializer_range - self.layer_norm_eps = layer_norm_eps - self.layerscale_value = layerscale_value - self.drop_path_rate = drop_path_rate - self.use_gated_mlp = use_gated_mlp - self.rope_theta = rope_theta - self.query_bias = query_bias - self.key_bias = key_bias - self.value_bias = value_bias - self.proj_bias = proj_bias - self.mlp_bias = mlp_bias - self.num_register_tokens = num_register_tokens - - # train augs - self.pos_embed_shift = pos_embed_shift - self.pos_embed_jitter = pos_embed_jitter - self.pos_embed_rescale = pos_embed_rescale - # Initialize backbone-specific configuration - self.apply_layernorm = apply_layernorm - self.reshape_hidden_states = reshape_hidden_states - - # Initialize backbone stage names - stage_names = ["stem"] + [f"stage{i}" for i in range(1, num_hidden_layers + 1)] - self.stage_names = stage_names - - # Initialize backbone features/indices - self.set_output_features_output_indices(out_indices=out_indices, out_features=out_features) - - __all__ = ["Deimv2Config"] diff --git a/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py b/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py index a6125952deeb..a4ba5aeda67f 100644 --- a/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py +++ b/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py @@ -26,6 +26,7 @@ from torchvision import transforms from transformers import Deimv2Config, Deimv2ForObjectDetection, RTDetrImageProcessor +from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig from transformers.utils import logging @@ -180,15 +181,11 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: num_stages = len(config.backbone_config.hidden_sizes) config.backbone_config.depths = config.backbone_config.stage_numb_of_layers config.backbone_config.stage_names = ["stem"] + [f"stage{i}" for i in range(1, num_stages + 1)] - - config.use_spatial_tuning_adapter = False elif "DINOv3STAs" in orig_config: dinov3_cfg = orig_config["DINOv3STAs"] name = dinov3_cfg["name"] - config.backbone_type = "dinov3" - config.dinov3_interaction_indexes = dinov3_cfg["interaction_indexes"] + interaction_indexes = dinov3_cfg["interaction_indexes"] config.sta_inplanes = dinov3_cfg.get("conv_inplane") or 16 - config.use_spatial_tuning_adapter = True is_dinov3 = "dinov3" in name @@ -239,8 +236,6 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: vit_hidden_size = dinov3_cfg.get("embed_dim") or 192 vit_num_heads = dinov3_cfg.get("num_heads") or 3 - config.dinov3_hidden_dim = dinov3_cfg.get("hidden_dim") or vit_hidden_size - ffn_ratio = preset["ffn_ratio"] if preset["use_gated_mlp"]: hidden_features = vit_hidden_size * ffn_ratio @@ -249,19 +244,22 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: else: intermediate_size = vit_hidden_size * ffn_ratio - config.dinov3_backbone_config = { - "hidden_size": vit_hidden_size, - "num_attention_heads": vit_num_heads, - "num_hidden_layers": 12, - "intermediate_size": intermediate_size, - "layerscale_value": preset["layerscale_value"], - "use_gated_mlp": preset["use_gated_mlp"], - "num_register_tokens": preset["num_register_tokens"], - "pos_embed_rescale": preset["pos_embed_rescale"], - "key_bias": preset["key_bias"], - "rope_theta": 100.0, - } - config.dinov3_apply_layernorm = is_dinov3 + out_indices = [idx + 1 for idx in interaction_indexes] + config.backbone_config = DINOv3ViTConfig( + hidden_size=vit_hidden_size, + num_attention_heads=vit_num_heads, + num_hidden_layers=12, + intermediate_size=intermediate_size, + layerscale_value=preset["layerscale_value"], + use_gated_mlp=preset["use_gated_mlp"], + num_register_tokens=preset["num_register_tokens"], + pos_embed_rescale=preset["pos_embed_rescale"], + key_bias=preset["key_bias"], + rope_theta=100.0, + out_indices=out_indices, + apply_layernorm=is_dinov3, + reshape_hidden_states=True, + ) config.encoder_has_trailing_conv = False else: raise ValueError(f"Unknown backbone in config: {list(orig_config.keys())}") @@ -297,8 +295,8 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: r"backbone\.stages\.(\d+)\.downsample\.lab\.(scale|bias)": r"model.backbone.model.encoder.stages.\1.downsample.lab.\2", # Encoder mappings # Input projections - r"encoder\.input_proj\.(\d+)\.conv\.weight": r"model.encoder_input_proj.\1.0.weight", - r"encoder\.input_proj\.(\d+)\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder_input_proj.\1.1.\2", + r"encoder\.input_proj\.(\d+)\.conv\.weight": r"model.backbone.encoder_input_proj.\1.0.weight", + r"encoder\.input_proj\.(\d+)\.norm\.(weight|bias|running_mean|running_var)": r"model.backbone.encoder_input_proj.\1.1.\2", # AIFI transformer encoder layers r"encoder\.encoder\.(\d+)\.layers\.0\.self_attn\.out_proj\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.self_attn.o_proj.\2", r"encoder\.encoder\.(\d+)\.layers\.0\.linear1\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.mlp.layers.0.\2", @@ -369,12 +367,12 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.value_proj\.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.value_proj.\2", r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.output_proj\.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.output_proj.\2", r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.num_points_scale": r"model.decoder.layers.\1.encoder_attn.num_points_scale", - r"decoder\.decoder\.layers\.(\d+)\.norm1\.scale": r"model.decoder.layers.\1.self_attn_layer_norm.scale", - r"decoder\.decoder\.layers\.(\d+)\.norm3\.scale": r"model.decoder.layers.\1.final_layer_norm.scale", + r"decoder\.decoder\.layers\.(\d+)\.norm1\.scale": r"model.decoder.layers.\1.self_attn_layer_norm.weight", + r"decoder\.decoder\.layers\.(\d+)\.norm3\.scale": r"model.decoder.layers.\1.final_layer_norm.weight", r"decoder\.decoder\.layers\.(\d+)\.swish_ffn\.w12\.(weight|bias)": r"model.decoder.layers.\1.mlp.w12.\2", r"decoder\.decoder\.layers\.(\d+)\.swish_ffn\.w3\.(weight|bias)": r"model.decoder.layers.\1.mlp.w3.\2", r"decoder\.decoder\.layers\.(\d+)\.gateway\.gate\.(weight|bias)": r"model.decoder.layers.\1.gateway.gate.\2", - r"decoder\.decoder\.layers\.(\d+)\.gateway\.norm\.scale": r"model.decoder.layers.\1.gateway.norm.scale", + r"decoder\.decoder\.layers\.(\d+)\.gateway\.norm\.scale": r"model.decoder.layers.\1.gateway.norm.weight", # LQE layers r"decoder\.decoder\.lqe_layers\.(\d+)\.reg_conf\.layers\.(\d+)\.(weight|bias)": r"model.decoder.lqe_layers.\1.reg_conf.layers.\2.\3", # Decoder heads and projections @@ -449,50 +447,63 @@ def get_deimv2_config(model_name: str) -> Deimv2Config: } DECODER_NO_GATEWAY_KEY_MAPPING = { - r"decoder\.decoder\.layers\.(\d+)\.norm2\.scale": r"model.decoder.layers.\1.encoder_attn_layer_norm.scale", + r"decoder\.decoder\.layers\.(\d+)\.norm2\.scale": r"model.decoder.layers.\1.encoder_attn_layer_norm.weight", } DINOV3_KEY_MAPPING = { # ViT embeddings - r"backbone\.dinov3\.patch_embed\.proj\.(weight|bias)": r"model.dinov3_backbone.embeddings.patch_embeddings.\1", - r"backbone\.dinov3\.cls_token": r"model.dinov3_backbone.embeddings.cls_token", - r"backbone\.dinov3\.storage_tokens": r"model.dinov3_backbone.embeddings.register_tokens", - r"backbone\.dinov3\.mask_token": r"model.dinov3_backbone.embeddings.mask_token", + r"backbone\.dinov3\.patch_embed\.proj\.(weight|bias)": r"model.backbone.backbone.embeddings.patch_embeddings.\1", + r"backbone\.dinov3\.cls_token": r"model.backbone.backbone.embeddings.cls_token", + r"backbone\.dinov3\.storage_tokens": r"model.backbone.backbone.embeddings.register_tokens", + r"backbone\.dinov3\.mask_token": r"model.backbone.backbone.embeddings.mask_token", # ViT blocks - r"backbone\.dinov3\.blocks\.(\d+)\.norm1\.(weight|bias)": r"model.dinov3_backbone.layers.\1.norm1.\2", - r"backbone\.dinov3\.blocks\.(\d+)\.norm2\.(weight|bias)": r"model.dinov3_backbone.layers.\1.norm2.\2", - r"backbone\.dinov3\.blocks\.(\d+)\.attn\.qkv\.(weight|bias)": r"model.dinov3_backbone.layers.\1.attention.qkv.\2", - r"backbone\.dinov3\.blocks\.(\d+)\.attn\.proj\.(weight|bias)": r"model.dinov3_backbone.layers.\1.attention.o_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.norm1\.(weight|bias)": r"model.backbone.backbone.layer.\1.norm1.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.norm2\.(weight|bias)": r"model.backbone.backbone.layer.\1.norm2.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.attn\.qkv\.(weight|bias)": r"model.backbone.backbone.layer.\1.attention.qkv.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.attn\.proj\.(weight|bias)": r"model.backbone.backbone.layer.\1.attention.o_proj.\2", # Standard MLP (S/M/L) - r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.fc1\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.up_proj.\2", - r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.fc2\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.down_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.fc1\.(weight|bias)": r"model.backbone.backbone.layer.\1.mlp.up_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.fc2\.(weight|bias)": r"model.backbone.backbone.layer.\1.mlp.down_proj.\2", # SwiGLU MLP (X only) - r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w1\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.gate_proj.\2", - r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w2\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.up_proj.\2", - r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w3\.(weight|bias)": r"model.dinov3_backbone.layers.\1.mlp.down_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w1\.(weight|bias)": r"model.backbone.backbone.layer.\1.mlp.gate_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w2\.(weight|bias)": r"model.backbone.backbone.layer.\1.mlp.up_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w3\.(weight|bias)": r"model.backbone.backbone.layer.\1.mlp.down_proj.\2", # LayerScale (L/X only) - r"backbone\.dinov3\.blocks\.(\d+)\.ls1\.gamma": r"model.dinov3_backbone.layers.\1.layer_scale1.lambda1", - r"backbone\.dinov3\.blocks\.(\d+)\.ls2\.gamma": r"model.dinov3_backbone.layers.\1.layer_scale2.lambda1", + r"backbone\.dinov3\.blocks\.(\d+)\.ls1\.gamma": r"model.backbone.backbone.layer.\1.layer_scale1.lambda1", + r"backbone\.dinov3\.blocks\.(\d+)\.ls2\.gamma": r"model.backbone.backbone.layer.\1.layer_scale2.lambda1", # Norm (L/X only) - r"backbone\.dinov3\.norm\.(weight|bias)": r"model.dinov3_backbone.norm.\1", + r"backbone\.dinov3\.norm\.(weight|bias)": r"model.backbone.backbone.norm.\1", # STA adapter - r"backbone\.sta\.stem\.0\.(weight|bias)": r"model.dinov3_backbone.sta.stem.0.\1", - r"backbone\.sta\.stem\.1\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.sta.stem.1.\1", - r"backbone\.sta\.conv2\.0\.(weight)": r"model.dinov3_backbone.sta.conv2.0.\1", - r"backbone\.sta\.conv2\.1\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.sta.conv2.1.\1", - r"backbone\.sta\.conv3\.1\.(weight)": r"model.dinov3_backbone.sta.conv3.1.\1", - r"backbone\.sta\.conv3\.2\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.sta.conv3.2.\1", - r"backbone\.sta\.conv4\.1\.(weight)": r"model.dinov3_backbone.sta.conv4.1.\1", - r"backbone\.sta\.conv4\.2\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.sta.conv4.2.\1", + r"backbone\.sta\.stem\.0\.(weight|bias)": r"model.backbone.sta.stem.0.\1", + r"backbone\.sta\.stem\.1\.(weight|bias|running_mean|running_var)": r"model.backbone.sta.stem.1.\1", + r"backbone\.sta\.conv2\.0\.(weight)": r"model.backbone.sta.conv2.0.\1", + r"backbone\.sta\.conv2\.1\.(weight|bias|running_mean|running_var)": r"model.backbone.sta.conv2.1.\1", + r"backbone\.sta\.conv3\.1\.(weight)": r"model.backbone.sta.conv3.1.\1", + r"backbone\.sta\.conv3\.2\.(weight|bias|running_mean|running_var)": r"model.backbone.sta.conv3.2.\1", + r"backbone\.sta\.conv4\.1\.(weight)": r"model.backbone.sta.conv4.1.\1", + r"backbone\.sta\.conv4\.2\.(weight|bias|running_mean|running_var)": r"model.backbone.sta.conv4.2.\1", # Projection convs/norms - r"backbone\.convs\.(\d+)\.weight": r"model.dinov3_backbone.convs.\1.weight", - r"backbone\.norms\.(\d+)\.(weight|bias|running_mean|running_var)": r"model.dinov3_backbone.norms.\1.\2", + r"backbone\.convs\.(\d+)\.weight": r"model.backbone.convs.\1.weight", + r"backbone\.norms\.(\d+)\.(weight|bias|running_mean|running_var)": r"model.backbone.norms.\1.\2", } +def split_swiglu_weights(state_dict): + for key in list(state_dict.keys()): + if ".mlp.w12." in key: + w12 = state_dict.pop(key) + gate, up = w12.chunk(2, dim=0) + state_dict[key.replace(".w12.", ".gate_proj.")] = gate + state_dict[key.replace(".w12.", ".up_proj.")] = up + elif ".mlp.w3." in key: + state_dict[key.replace(".w3.", ".down_proj.")] = state_dict.pop(key) + + def convert_old_keys_to_new_keys(state_dict, config=None): mapping = dict(ORIGINAL_TO_CONVERTED_KEY_MAPPING) + is_dinov3 = getattr(config.backbone_config, "model_type", None) == "dinov3_vit" if config else False + if config is not None: if config.encoder_type == "lite": for k in list(mapping.keys()): @@ -513,9 +524,9 @@ def convert_old_keys_to_new_keys(state_dict, config=None): if "gateway" in k: del mapping[k] - if config.backbone_type == "dinov3": + if is_dinov3: for k in list(mapping.keys()): - if k.startswith(r"backbone\."): + if k.startswith(r"backbone\.") or k.startswith(r"encoder\.input_proj"): del mapping[k] mapping.update(DINOV3_KEY_MAPPING) @@ -578,28 +589,25 @@ def strip_dinov3_model_prefix(state_dict): def read_in_q_k_v_vit(state_dict, config): - from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig - - vit_config = DINOv3ViTConfig(**config.dinov3_backbone_config) - has_key_bias = config.dinov3_backbone_config.get("key_bias", True) - prefix = "model.dinov3_backbone" - for i in range(vit_config.num_hidden_layers): - qkv_key = f"{prefix}.layers.{i}.attention.qkv.weight" + has_key_bias = getattr(config.backbone_config, "key_bias", True) + prefix = "model.backbone.backbone" + for i in range(config.backbone_config.num_hidden_layers): + qkv_key = f"{prefix}.layer.{i}.attention.qkv.weight" if qkv_key in state_dict: qkv_w = state_dict.pop(qkv_key) q, k, v = qkv_w.chunk(3, dim=0) - state_dict[f"{prefix}.layers.{i}.attention.q_proj.weight"] = q - state_dict[f"{prefix}.layers.{i}.attention.k_proj.weight"] = k - state_dict[f"{prefix}.layers.{i}.attention.v_proj.weight"] = v + state_dict[f"{prefix}.layer.{i}.attention.q_proj.weight"] = q + state_dict[f"{prefix}.layer.{i}.attention.k_proj.weight"] = k + state_dict[f"{prefix}.layer.{i}.attention.v_proj.weight"] = v - qkv_bias_key = f"{prefix}.layers.{i}.attention.qkv.bias" + qkv_bias_key = f"{prefix}.layer.{i}.attention.qkv.bias" if qkv_bias_key in state_dict: qkv_b = state_dict.pop(qkv_bias_key) q_b, k_b, v_b = qkv_b.chunk(3, dim=0) - state_dict[f"{prefix}.layers.{i}.attention.q_proj.bias"] = q_b + state_dict[f"{prefix}.layer.{i}.attention.q_proj.bias"] = q_b if has_key_bias: - state_dict[f"{prefix}.layers.{i}.attention.k_proj.bias"] = k_b - state_dict[f"{prefix}.layers.{i}.attention.v_proj.bias"] = v_b + state_dict[f"{prefix}.layer.{i}.attention.k_proj.bias"] = k_b + state_dict[f"{prefix}.layer.{i}.attention.v_proj.bias"] = v_b def load_original_state_dict(repo_id): @@ -634,7 +642,9 @@ def convert_deimv2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, if key.endswith(".num_batches_tracked"): state_dict.pop(key) - if config.backbone_type == "dinov3": + is_dinov3 = getattr(config.backbone_config, "model_type", None) == "dinov3_vit" + + if is_dinov3: strip_dinov3_model_prefix(state_dict) for key in list(state_dict.keys()): if "rope_embed.periods" in key or "qkv.bias_mask" in key: @@ -645,9 +655,11 @@ def convert_deimv2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, state_dict = convert_old_keys_to_new_keys(state_dict, config) - if config.backbone_type == "dinov3": + split_swiglu_weights(state_dict) + + if is_dinov3: read_in_q_k_v_vit(state_dict, config) - mask_key = "model.dinov3_backbone.embeddings.mask_token" + mask_key = "model.backbone.backbone.embeddings.mask_token" if mask_key in state_dict and state_dict[mask_key].dim() == 2: state_dict[mask_key] = state_dict[mask_key].unsqueeze(1) @@ -677,7 +689,7 @@ def convert_deimv2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, model = Deimv2ForObjectDetection(config) missing, unexpected = model.load_state_dict(state_dict, strict=False) - expected_missing = {"mask_token", "register_tokens", "layer_scale1", "layer_scale2"} + expected_missing = {"mask_token", "register_tokens", "layer_scale1", "layer_scale2", "backbone.norm"} unexpected_missing = [k for k in missing if not any(e in k for e in expected_missing)] if unexpected_missing: logger.warning(f"Missing keys ({len(unexpected_missing)}): {unexpected_missing[:10]}...") @@ -690,7 +702,6 @@ def convert_deimv2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, model.eval() - is_dinov3 = config.backbone_type == "dinov3" if is_dinov3: image_processor = RTDetrImageProcessor( do_normalize=True, diff --git a/src/transformers/models/deimv2/modeling_deimv2.py b/src/transformers/models/deimv2/modeling_deimv2.py index 9173609eb0df..525695807d90 100644 --- a/src/transformers/models/deimv2/modeling_deimv2.py +++ b/src/transformers/models/deimv2/modeling_deimv2.py @@ -17,28 +17,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import math from collections.abc import Callable from dataclasses import dataclass -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from ... import initialization as init -from ...activations import ACT2CLS, ACT2FN +from ...activations import ACT2CLS +from ...backbone_utils import load_backbone from ...image_transforms import center_to_corners_format, corners_to_center_format -from ...modeling_layers import GradientCheckpointingLayer +from ...integrations import use_kernel_forward_from_hub from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import compile_compatible_method_lru_cache from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int -from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults +from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from .configuration_deimv2 import Deimv2Config, Deimv2DINOv3ViTConfig +from .configuration_deimv2 import Deimv2Config @dataclass @@ -79,30 +80,94 @@ class Deimv2DecoderOutput(ModelOutput): cross_attentions: tuple[torch.FloatTensor] | None = None +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the RT-DETR encoder-decoder model. + """ +) +class Deimv2ModelOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate predicted corners (predicted corners of each layer of the decoder). + initial_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points used for the first decoder layer. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`): + Logits of predicted bounding boxes coordinates in the encoder stage. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values. + """ + + last_hidden_state: torch.FloatTensor | None = None + intermediate_hidden_states: torch.FloatTensor | None = None + intermediate_logits: torch.FloatTensor | None = None + intermediate_reference_points: torch.FloatTensor | None = None + intermediate_predicted_corners: torch.FloatTensor | None = None + initial_reference_points: torch.FloatTensor | None = None + decoder_hidden_states: tuple[torch.FloatTensor] | None = None + decoder_attentions: tuple[torch.FloatTensor] | None = None + cross_attentions: tuple[torch.FloatTensor] | None = None + encoder_last_hidden_state: torch.FloatTensor | None = None + encoder_hidden_states: tuple[torch.FloatTensor] | None = None + encoder_attentions: tuple[torch.FloatTensor] | None = None + init_reference_points: torch.FloatTensor | None = None + enc_topk_logits: torch.FloatTensor | None = None + enc_topk_bboxes: torch.FloatTensor | None = None + enc_outputs_class: torch.FloatTensor | None = None + enc_outputs_coord_logits: torch.FloatTensor | None = None + denoising_meta_values: dict | None = None + + +@use_kernel_forward_from_hub("RMSNorm") class Deimv2RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + Deimv2RMSNorm is equivalent to T5LayerNorm + """ super().__init__() - self.eps = eps - self.scale = nn.Parameter(torch.ones(dim)) + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_dtype = hidden_states.dtype - hidden_states = hidden_states.float() - hidden_states = hidden_states * torch.rsqrt(hidden_states.pow(2).mean(-1, keepdim=True) + self.eps) - return (hidden_states * self.scale).to(input_dtype) + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Deimv2SwiGLUFFN(nn.Module): def __init__(self, in_features: int, hidden_features: int, out_features: int): super().__init__() - self.w12 = nn.Linear(in_features, 2 * hidden_features) - self.w3 = nn.Linear(hidden_features, out_features) + self.gate_proj = nn.Linear(in_features, hidden_features) + self.up_proj = nn.Linear(in_features, hidden_features) + self.down_proj = nn.Linear(hidden_features, out_features) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - x12 = self.w12(hidden_states) - x1, x2 = x12.chunk(2, dim=-1) - hidden = F.silu(x1) * x2 - return self.w3(hidden) + return self.down_proj(F.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) class Deimv2Gate(nn.Module): @@ -352,400 +417,7 @@ def forward(self, x): return self.activation(y) -class Deimv2DINOv3ViTEmbeddings(nn.Module): - """ - Construct the CLS token, mask token, position and patch embeddings. - """ - - def __init__(self, config: Deimv2DINOv3ViTConfig): - super().__init__() - self.config = config - self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) - self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) - self.register_tokens = nn.Parameter(torch.empty(1, config.num_register_tokens, config.hidden_size)) - self.patch_embeddings = nn.Conv2d( - config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size - ) - - def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None) -> torch.Tensor: - batch_size = pixel_values.shape[0] - target_dtype = self.patch_embeddings.weight.dtype - - # (batch_size, num_channels, height, width) -> (batch_size, num_patches, hidden_size) - patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) - patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2) - - if bool_masked_pos is not None: - mask_token = self.mask_token.to(patch_embeddings.dtype) - patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings) - - # Add CLS and register tokens - cls_token = self.cls_token.expand(batch_size, -1, -1) - register_tokens = self.register_tokens.expand(batch_size, -1, -1) - embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) - - return embeddings - - -class Deimv2DINOv3ViTLayerScale(nn.Module): - def __init__(self, config) -> None: - super().__init__() - self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size)) - - def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: - return hidden_state * self.lambda1 - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor | None, - scaling: float | None = None, - dropout: float = 0.0, - **kwargs: Unpack[TransformersKwargs], -): - if scaling is None: - scaling = query.size(-1) ** -0.5 - - # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - - attn_output = torch.matmul(attn_weights, value) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -def apply_rotary_pos_emb( - q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, **kwargs -) -> tuple[torch.Tensor, torch.Tensor]: - """Applies Rotary Position Embedding to the query and key tensors, but only to the patch tokens, - ignoring the prefix tokens (cls token and register tokens). - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - - num_tokens = q.shape[-2] - num_patches = sin.shape[-2] - num_prefix_tokens = num_tokens - num_patches # cls token + register tokens - - q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2) - k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2) - - # apply rope only to patch tokens - q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin) - k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin) - - q = torch.cat((q_prefix_tokens, q_patches), dim=-2) - k = torch.cat((k_prefix_tokens, k_patches), dim=-2) - - return q, k - - -class Deimv2DINOv3ViTAttention(nn.Module): - """ - Multi-headed attention compatible with ALL_ATTENTION_FUNCTIONS. - """ - - def __init__(self, config: Deimv2DINOv3ViTConfig): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - self.is_causal = False - - self.scaling = self.head_dim**-0.5 - self.is_causal = False - - self.dropout = config.attention_dropout - self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.key_bias) - self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.value_bias) - - self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.query_bias) - self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.proj_bias) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, - **kwargs: Unpack[TransformersKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None]: - """Input shape: Batch x Time x Channel""" - - batch_size, patches, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( - self.config._attn_implementation, eager_attention_forward - ) - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, - scaling=self.scaling, - **kwargs, - ) - - attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() - attn_output = self.o_proj(attn_output) - - return attn_output, attn_weights - - -def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: - """ - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - - """ - if drop_prob == 0.0 or not training: - return input - keep_prob = 1 - drop_prob - shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) - random_tensor.floor_() # binarize - output = input.div(keep_prob) * random_tensor - return output - - -class Deimv2DINOv3ViTDropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, drop_prob: float | None = None) -> None: - super().__init__() - self.drop_prob = drop_prob - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - return drop_path(hidden_states, self.drop_prob, self.training) - - def extra_repr(self) -> str: - return f"p={self.drop_prob}" - - -class Deimv2DINOv3ViTMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.up_proj(x))) - - -class Deimv2DINOv3ViTGatedMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -class Deimv2DINOv3ViTLayer(GradientCheckpointingLayer): - """This corresponds to the Block class in the original implementation.""" - - def __init__(self, config: Deimv2DINOv3ViTConfig): - super().__init__() - - self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = Deimv2DINOv3ViTAttention(config) - self.layer_scale1 = Deimv2DINOv3ViTLayerScale(config) - self.drop_path = ( - Deimv2DINOv3ViTDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() - ) - - self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - if config.use_gated_mlp: - self.mlp = Deimv2DINOv3ViTGatedMLP(config) - else: - self.mlp = Deimv2DINOv3ViTMLP(config) - self.layer_scale2 = Deimv2DINOv3ViTLayerScale(config) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, - ) -> torch.Tensor: - # Attention with residual connection - residual = hidden_states - hidden_states = self.norm1(hidden_states) - hidden_states, _ = self.attention( - hidden_states, - attention_mask=attention_mask, - position_embeddings=position_embeddings, - ) - hidden_states = self.layer_scale1(hidden_states) - hidden_states = self.drop_path(hidden_states) + residual - - # MLP with residual connection - residual = hidden_states - hidden_states = self.norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = self.layer_scale2(hidden_states) - hidden_states = self.drop_path(hidden_states) + residual - - return hidden_states - - -@compile_compatible_method_lru_cache(maxsize=32) -def get_patches_center_coordinates( - num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device -) -> torch.Tensor: - """ - Computes the 2D coordinates of the centers of image patches, normalized to the range [-1, +1]. - The center of each patch is exactly halfway between its top-left and bottom-right corners. - - Args: - num_patches_h (int): Number of patches along the vertical (height) axis. - num_patches_w (int): Number of patches along the horizontal (width) axis. - dtype (torch.dtype): The desired data type of the returned tensor. - - Returns: - torch.Tensor: A tensor of shape (height * width, 2), where each row contains the (y, x) - coordinates of a patch center, normalized to [-1, +1]. - """ - coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device) - coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device) - coords_h = coords_h / num_patches_h - coords_w = coords_w / num_patches_w - # (height, width, 2) -> (height * width, 2) - coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) - coords = coords.flatten(0, 1) - # Shift range [0, 1] to [-1, +1] - coords = 2.0 * coords - 1.0 - return coords - - -def augment_patches_center_coordinates( - coords: torch.Tensor, - shift: float | None = None, - jitter: float | None = None, - rescale: float | None = None, -) -> torch.Tensor: - # Shift coords by adding a uniform value in [-shift, shift] - if shift is not None: - shift_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype) - shift_hw = shift_hw.uniform_(-shift, shift) - coords = coords + shift_hw - - # Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter] - if jitter is not None: - jitter_range = np.log(jitter) - jitter_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype) - jitter_hw = jitter_hw.uniform_(-jitter_range, jitter_range).exp() - coords = coords * jitter_hw - - # Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale] - if rescale is not None: - rescale_range = np.log(rescale) - rescale_hw = torch.empty(1, device=coords.device, dtype=coords.dtype) - rescale_hw = rescale_hw.uniform_(-rescale_range, rescale_range).exp() - coords = coords * rescale_hw - - return coords - - -class Deimv2DINOv3ViTRopePositionEmbedding(nn.Module): - inv_freq: torch.Tensor - - def __init__(self, config: Deimv2DINOv3ViTConfig): - super().__init__() - - self.config = config - self.base = config.rope_theta - self.head_dim = config.hidden_size // config.num_attention_heads - self.num_patches_h = config.image_size // config.patch_size - self.num_patches_w = config.image_size // config.patch_size - - inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32) # (head_dim / 4,) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - _, _, height, width = pixel_values.shape - num_patches_h = height // self.config.patch_size - num_patches_w = width // self.config.patch_size - - device = pixel_values.device - device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu" - - with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - # Although we could precompute static patch_coords from image_size and patch_size in the config, - # the model was trained with random_scale, so it can process images of varying sizes. - # Therefore, it's better to compute patch_coords dynamically (with lru_cache). - patch_coords = get_patches_center_coordinates( - num_patches_h, num_patches_w, dtype=torch.float32, device=device - ) - if self.training: - patch_coords = augment_patches_center_coordinates( - patch_coords, - shift=self.config.pos_embed_shift, - jitter=self.config.pos_embed_jitter, - rescale=self.config.pos_embed_rescale, - ) - - # (height * width, 2, head_dim / 4) -> (height * width, head_dim / 2) -> (height * width, head_dim) - angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] - angles = angles.flatten(1, 2) - angles = angles.tile(2) - - cos = torch.cos(angles) - sin = torch.sin(angles) - - dtype = pixel_values.dtype - return cos.to(dtype=dtype), sin.to(dtype=dtype) - - -class Deimv2CSPRepLayer2(nn.Module): +class Deimv2CSPRepLayer(nn.Module): def __init__( self, config: Deimv2Config, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0 ): @@ -762,34 +434,27 @@ def __init__( self.conv3 = nn.Identity() def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: - chunks = list(self.conv1(hidden_state).chunk(2, 1)) - bottleneck_out = chunks[1] + hidden_state_1, hidden_state_2 = self.conv1(hidden_state).chunk(2, 1) for bottleneck in self.bottlenecks: - bottleneck_out = bottleneck(bottleneck_out) - return self.conv3(chunks[0] + bottleneck_out) + hidden_state_2 = bottleneck(hidden_state_2) + return self.conv3(hidden_state_1 + hidden_state_2) class Deimv2RepNCSPELAN5(nn.Module): - def __init__( - self, - config: Deimv2Config, - numb_blocks: int = 3, - c1: int | None = None, - c2: int | None = None, - c3: int | None = None, - c4: int | None = None, - ): + def __init__(self, config: Deimv2Config, numb_blocks: int = 3): super().__init__() - act = config.activation_function - c1 = c1 if c1 is not None else config.encoder_hidden_dim - c2 = c2 if c2 is not None else config.encoder_hidden_dim - c3 = c3 if c3 is not None else config.encoder_hidden_dim * 2 - c4 = c4 if c4 is not None else round(config.hidden_expansion * config.encoder_hidden_dim // 2) - self.conv_dim = c3 // 2 - self.conv1 = Deimv2ConvNormLayer(config, c1, c3, 1, 1, activation=act) - self.csp_rep1 = Deimv2CSPRepLayer2(config, c3 // 2, c4, num_blocks=numb_blocks) - self.csp_rep2 = Deimv2CSPRepLayer2(config, c4, c4, num_blocks=numb_blocks) - self.conv4 = Deimv2ConvNormLayer(config, c3 + (2 * c4), c2, 1, 1, activation=act) + activation = config.activation_function + in_channels = config.encoder_hidden_dim + out_channels = config.encoder_hidden_dim + split_channels = config.encoder_hidden_dim * 2 + csp_channels = round(config.hidden_expansion * config.encoder_hidden_dim // 2) + self.conv_dim = split_channels // 2 + self.conv1 = Deimv2ConvNormLayer(config, in_channels, split_channels, 1, 1, activation=activation) + self.csp_rep1 = Deimv2CSPRepLayer(config, split_channels // 2, csp_channels, num_blocks=numb_blocks) + self.csp_rep2 = Deimv2CSPRepLayer(config, csp_channels, csp_channels, num_blocks=numb_blocks) + self.conv4 = Deimv2ConvNormLayer( + config, split_channels + (2 * csp_channels), out_channels, 1, 1, activation=activation + ) def forward(self, input_features: torch.Tensor) -> torch.Tensor: split_features = list(self.conv1(input_features).split((self.conv_dim, self.conv_dim), 1)) @@ -819,6 +484,34 @@ def forward(self, input_features: torch.Tensor) -> torch.Tensor: return input_features +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float | None = None, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Deimv2SelfAttention(nn.Module): """ Multi-headed self-attention from 'Attention Is All You Need' paper. @@ -1078,11 +771,11 @@ def __init__(self, config: Deimv2Config): ) def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - c1 = self.stem(pixel_values) - c2 = self.conv2(c1) - c3 = self.conv3(c2) - c4 = self.conv4(c3) - return c2, c3, c4 + hidden_states = self.stem(pixel_values) + feature_map_8 = self.conv2(hidden_states) + feature_map_16 = self.conv3(feature_map_8) + feature_map_32 = self.conv4(feature_map_16) + return feature_map_8, feature_map_16, feature_map_32 class Deimv2GAPFusion(nn.Module): @@ -1124,12 +817,9 @@ def __init__(self, config: Deimv2Config): self.bi_fusion = Deimv2GAPFusion(config, hidden_dim) - c1, c2 = hidden_dim, hidden_dim - c3 = hidden_dim * 2 - c4 = round(config.hidden_expansion * hidden_dim // 2) num_blocks = round(3 * config.depth_mult) - self.fpn_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks, c1=c1, c2=c2, c3=c3, c4=c4) - self.pan_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks, c1=c1, c2=c2, c3=c3, c4=c4) + self.fpn_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks) + self.pan_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks) def forward(self, inputs_embeds=None, **kwargs) -> BaseModelOutput: feats = inputs_embeds @@ -1148,27 +838,115 @@ def forward(self, inputs_embeds=None, **kwargs) -> BaseModelOutput: return BaseModelOutput(last_hidden_state=outs) -class Deimv2DINOv3Backbone(nn.Module): - def __init__(self, config: Deimv2Config): +class Deimv2FrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): super().__init__() - vit_config = Deimv2DINOv3ViTConfig(**config.dinov3_backbone_config) + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) - self.embeddings = Deimv2DINOv3ViTEmbeddings(vit_config) - self.rope_embeddings = Deimv2DINOv3ViTRopePositionEmbedding(vit_config) - self.layers = nn.ModuleList([Deimv2DINOv3ViTLayer(vit_config) for _ in range(vit_config.num_hidden_layers)]) + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] - self.apply_layernorm = config.dinov3_apply_layernorm - if self.apply_layernorm: - self.norm = nn.LayerNorm(vit_config.hidden_size, eps=vit_config.layer_norm_eps) + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +def replace_batch_norm(model): + r""" + Recursively replace all `torch.nn.BatchNorm2d` with `Deimv2FrozenBatchNorm2d`. - self.interaction_indexes = config.dinov3_interaction_indexes - self.patch_size = vit_config.patch_size - self.num_prefix_tokens = 1 + vit_config.num_register_tokens + Args: + model (torch.nn.Module): + input model + """ + for name, module in model.named_children(): + if isinstance(module, nn.BatchNorm2d): + new_module = Deimv2FrozenBatchNorm2d(module.num_features) + + if module.weight.device != torch.device("meta"): + new_module.weight.copy_(module.weight) + new_module.bias.copy_(module.bias) + new_module.running_mean.copy_(module.running_mean) + new_module.running_var.copy_(module.running_var) + + model._modules[name] = new_module + + if len(list(module.children())) > 0: + replace_batch_norm(module) + + +class Deimv2ConvEncoder(nn.Module): + """ + Convolutional backbone using the modeling_deimv2_resnet.py. + + nn.BatchNorm2d layers are replaced by Deimv2FrozenBatchNorm2d as defined above. + https://github.com/lyuwenyu/RT-DETR/blob/main/Deimv2_pytorch/src/nn/backbone/presnet.py#L142 + """ + + def __init__(self, config): + super().__init__() + + backbone = load_backbone(config) + + if config.freeze_backbone_batch_norms: + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) + self.model = backbone + self.intermediate_channel_sizes = self.model.channels + if config.encoder_type != "lite": + encoder_input_proj = [] + for in_channel in self.intermediate_channel_sizes: + encoder_input_proj.append( + nn.Sequential( + nn.Conv2d(in_channel, config.encoder_hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(config.encoder_hidden_dim), + ) + ) + self.encoder_input_proj = nn.ModuleList(encoder_input_proj) + + def forward(self, pixel_values: torch.Tensor) -> list[torch.Tensor]: + features = self.model(pixel_values).feature_maps + if hasattr(self, "encoder_input_proj"): + return [self.encoder_input_proj[i](feat) for i, feat in enumerate(features)] + return list(features) + + +class Deimv2DINOv3ConvEncoder(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + self.backbone = load_backbone(config) self.sta = Deimv2SpatialTuningAdapter(config) - embed_dim = vit_config.hidden_size - hidden_dim = config.dinov3_hidden_dim or embed_dim + embed_dim = config.backbone_config.hidden_size + hidden_dim = config.encoder_hidden_dim sta_ch = config.sta_inplanes self.convs = nn.ModuleList( [ @@ -1180,28 +958,19 @@ def __init__(self, config: Deimv2Config): self.norms = nn.ModuleList([nn.BatchNorm2d(hidden_dim) for _ in range(3)]) def forward(self, pixel_values: torch.Tensor) -> list[torch.Tensor]: - hidden_states = self.embeddings(pixel_values) - position_embeddings = self.rope_embeddings(pixel_values) - - intermediate_outputs = [] - for i, layer in enumerate(self.layers): - hidden_states = layer(hidden_states, position_embeddings=position_embeddings) - if i in self.interaction_indexes: - out = self.norm(hidden_states) if self.apply_layernorm else hidden_states - intermediate_outputs.append(out) + backbone_output = self.backbone(pixel_values) + feature_maps = backbone_output.feature_maps - batch_size = pixel_values.shape[0] - h_patches = pixel_values.shape[2] // self.patch_size - w_patches = pixel_values.shape[3] // self.patch_size + patch_size = self.backbone.config.patch_size + h_patches = pixel_values.shape[2] // patch_size + w_patches = pixel_values.shape[3] // patch_size sem_feats = [] - num_scales = len(intermediate_outputs) - for i, feat in enumerate(intermediate_outputs): - patch_tokens = feat[:, self.num_prefix_tokens :] - spatial = patch_tokens.transpose(1, 2).reshape(batch_size, -1, h_patches, w_patches).contiguous() + num_scales = len(feature_maps) + for i, feat in enumerate(feature_maps): resize_h = int(h_patches * 2 ** (num_scales - 2 - i)) resize_w = int(w_patches * 2 ** (num_scales - 2 - i)) - spatial = F.interpolate(spatial, size=[resize_h, resize_w], mode="bilinear", align_corners=False) + spatial = F.interpolate(feat, size=[resize_h, resize_w], mode="bilinear", align_corners=False) sem_feats.append(spatial) detail_feats = self.sta(pixel_values) @@ -1275,6 +1044,7 @@ def __init__(self, config: Deimv2Config): self.final_layer_norm = Deimv2RMSNorm(config.d_model) # gate self.gateway = Deimv2Gate(config.d_model) + self.use_gateway = config.use_gateway if config.use_gateway: self.gateway = Deimv2Gate(config.d_model) else: @@ -1336,7 +1106,7 @@ def forward( ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if hasattr(self, "gateway"): + if self.use_gateway: hidden_states = self.gateway(residual, hidden_states) else: hidden_states = residual + hidden_states @@ -1345,7 +1115,8 @@ def forward( residual = hidden_states hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states.clamp(min=-65504, max=65504)) + clamp_value = torch.finfo(hidden_states.dtype).max + hidden_states = self.final_layer_norm(hidden_states.clamp(min=-clamp_value, max=clamp_value)) return hidden_states @@ -1381,6 +1152,7 @@ class Deimv2PreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" + # initialize linear layer bias value according to a given probability value. if isinstance(module, (Deimv2ForObjectDetection, Deimv2Decoder)): if module.class_embed is not None: for layer in module.class_embed: @@ -1443,38 +1215,26 @@ def _init_weights(self, module): init.constant_(module.reg_conf.layers[-1].bias, 0) init.constant_(module.reg_conf.layers[-1].weight, 0) - if isinstance(module, Deimv2SwiGLUFFN): - init.xavier_uniform_(module.w12.weight) - init.constant_(module.w12.bias, 0) - init.xavier_uniform_(module.w3.weight) - init.constant_(module.w3.bias, 0) - - if isinstance(module, Deimv2RMSNorm): - init.ones_(module.scale) - if isinstance(module, nn.LayerNorm): init.ones_(module.weight) init.zeros_(module.bias) - if isinstance(module, Deimv2DINOv3ViTEmbeddings): - init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) - if module.config.num_register_tokens > 0: - init.trunc_normal_(module.register_tokens, mean=0.0, std=self.config.initializer_range) - init.zeros_(module.mask_token) - - if isinstance(module, Deimv2DINOv3ViTLayerScale) and self.config.dinov3_backbone_config is not None: - layerscale_value = self.config.dinov3_backbone_config.get("layerscale_value", 1.0) - init.constant_(module.lambda1, layerscale_value) - - if isinstance(module, Deimv2DINOv3ViTRopePositionEmbedding): - inv_freq = 1 / module.base ** torch.arange(0, 1, 4 / module.head_dim, dtype=torch.float32) - init.copy_(module.inv_freq, inv_freq) - if hasattr(module, "weight_embedding") and self.config.learn_initial_query: init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: init.xavier_uniform_(module.denoising_class_embed.weight) + if isinstance(module, Deimv2SwiGLUFFN): + init.xavier_uniform_(module.gate_proj.weight) + init.constant_(module.gate_proj.bias, 0) + init.xavier_uniform_(module.up_proj.weight) + init.constant_(module.up_proj.bias, 0) + init.xavier_uniform_(module.down_proj.weight) + init.constant_(module.down_proj.bias, 0) + + if isinstance(module, Deimv2RMSNorm): + init.ones_(module.weight) + class Deimv2HybridEncoder(Deimv2PreTrainedModel): """ @@ -1787,64 +1547,6 @@ def forward( ) -@dataclass -@auto_docstring( - custom_intro=""" - Base class for outputs of the RT-DETR encoder-decoder model. - """ -) -class Deimv2ModelOutput(ModelOutput): - r""" - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the decoder of the model. - intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): - Stacked intermediate hidden states (output of each layer of the decoder). - intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): - Stacked intermediate logits (logits of each layer of the decoder). - intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): - Stacked intermediate reference points (reference points of each layer of the decoder). - intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): - Stacked intermediate predicted corners (predicted corners of each layer of the decoder). - initial_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): - Initial reference points used for the first decoder layer. - init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): - Initial reference points sent through the Transformer decoder. - enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): - Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are - picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e. - foreground and background). - enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`): - Logits of predicted bounding boxes coordinates in the encoder stage. - enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): - Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are - picked as region proposals in the first stage. Output of bounding box binary classification (i.e. - foreground and background). - enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): - Logits of predicted bounding boxes coordinates in the first stage. - denoising_meta_values (`dict`): - Extra dictionary for the denoising related values. - """ - - last_hidden_state: torch.FloatTensor | None = None - intermediate_hidden_states: torch.FloatTensor | None = None - intermediate_logits: torch.FloatTensor | None = None - intermediate_reference_points: torch.FloatTensor | None = None - intermediate_predicted_corners: torch.FloatTensor | None = None - initial_reference_points: torch.FloatTensor | None = None - decoder_hidden_states: tuple[torch.FloatTensor] | None = None - decoder_attentions: tuple[torch.FloatTensor] | None = None - cross_attentions: tuple[torch.FloatTensor] | None = None - encoder_last_hidden_state: torch.FloatTensor | None = None - encoder_hidden_states: tuple[torch.FloatTensor] | None = None - encoder_attentions: tuple[torch.FloatTensor] | None = None - init_reference_points: torch.FloatTensor | None = None - enc_topk_logits: torch.FloatTensor | None = None - enc_topk_bboxes: torch.FloatTensor | None = None - enc_outputs_class: torch.FloatTensor | None = None - enc_outputs_coord_logits: torch.FloatTensor | None = None - denoising_meta_values: dict | None = None - - def get_contrastive_denoising_training_group( targets, num_classes, @@ -1977,23 +1679,11 @@ class Deimv2Model(Deimv2PreTrainedModel): def __init__(self, config: Deimv2Config): super().__init__(config) - if config.backbone_type == "dinov3": - self.dinov3_backbone = Deimv2DINOv3Backbone(config) + is_dinov3 = getattr(config.backbone_config, "model_type", None) == "dinov3_vit" + if is_dinov3: + self.backbone = Deimv2DINOv3ConvEncoder(config) else: - from ..d_fine.modeling_d_fine import DFineConvEncoder - - self.backbone = DFineConvEncoder(config) - if config.encoder_type != "lite": - intermediate_channel_sizes = self.backbone.intermediate_channel_sizes - encoder_input_proj = [] - for in_channel in intermediate_channel_sizes: - encoder_input_proj.append( - nn.Sequential( - nn.Conv2d(in_channel, config.encoder_hidden_dim, kernel_size=1, bias=False), - nn.BatchNorm2d(config.encoder_hidden_dim), - ) - ) - self.encoder_input_proj = nn.ModuleList(encoder_input_proj) + self.backbone = Deimv2ConvEncoder(config) if config.encoder_type == "lite": self.encoder = Deimv2LiteEncoder(config) @@ -2038,9 +1728,6 @@ def __init__(self, config: Deimv2Config): self.decoder_input_proj = nn.ModuleList(decoder_input_proj) self.decoder = Deimv2Decoder(config) - if config.use_spatial_tuning_adapter and config.backbone_type != "dinov3": - self.spatial_tuning_adapter = Deimv2SpatialTuningAdapter(config) - self.post_init() def freeze_backbone(self): @@ -2131,14 +1818,7 @@ def forward( if pixel_mask is None: pixel_mask = torch.ones(((batch_size, height, width)), device=device) - if self.config.backbone_type == "dinov3": - proj_feats = self.dinov3_backbone(pixel_values) - elif self.config.encoder_type == "lite": - features = self.backbone(pixel_values, pixel_mask) - proj_feats = [source for source, mask in features] - else: - features = self.backbone(pixel_values, pixel_mask) - proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)] + proj_feats = self.backbone(pixel_values) else: batch_size = inputs_embeds.shape[0] device = inputs_embeds.device @@ -2180,8 +1860,6 @@ def forward( level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) if self.training and self.config.num_denoising > 0 and labels is not None: - from ..d_fine.modeling_d_fine import get_contrastive_denoising_training_group - ( denoising_class, denoising_bbox_unact, @@ -2242,8 +1920,6 @@ def forward( init_reference_points = reference_points_unact.detach() - from ..d_fine.modeling_d_fine import DFineModelOutput - decoder_outputs = self.decoder( inputs_embeds=target, encoder_hidden_states=source_flatten, @@ -2255,7 +1931,7 @@ def forward( **kwargs, ) - return DFineModelOutput( + return Deimv2ModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, intermediate_logits=decoder_outputs.intermediate_logits, diff --git a/src/transformers/models/deimv2/modular_deimv2.py b/src/transformers/models/deimv2/modular_deimv2.py index 5a7edf292378..536f0c10937d 100644 --- a/src/transformers/models/deimv2/modular_deimv2.py +++ b/src/transformers/models/deimv2/modular_deimv2.py @@ -11,17 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math import torch import torch.nn as nn import torch.nn.functional as F from ... import initialization as init +from ...backbone_utils import load_backbone from ...modeling_outputs import BaseModelOutput from ...utils import logging +from ..auto import AutoConfig from ..d_fine.configuration_d_fine import DFineConfig from ..d_fine.modeling_d_fine import ( + DFineAIFILayer, + DFineConvEncoder, DFineConvNormLayer, DFineDecoder, DFineDecoderLayer, @@ -34,19 +37,14 @@ DFineLQE, DFineMLP, DFineModel, + DFineModelOutput, DFineMultiscaleDeformableAttention, DFinePreTrainedModel, DFineRepVggBlock, DFineSCDown, + get_contrastive_denoising_training_group, ) -from ..dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig -from ..dinov3_vit.modeling_dinov3_vit import ( - DINOv3ViTEmbeddings, - DINOv3ViTLayer, - DINOv3ViTLayerScale, - DINOv3ViTRopePositionEmbedding, -) -from ..rt_detr.modeling_rt_detr import RTDetrAIFILayer +from ..llama.modeling_llama import LlamaRMSNorm logger = logging.get_logger(__name__) @@ -73,7 +71,8 @@ class Deimv2Config(DFineConfig): batch_norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the batch normalization layers. backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `HGNetV2Config()`): - The configuration of the backbone model. + The configuration of the backbone model. For HGNetV2 variants, use `HGNetV2Config`. + For DINOv3 variants, use `DINOv3ViTConfig`. freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`): Whether to freeze the batch normalization layers in the backbone. encoder_hidden_dim (`int`, *optional*, defaults to 256): @@ -202,8 +201,6 @@ class Deimv2Config(DFineConfig): Alpha parameter for the Matching Auxiliary Loss (MAL). If `None`, uses `focal_loss_alpha`. encoder_fuse_op (`str`, *optional*, defaults to `"sum"`): Fusion operation used in the encoder FPN. DEIMv2 uses `"sum"` instead of D-Fine's `"cat"`. - use_spatial_tuning_adapter (`bool`, *optional*, defaults to `False`): - Whether to use the Spatial Tuning Adapter (STA) for DINOv2 backbone variants. sta_inplanes (`int`, *optional*, defaults to 16): Number of input planes for the STA convolutional stem. encoder_type (`str`, *optional*, defaults to `"hybrid"`): @@ -214,17 +211,6 @@ class Deimv2Config(DFineConfig): uses RMSNorm on the encoder attention output instead. share_bbox_head (`bool`, *optional*, defaults to `False`): Whether to share the bounding box prediction head across all decoder layers. - backbone_type (`str`, *optional*, defaults to `"hgnetv2"`): - Type of backbone to use. `"hgnetv2"` uses HGNetV2, `"dinov3"` uses DINOv3 ViT backbone with STA. - dinov3_backbone_config (`dict`, *optional*): - Configuration dictionary for the DINOv3 ViT backbone. Passed as kwargs to `DINOv3ViTConfig`. - dinov3_interaction_indexes (`list[int]`, *optional*): - Layer indices in the DINOv3 ViT backbone from which to extract intermediate features. - dinov3_hidden_dim (`int`, *optional*): - Hidden dimension for the DINOv3 backbone projection convolutions. If `None`, uses `hidden_size` from - the DINOv3 ViT config. - dinov3_apply_layernorm (`bool`, *optional*, defaults to `False`): - Whether to apply LayerNorm to intermediate features extracted from the DINOv3 ViT backbone. encoder_has_trailing_conv (`bool`, *optional*, defaults to `True`): Whether the encoder's CSP blocks include a trailing 3x3 convolution after the bottleneck path. `True` for RepNCSPELAN4 (used by HGNetV2 N and LiteEncoder variants). @@ -234,6 +220,7 @@ class Deimv2Config(DFineConfig): """ model_type = "deimv2" + sub_configs = {"backbone_config": AutoConfig} def __init__( self, @@ -305,16 +292,10 @@ def __init__( use_dense_o2o=True, mal_alpha=None, encoder_fuse_op="sum", - use_spatial_tuning_adapter=False, sta_inplanes=16, encoder_type="hybrid", use_gateway=True, share_bbox_head=False, - backbone_type="hgnetv2", - dinov3_backbone_config=None, - dinov3_interaction_indexes=None, - dinov3_hidden_dim=None, - dinov3_apply_layernorm=False, encoder_has_trailing_conv=True, tie_word_embeddings=True, **kwargs, @@ -323,16 +304,10 @@ def __init__( self.use_dense_o2o = use_dense_o2o self.mal_alpha = mal_alpha self.encoder_fuse_op = encoder_fuse_op - self.use_spatial_tuning_adapter = use_spatial_tuning_adapter self.sta_inplanes = sta_inplanes self.encoder_type = encoder_type self.use_gateway = use_gateway self.share_bbox_head = share_bbox_head - self.backbone_type = backbone_type - self.dinov3_backbone_config = dinov3_backbone_config - self.dinov3_interaction_indexes = dinov3_interaction_indexes - self.dinov3_hidden_dim = dinov3_hidden_dim - self.dinov3_apply_layernorm = dinov3_apply_layernorm self.encoder_has_trailing_conv = encoder_has_trailing_conv super().__init__( initializer_range=initializer_range, @@ -404,38 +379,27 @@ def __init__( ) -class Deimv2DINOv3ViTConfig(DINOv3ViTConfig): - model_type = "deimv2_dinov3_vit" - - class Deimv2DecoderOutput(DFineDecoderOutput): pass -class Deimv2RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.scale = nn.Parameter(torch.ones(dim)) +class Deimv2ModelOutput(DFineModelOutput): + pass - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.float() - hidden_states = hidden_states * torch.rsqrt(hidden_states.pow(2).mean(-1, keepdim=True) + self.eps) - return (hidden_states * self.scale).to(input_dtype) + +class Deimv2RMSNorm(LlamaRMSNorm): + pass class Deimv2SwiGLUFFN(nn.Module): def __init__(self, in_features: int, hidden_features: int, out_features: int): super().__init__() - self.w12 = nn.Linear(in_features, 2 * hidden_features) - self.w3 = nn.Linear(hidden_features, out_features) + self.gate_proj = nn.Linear(in_features, hidden_features) + self.up_proj = nn.Linear(in_features, hidden_features) + self.down_proj = nn.Linear(hidden_features, out_features) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - x12 = self.w12(hidden_states) - x1, x2 = x12.chunk(2, dim=-1) - hidden = F.silu(x1) * x2 - return self.w3(hidden) + return self.down_proj(F.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) class Deimv2Gate(DFineGate): @@ -460,23 +424,7 @@ class Deimv2RepVggBlock(DFineRepVggBlock): pass -class Deimv2DINOv3ViTEmbeddings(DINOv3ViTEmbeddings): - pass - - -class Deimv2DINOv3ViTLayerScale(DINOv3ViTLayerScale): - pass - - -class Deimv2DINOv3ViTLayer(DINOv3ViTLayer): - pass - - -class Deimv2DINOv3ViTRopePositionEmbedding(DINOv3ViTRopePositionEmbedding): - pass - - -class Deimv2CSPRepLayer2(nn.Module): +class Deimv2CSPRepLayer(nn.Module): def __init__( self, config: Deimv2Config, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0 ): @@ -493,34 +441,27 @@ def __init__( self.conv3 = nn.Identity() def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: - chunks = list(self.conv1(hidden_state).chunk(2, 1)) - bottleneck_out = chunks[1] + hidden_state_1, hidden_state_2 = self.conv1(hidden_state).chunk(2, 1) for bottleneck in self.bottlenecks: - bottleneck_out = bottleneck(bottleneck_out) - return self.conv3(chunks[0] + bottleneck_out) + hidden_state_2 = bottleneck(hidden_state_2) + return self.conv3(hidden_state_1 + hidden_state_2) class Deimv2RepNCSPELAN5(nn.Module): - def __init__( - self, - config: Deimv2Config, - numb_blocks: int = 3, - c1: int | None = None, - c2: int | None = None, - c3: int | None = None, - c4: int | None = None, - ): + def __init__(self, config: Deimv2Config, numb_blocks: int = 3): super().__init__() - act = config.activation_function - c1 = c1 if c1 is not None else config.encoder_hidden_dim - c2 = c2 if c2 is not None else config.encoder_hidden_dim - c3 = c3 if c3 is not None else config.encoder_hidden_dim * 2 - c4 = c4 if c4 is not None else round(config.hidden_expansion * config.encoder_hidden_dim // 2) - self.conv_dim = c3 // 2 - self.conv1 = Deimv2ConvNormLayer(config, c1, c3, 1, 1, activation=act) - self.csp_rep1 = Deimv2CSPRepLayer2(config, c3 // 2, c4, num_blocks=numb_blocks) - self.csp_rep2 = Deimv2CSPRepLayer2(config, c4, c4, num_blocks=numb_blocks) - self.conv4 = Deimv2ConvNormLayer(config, c3 + (2 * c4), c2, 1, 1, activation=act) + activation = config.activation_function + in_channels = config.encoder_hidden_dim + out_channels = config.encoder_hidden_dim + split_channels = config.encoder_hidden_dim * 2 + csp_channels = round(config.hidden_expansion * config.encoder_hidden_dim // 2) + self.conv_dim = split_channels // 2 + self.conv1 = Deimv2ConvNormLayer(config, in_channels, split_channels, 1, 1, activation=activation) + self.csp_rep1 = Deimv2CSPRepLayer(config, split_channels // 2, csp_channels, num_blocks=numb_blocks) + self.csp_rep2 = Deimv2CSPRepLayer(config, csp_channels, csp_channels, num_blocks=numb_blocks) + self.conv4 = Deimv2ConvNormLayer( + config, split_channels + (2 * csp_channels), out_channels, 1, 1, activation=activation + ) def forward(self, input_features: torch.Tensor) -> torch.Tensor: split_features = list(self.conv1(input_features).split((self.conv_dim, self.conv_dim), 1)) @@ -539,7 +480,7 @@ class Deimv2EncoderLayer(DFineEncoderLayer): pass -class Deimv2AIFILayer(RTDetrAIFILayer): +class Deimv2AIFILayer(DFineAIFILayer): pass @@ -569,11 +510,11 @@ def __init__(self, config: Deimv2Config): ) def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - c1 = self.stem(pixel_values) - c2 = self.conv2(c1) - c3 = self.conv3(c2) - c4 = self.conv4(c3) - return c2, c3, c4 + hidden_states = self.stem(pixel_values) + feature_map_8 = self.conv2(hidden_states) + feature_map_16 = self.conv3(feature_map_8) + feature_map_32 = self.conv4(feature_map_16) + return feature_map_8, feature_map_16, feature_map_32 class Deimv2GAPFusion(nn.Module): @@ -615,12 +556,9 @@ def __init__(self, config: Deimv2Config): self.bi_fusion = Deimv2GAPFusion(config, hidden_dim) - c1, c2 = hidden_dim, hidden_dim - c3 = hidden_dim * 2 - c4 = round(config.hidden_expansion * hidden_dim // 2) num_blocks = round(3 * config.depth_mult) - self.fpn_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks, c1=c1, c2=c2, c3=c3, c4=c4) - self.pan_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks, c1=c1, c2=c2, c3=c3, c4=c4) + self.fpn_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks) + self.pan_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks) def forward(self, inputs_embeds=None, **kwargs) -> BaseModelOutput: feats = inputs_embeds @@ -639,27 +577,36 @@ def forward(self, inputs_embeds=None, **kwargs) -> BaseModelOutput: return BaseModelOutput(last_hidden_state=outs) -class Deimv2DINOv3Backbone(nn.Module): - def __init__(self, config: Deimv2Config): - super().__init__() - vit_config = Deimv2DINOv3ViTConfig(**config.dinov3_backbone_config) +class Deimv2ConvEncoder(DFineConvEncoder): + def __init__(self, config): + super().__init__(config) + if config.encoder_type != "lite": + encoder_input_proj = [] + for in_channel in self.intermediate_channel_sizes: + encoder_input_proj.append( + nn.Sequential( + nn.Conv2d(in_channel, config.encoder_hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(config.encoder_hidden_dim), + ) + ) + self.encoder_input_proj = nn.ModuleList(encoder_input_proj) - self.embeddings = Deimv2DINOv3ViTEmbeddings(vit_config) - self.rope_embeddings = Deimv2DINOv3ViTRopePositionEmbedding(vit_config) - self.layers = nn.ModuleList([Deimv2DINOv3ViTLayer(vit_config) for _ in range(vit_config.num_hidden_layers)]) + def forward(self, pixel_values: torch.Tensor) -> list[torch.Tensor]: + features = self.model(pixel_values).feature_maps + if hasattr(self, "encoder_input_proj"): + return [self.encoder_input_proj[i](feat) for i, feat in enumerate(features)] + return list(features) - self.apply_layernorm = config.dinov3_apply_layernorm - if self.apply_layernorm: - self.norm = nn.LayerNorm(vit_config.hidden_size, eps=vit_config.layer_norm_eps) - self.interaction_indexes = config.dinov3_interaction_indexes - self.patch_size = vit_config.patch_size - self.num_prefix_tokens = 1 + vit_config.num_register_tokens +class Deimv2DINOv3ConvEncoder(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + self.backbone = load_backbone(config) self.sta = Deimv2SpatialTuningAdapter(config) - embed_dim = vit_config.hidden_size - hidden_dim = config.dinov3_hidden_dim or embed_dim + embed_dim = config.backbone_config.hidden_size + hidden_dim = config.encoder_hidden_dim sta_ch = config.sta_inplanes self.convs = nn.ModuleList( [ @@ -671,28 +618,19 @@ def __init__(self, config: Deimv2Config): self.norms = nn.ModuleList([nn.BatchNorm2d(hidden_dim) for _ in range(3)]) def forward(self, pixel_values: torch.Tensor) -> list[torch.Tensor]: - hidden_states = self.embeddings(pixel_values) - position_embeddings = self.rope_embeddings(pixel_values) - - intermediate_outputs = [] - for i, layer in enumerate(self.layers): - hidden_states = layer(hidden_states, position_embeddings=position_embeddings) - if i in self.interaction_indexes: - out = self.norm(hidden_states) if self.apply_layernorm else hidden_states - intermediate_outputs.append(out) + backbone_output = self.backbone(pixel_values) + feature_maps = backbone_output.feature_maps - batch_size = pixel_values.shape[0] - h_patches = pixel_values.shape[2] // self.patch_size - w_patches = pixel_values.shape[3] // self.patch_size + patch_size = self.backbone.config.patch_size + h_patches = pixel_values.shape[2] // patch_size + w_patches = pixel_values.shape[3] // patch_size sem_feats = [] - num_scales = len(intermediate_outputs) - for i, feat in enumerate(intermediate_outputs): - patch_tokens = feat[:, self.num_prefix_tokens :] - spatial = patch_tokens.transpose(1, 2).reshape(batch_size, -1, h_patches, w_patches).contiguous() + num_scales = len(feature_maps) + for i, feat in enumerate(feature_maps): resize_h = int(h_patches * 2 ** (num_scales - 2 - i)) resize_w = int(w_patches * 2 ** (num_scales - 2 - i)) - spatial = F.interpolate(spatial, size=[resize_h, resize_w], mode="bilinear", align_corners=False) + spatial = F.interpolate(feat, size=[resize_h, resize_w], mode="bilinear", align_corners=False) sem_feats.append(spatial) detail_feats = self.sta(pixel_values) @@ -720,6 +658,7 @@ def __init__(self, config: Deimv2Config): self.self_attn_layer_norm = Deimv2RMSNorm(config.d_model) self.final_layer_norm = Deimv2RMSNorm(config.d_model) self.mlp = Deimv2SwiGLUFFN(config.d_model, config.decoder_ffn_dim // 2, config.d_model) + self.use_gateway = config.use_gateway if config.use_gateway: self.gateway = Deimv2Gate(config.d_model) else: @@ -762,7 +701,7 @@ def forward( ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if hasattr(self, "gateway"): + if self.use_gateway: hidden_states = self.gateway(residual, hidden_states) else: hidden_states = residual + hidden_states @@ -771,7 +710,8 @@ def forward( residual = hidden_states hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states.clamp(min=-65504, max=65504)) + clamp_value = torch.finfo(hidden_states.dtype).max + hidden_states = self.final_layer_norm(hidden_states.clamp(min=-clamp_value, max=clamp_value)) return hidden_states @@ -783,100 +723,18 @@ class Deimv2MLPPredictionHead(DFineMLP): class Deimv2PreTrainedModel(DFinePreTrainedModel): @torch.no_grad() def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (Deimv2ForObjectDetection, Deimv2Decoder)): - if module.class_embed is not None: - for layer in module.class_embed: - prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) - bias = float(-math.log((1 - prior_prob) / prior_prob)) - init.xavier_uniform_(layer.weight) - init.constant_(layer.bias, bias) - - if module.bbox_embed is not None: - for layer in module.bbox_embed: - init.constant_(layer.layers[-1].weight, 0) - init.constant_(layer.layers[-1].bias, 0) - - if hasattr(module, "reg_scale"): - init.constant_(module.reg_scale, self.config.reg_scale) - - if hasattr(module, "up"): - init.constant_(module.up, self.config.up) - - if isinstance(module, Deimv2MultiscaleDeformableAttention): - init.constant_(module.sampling_offsets.weight, 0.0) - default_dtype = torch.get_default_dtype() - thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( - 2.0 * math.pi / module.n_heads - ) - grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) - grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values - grid_init = grid_init.reshape(module.n_heads, 1, 2).tile([1, sum(module.num_points_list), 1]) - scaling = torch.concat([torch.arange(1, n + 1) for n in module.num_points_list]).reshape(1, -1, 1) - grid_init *= scaling - init.copy_(module.sampling_offsets.bias, grid_init.flatten()) - - init.constant_(module.attention_weights.weight, 0.0) - init.constant_(module.attention_weights.bias, 0.0) - - num_points_scale = [1 / n for n in module.num_points_list for _ in range(n)] - init.copy_(module.num_points_scale, torch.tensor(num_points_scale, dtype=torch.float32)) - - if isinstance(module, Deimv2Model): - prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) - bias = float(-math.log((1 - prior_prob) / prior_prob)) - init.xavier_uniform_(module.enc_score_head.weight) - init.constant_(module.enc_score_head.bias, bias) - - if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - init.zeros_(module.bias) - if getattr(module, "running_mean", None) is not None: - init.zeros_(module.running_mean) - init.ones_(module.running_var) - init.zeros_(module.num_batches_tracked) - - if isinstance(module, Deimv2Gate): - bias = float(-math.log((1 - 0.5) / 0.5)) - init.constant_(module.gate.bias, bias) - init.constant_(module.gate.weight, 0) - - if isinstance(module, Deimv2LQE): - init.constant_(module.reg_conf.layers[-1].bias, 0) - init.constant_(module.reg_conf.layers[-1].weight, 0) + super()._init_weights(module) if isinstance(module, Deimv2SwiGLUFFN): - init.xavier_uniform_(module.w12.weight) - init.constant_(module.w12.bias, 0) - init.xavier_uniform_(module.w3.weight) - init.constant_(module.w3.bias, 0) + init.xavier_uniform_(module.gate_proj.weight) + init.constant_(module.gate_proj.bias, 0) + init.xavier_uniform_(module.up_proj.weight) + init.constant_(module.up_proj.bias, 0) + init.xavier_uniform_(module.down_proj.weight) + init.constant_(module.down_proj.bias, 0) if isinstance(module, Deimv2RMSNorm): - init.ones_(module.scale) - - if isinstance(module, nn.LayerNorm): init.ones_(module.weight) - init.zeros_(module.bias) - - if isinstance(module, Deimv2DINOv3ViTEmbeddings): - init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) - if module.config.num_register_tokens > 0: - init.trunc_normal_(module.register_tokens, mean=0.0, std=self.config.initializer_range) - init.zeros_(module.mask_token) - - if isinstance(module, Deimv2DINOv3ViTLayerScale) and self.config.dinov3_backbone_config is not None: - layerscale_value = self.config.dinov3_backbone_config.get("layerscale_value", 1.0) - init.constant_(module.lambda1, layerscale_value) - - if isinstance(module, Deimv2DINOv3ViTRopePositionEmbedding): - inv_freq = 1 / module.base ** torch.arange(0, 1, 4 / module.head_dim, dtype=torch.float32) - init.copy_(module.inv_freq, inv_freq) - - if hasattr(module, "weight_embedding") and self.config.learn_initial_query: - init.xavier_uniform_(module.weight_embedding.weight) - if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: - init.xavier_uniform_(module.denoising_class_embed.weight) class Deimv2HybridEncoder(DFineHybridEncoder): @@ -985,23 +843,11 @@ class Deimv2Model(DFineModel): def __init__(self, config: Deimv2Config): Deimv2PreTrainedModel.__init__(self, config) - if config.backbone_type == "dinov3": - self.dinov3_backbone = Deimv2DINOv3Backbone(config) + is_dinov3 = getattr(config.backbone_config, "model_type", None) == "dinov3_vit" + if is_dinov3: + self.backbone = Deimv2DINOv3ConvEncoder(config) else: - from ..d_fine.modeling_d_fine import DFineConvEncoder - - self.backbone = DFineConvEncoder(config) - if config.encoder_type != "lite": - intermediate_channel_sizes = self.backbone.intermediate_channel_sizes - encoder_input_proj = [] - for in_channel in intermediate_channel_sizes: - encoder_input_proj.append( - nn.Sequential( - nn.Conv2d(in_channel, config.encoder_hidden_dim, kernel_size=1, bias=False), - nn.BatchNorm2d(config.encoder_hidden_dim), - ) - ) - self.encoder_input_proj = nn.ModuleList(encoder_input_proj) + self.backbone = Deimv2ConvEncoder(config) if config.encoder_type == "lite": self.encoder = Deimv2LiteEncoder(config) @@ -1046,9 +892,6 @@ def __init__(self, config: Deimv2Config): self.decoder_input_proj = nn.ModuleList(decoder_input_proj) self.decoder = Deimv2Decoder(config) - if config.use_spatial_tuning_adapter and config.backbone_type != "dinov3": - self.spatial_tuning_adapter = Deimv2SpatialTuningAdapter(config) - self.post_init() def forward( @@ -1069,14 +912,7 @@ def forward( if pixel_mask is None: pixel_mask = torch.ones(((batch_size, height, width)), device=device) - if self.config.backbone_type == "dinov3": - proj_feats = self.dinov3_backbone(pixel_values) - elif self.config.encoder_type == "lite": - features = self.backbone(pixel_values, pixel_mask) - proj_feats = [source for source, mask in features] - else: - features = self.backbone(pixel_values, pixel_mask) - proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)] + proj_feats = self.backbone(pixel_values) else: batch_size = inputs_embeds.shape[0] device = inputs_embeds.device @@ -1118,8 +954,6 @@ def forward( level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) if self.training and self.config.num_denoising > 0 and labels is not None: - from ..d_fine.modeling_d_fine import get_contrastive_denoising_training_group - ( denoising_class, denoising_bbox_unact, @@ -1180,8 +1014,6 @@ def forward( init_reference_points = reference_points_unact.detach() - from ..d_fine.modeling_d_fine import DFineModelOutput - decoder_outputs = self.decoder( inputs_embeds=target, encoder_hidden_states=source_flatten, @@ -1193,7 +1025,7 @@ def forward( **kwargs, ) - return DFineModelOutput( + return Deimv2ModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, intermediate_logits=decoder_outputs.intermediate_logits, diff --git a/tests/models/deimv2/test_modeling_deimv2.py b/tests/models/deimv2/test_modeling_deimv2.py index db4ce97086b4..8ec713a524ef 100644 --- a/tests/models/deimv2/test_modeling_deimv2.py +++ b/tests/models/deimv2/test_modeling_deimv2.py @@ -24,6 +24,7 @@ from transformers import ( Deimv2Config, + DINOv3ViTConfig, HGNetV2Config, is_torch_available, is_vision_available, @@ -1136,36 +1137,21 @@ def prepare_config_and_inputs(self): return config, pixel_values, pixel_mask, labels def get_config(self): - hidden_sizes = [64, 128, 256, 512] - backbone_config = HGNetV2Config( - stage_in_channels=[16, 64, 128, 256], - stage_mid_channels=[16, 32, 64, 128], - stage_out_channels=[64, 128, 256, 512], - stage_num_blocks=[1, 1, 2, 1], - stage_downsample=[False, True, True, True], - stage_light_block=[False, False, True, True], - stage_kernel_size=[3, 3, 5, 5], - stage_numb_of_layers=[3, 3, 3, 3], - embeddings_size=10, - hidden_sizes=hidden_sizes, - depths=[1, 1, 2, 1], - out_features=["stage2", "stage3", "stage4"], + backbone_config = DINOv3ViTConfig( + hidden_size=32, + num_attention_heads=2, + num_hidden_layers=4, + intermediate_size=64, + num_register_tokens=1, + layerscale_value=1.0, + use_gated_mlp=False, + rope_theta=100.0, + patch_size=16, + image_size=self.image_size, out_indices=[2, 3, 4], - stem_channels=[3, 16, 16], - use_lab=True, + apply_layernorm=False, + reshape_hidden_states=True, ) - dinov3_backbone_config = { - "hidden_size": 32, - "num_attention_heads": 2, - "num_hidden_layers": 4, - "intermediate_size": 64, - "num_register_tokens": 1, - "layerscale_value": 1.0, - "use_gated_mlp": False, - "rope_theta": 100.0, - "patch_size": 16, - "image_size": self.image_size, - } return Deimv2Config( backbone_config=backbone_config, encoder_hidden_dim=self.encoder_hidden_dim, @@ -1207,13 +1193,7 @@ def get_config(self): image_size=self.image_size, disable_custom_kernels=self.disable_custom_kernels, with_box_refine=self.with_box_refine, - backbone_type="dinov3", - dinov3_backbone_config=dinov3_backbone_config, - dinov3_interaction_indexes=[1, 2, 3], - dinov3_hidden_dim=self.encoder_hidden_dim, - dinov3_apply_layernorm=False, sta_inplanes=self.sta_inplanes, - use_spatial_tuning_adapter=True, encoder_has_trailing_conv=False, ) @@ -1355,6 +1335,10 @@ def test_disk_offload_bin(self): def test_disk_offload_safetensors(self): pass + @unittest.skip(reason="DINOv3 backbone with RoPE and dynamic interpolation causes torch.compile inductor overflow") + def test_sdpa_can_compile_dynamic(self): + pass + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) def test_eager_matches_sdpa_inference( self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 413f5e6d61fa..345d8172c049 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -97,7 +97,6 @@ "DetrConfig": True, "DFineConfig": True, "Deimv2Config": True, - "Deimv2DINOv3ViTConfig": True, "GroundingDinoConfig": True, "MMGroundingDinoConfig": True, "RTDetrConfig": True, From 4ad0dc5bd9e4c0ab04a150ca77e05dbd9df3c07f Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Thu, 19 Mar 2026 15:58:56 +0400 Subject: [PATCH 06/25] refactor: Resolve second review round --- docs/source/en/model_doc/deimv2.md | 2 +- src/transformers/models/deimv2/__init__.py | 2 +- .../models/deimv2/configuration_deimv2.py | 542 +++++++----------- ...eimv2_original_pytorch_checkpoint_to_hf.py | 136 ++--- .../models/deimv2/modeling_deimv2.py | 49 +- .../models/deimv2/modular_deimv2.py | 491 +++++----------- tests/models/deimv2/test_modeling_deimv2.py | 148 ++--- 7 files changed, 519 insertions(+), 851 deletions(-) diff --git a/docs/source/en/model_doc/deimv2.md b/docs/source/en/model_doc/deimv2.md index 3d4e4a4a77b1..67cb46398921 100644 --- a/docs/source/en/model_doc/deimv2.md +++ b/docs/source/en/model_doc/deimv2.md @@ -1,4 +1,4 @@ - +*This model was released on 2025-09-25 and added to Hugging Face Transformers on 2026-04-22.* # DEIMv2 diff --git a/src/transformers/loss/loss_deimv2.py b/src/transformers/loss/loss_deimv2.py index 083e32fd2a82..88c1aa94d1fa 100644 --- a/src/transformers/loss/loss_deimv2.py +++ b/src/transformers/loss/loss_deimv2.py @@ -109,16 +109,29 @@ def get_loss(self, loss, outputs, targets, indices, num_boxes): return loss_map[loss](outputs, targets, indices, num_boxes) def forward(self, outputs, targets): + """ + This performs the loss computation. + + Args: + outputs (`dict`, *optional*): + Dictionary of tensors, see the output specification of the model for the format. + targets (`list[dict]`, *optional*): + List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the + losses applied, see each loss' doc. + """ if not self.use_dense_one_to_one: return super().forward(outputs, targets) + # Retrieve the matching between the outputs of the last layer and the targets outputs_without_aux = {k: v for k, v in outputs.items() if "auxiliary_outputs" not in k} indices = self.matcher(outputs_without_aux, targets) + # Compute the average number of target boxes across all nodes, for normalization purposes num_boxes = sum(len(t["class_labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) num_boxes = torch.clamp(num_boxes, min=1).item() + # Handle auxiliary outputs matching cached_indices = [] indices_aux_list = [] if "auxiliary_outputs" in outputs: @@ -127,11 +140,13 @@ def forward(self, outputs, targets): cached_indices.append(aux_indices) indices_aux_list.append(aux_indices) + # Dense one-to-one matching indices_go = self._get_dense_o2o_indices(indices, indices_aux_list) num_boxes_go = sum(len(x[0]) for x in indices_go) num_boxes_go = torch.as_tensor([num_boxes_go], dtype=torch.float, device=next(iter(outputs.values())).device) num_boxes_go = torch.clamp(num_boxes_go, min=1).item() + # Compute all the requested losses losses = {} for loss in self.losses: use_union = loss in ("boxes", "local") @@ -141,6 +156,7 @@ def forward(self, outputs, targets): l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} losses.update(l_dict) + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if "auxiliary_outputs" in outputs: for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]): for loss in self.losses: @@ -152,6 +168,7 @@ def forward(self, outputs, targets): l_dict = {k + f"_aux_{i}": v for k, v in l_dict.items()} losses.update(l_dict) + # In case of cdn auxiliary losses. For deimv2 if "dn_auxiliary_outputs" in outputs: if "denoising_meta_values" not in outputs: raise ValueError( @@ -187,55 +204,50 @@ def Deimv2ForObjectDetectionLoss( ): criterion = Deimv2Loss(config) criterion.to(device) - outputs_loss = {} + + outputs_loss = {"logits": logits, "pred_boxes": pred_boxes.clamp(min=0, max=1)} auxiliary_outputs = None - if config.auxiliary_loss: - if denoising_meta_values is not None: - dn_out_coord, outputs_coord = torch.split( - outputs_coord.clamp(min=0, max=1), denoising_meta_values["dn_num_split"], dim=2 - ) - dn_out_class, outputs_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2) - # https://github.com/Intellindust-AI-Lab/DEIMv2/blob/main/engine/deim/deim_decoder.py#L562-L571 - # The original splits denoising queries in the decoder; here it happens in the loss since the decoder returns unsplit tensors. - _, logits = torch.split(logits, denoising_meta_values["dn_num_split"], dim=1) - _, pred_boxes = torch.split(pred_boxes, denoising_meta_values["dn_num_split"], dim=1) - dn_out_corners, out_corners = torch.split(predicted_corners, denoising_meta_values["dn_num_split"], dim=2) - dn_out_refs, out_refs = torch.split(initial_reference_points, denoising_meta_values["dn_num_split"], dim=2) - - outputs_loss["logits"] = logits - outputs_loss["pred_boxes"] = pred_boxes.clamp(min=0, max=1) - - auxiliary_outputs = _set_aux_loss2( - outputs_class[:, :-1].transpose(0, 1), - outputs_coord[:, :-1].transpose(0, 1), - out_corners[:, :-1].transpose(0, 1), - out_refs[:, :-1].transpose(0, 1), - out_corners[:, -1], - outputs_class[:, -1], - ) - - outputs_loss["auxiliary_outputs"] = auxiliary_outputs - outputs_loss["auxiliary_outputs"].extend( - _set_aux_loss([enc_topk_logits], [enc_topk_bboxes.clamp(min=0, max=1)]) - ) - - dn_auxiliary_outputs = _set_aux_loss2( - dn_out_class.transpose(0, 1), - dn_out_coord.transpose(0, 1), - dn_out_corners.transpose(0, 1), - dn_out_refs.transpose(0, 1), - dn_out_corners[:, -1], - dn_out_class[:, -1], - ) - outputs_loss["dn_auxiliary_outputs"] = dn_auxiliary_outputs - outputs_loss["denoising_meta_values"] = denoising_meta_values - else: - outputs_loss["logits"] = logits - outputs_loss["pred_boxes"] = pred_boxes.clamp(min=0, max=1) - else: + + if config.auxiliary_loss and denoising_meta_values is not None: + dn_out_coord, outputs_coord = torch.split( + outputs_coord.clamp(min=0, max=1), denoising_meta_values["dn_num_split"], dim=2 + ) + dn_out_class, outputs_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2) + # https://github.com/Intellindust-AI-Lab/DEIMv2/blob/main/engine/deim/deim_decoder.py#L562-L571 + # The original splits denoising queries in the decoder; here it happens in the loss since the decoder returns unsplit tensors. + _, logits = torch.split(logits, denoising_meta_values["dn_num_split"], dim=1) + _, pred_boxes = torch.split(pred_boxes, denoising_meta_values["dn_num_split"], dim=1) + dn_out_corners, out_corners = torch.split(predicted_corners, denoising_meta_values["dn_num_split"], dim=2) + dn_out_refs, out_refs = torch.split(initial_reference_points, denoising_meta_values["dn_num_split"], dim=2) + outputs_loss["logits"] = logits outputs_loss["pred_boxes"] = pred_boxes.clamp(min=0, max=1) + auxiliary_outputs = _set_aux_loss2( + outputs_class[:, :-1].transpose(0, 1), + outputs_coord[:, :-1].transpose(0, 1), + out_corners[:, :-1].transpose(0, 1), + out_refs[:, :-1].transpose(0, 1), + out_corners[:, -1], + outputs_class[:, -1], + ) + + outputs_loss["auxiliary_outputs"] = auxiliary_outputs + outputs_loss["auxiliary_outputs"].extend( + _set_aux_loss([enc_topk_logits], [enc_topk_bboxes.clamp(min=0, max=1)]) + ) + + dn_auxiliary_outputs = _set_aux_loss2( + dn_out_class.transpose(0, 1), + dn_out_coord.transpose(0, 1), + dn_out_corners.transpose(0, 1), + dn_out_refs.transpose(0, 1), + dn_out_corners[:, -1], + dn_out_class[:, -1], + ) + outputs_loss["dn_auxiliary_outputs"] = dn_auxiliary_outputs + outputs_loss["denoising_meta_values"] = denoising_meta_values + loss_dict = criterion(outputs_loss, labels) loss = sum(loss_dict.values()) diff --git a/src/transformers/models/deimv2/modeling_deimv2.py b/src/transformers/models/deimv2/modeling_deimv2.py index 04dd673ed826..070592ff08f0 100644 --- a/src/transformers/models/deimv2/modeling_deimv2.py +++ b/src/transformers/models/deimv2/modeling_deimv2.py @@ -1158,6 +1158,7 @@ def _init_weights(self, module): init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: init.xavier_uniform_(module.denoising_class_embed.weight) + PreTrainedModel._init_weights(self, module) if isinstance(module, Deimv2SwiGLUFFN): init.xavier_uniform_(module.gate_proj.weight) @@ -1167,9 +1168,6 @@ def _init_weights(self, module): init.xavier_uniform_(module.down_proj.weight) init.constant_(module.down_proj.bias, 0) - if isinstance(module, Deimv2RMSNorm): - init.ones_(module.weight) - class Deimv2LiteEncoder(Deimv2PreTrainedModel): # LiteEncoder has no transformer layers, so hidden_states are recorded from the conv projections. @@ -2023,8 +2021,6 @@ class Deimv2ObjectDetectionOutput(ModelOutput): """ ) class Deimv2ForObjectDetection(Deimv2PreTrainedModel): - # When using clones, all layers > 0 will be clones, but layer 0 *is* required - # We can't initialize the model on meta device as some weights are modified during the initialization _no_split_modules = None _tied_weights_keys = { r"bbox_embed.(?![0])\d+": r"bbox_embed.0", diff --git a/src/transformers/models/deimv2/modular_deimv2.py b/src/transformers/models/deimv2/modular_deimv2.py index 59ad178d8278..271f2756df67 100644 --- a/src/transformers/models/deimv2/modular_deimv2.py +++ b/src/transformers/models/deimv2/modular_deimv2.py @@ -22,6 +22,7 @@ from ... import initialization as init from ...backbone_utils import load_backbone from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging from ...utils.generic import merge_with_config_defaults @@ -491,6 +492,7 @@ class Deimv2PreTrainedModel(DFinePreTrainedModel): @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) + PreTrainedModel._init_weights(self, module) if isinstance(module, Deimv2SwiGLUFFN): init.xavier_uniform_(module.gate_proj.weight) @@ -500,9 +502,6 @@ def _init_weights(self, module): init.xavier_uniform_(module.down_proj.weight) init.constant_(module.down_proj.bias, 0) - if isinstance(module, Deimv2RMSNorm): - init.ones_(module.weight) - class Deimv2LiteEncoder(Deimv2PreTrainedModel): # LiteEncoder has no transformer layers, so hidden_states are recorded from the conv projections. @@ -854,6 +853,8 @@ def forward( class Deimv2ForObjectDetection(DFineForObjectDetection): + _no_split_modules = None + @property def _tied_weights_keys(self): keys = { From 943f4bb7dcacbeb0424f21ed77ac00511622b7d7 Mon Sep 17 00:00:00 2001 From: vasqu Date: Wed, 22 Apr 2026 13:07:32 +0200 Subject: [PATCH 16/25] fixup their init weights --- .../models/d_fine/modeling_d_fine.py | 14 +------------- src/transformers/models/d_fine/modular_d_fine.py | 15 ++------------- .../models/deimv2/modeling_deimv2.py | 16 +--------------- src/transformers/models/deimv2/modular_deimv2.py | 4 +--- 4 files changed, 5 insertions(+), 44 deletions(-) diff --git a/src/transformers/models/d_fine/modeling_d_fine.py b/src/transformers/models/d_fine/modeling_d_fine.py index 1c758f8b1dcd..f1d23356fb2b 100644 --- a/src/transformers/models/d_fine/modeling_d_fine.py +++ b/src/transformers/models/d_fine/modeling_d_fine.py @@ -874,6 +874,7 @@ class DFinePreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" + super()._init_weights(module) # initialize linear layer bias value according to a given probability value. if isinstance(module, (DFineForObjectDetection, DFineDecoder)): if module.class_embed is not None: @@ -919,15 +920,6 @@ def _init_weights(self, module): init.xavier_uniform_(module.enc_score_head.weight) init.constant_(module.enc_score_head.bias, bias) - if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - init.zeros_(module.bias) - if getattr(module, "running_mean", None) is not None: - init.zeros_(module.running_mean) - init.ones_(module.running_var) - init.zeros_(module.num_batches_tracked) - if isinstance(module, DFineGate): bias = float(-math.log((1 - 0.5) / 0.5)) init.constant_(module.gate.bias, bias) @@ -937,10 +929,6 @@ def _init_weights(self, module): init.constant_(module.reg_conf.layers[-1].bias, 0) init.constant_(module.reg_conf.layers[-1].weight, 0) - if isinstance(module, nn.LayerNorm): - init.ones_(module.weight) - init.zeros_(module.bias) - if hasattr(module, "weight_embedding") and self.config.learn_initial_query: init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: diff --git a/src/transformers/models/d_fine/modular_d_fine.py b/src/transformers/models/d_fine/modular_d_fine.py index ba5798ad93cb..49289f075037 100644 --- a/src/transformers/models/d_fine/modular_d_fine.py +++ b/src/transformers/models/d_fine/modular_d_fine.py @@ -23,6 +23,7 @@ from ...backbone_utils import consolidate_backbone_kwargs_to_config from ...configuration_utils import PreTrainedConfig from ...image_transforms import corners_to_center_format +from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging, torch_compilable_check from ..auto import AutoConfig @@ -678,6 +679,7 @@ class DFinePreTrainedModel(RTDetrPreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" + PreTrainedModel._init_weights(self, module) # initialize linear layer bias value according to a given probability value. if isinstance(module, (DFineForObjectDetection, DFineDecoder)): if module.class_embed is not None: @@ -723,15 +725,6 @@ def _init_weights(self, module): init.xavier_uniform_(module.enc_score_head.weight) init.constant_(module.enc_score_head.bias, bias) - if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - init.zeros_(module.bias) - if getattr(module, "running_mean", None) is not None: - init.zeros_(module.running_mean) - init.ones_(module.running_var) - init.zeros_(module.num_batches_tracked) - if isinstance(module, DFineGate): bias = float(-math.log((1 - 0.5) / 0.5)) init.constant_(module.gate.bias, bias) @@ -741,10 +734,6 @@ def _init_weights(self, module): init.constant_(module.reg_conf.layers[-1].bias, 0) init.constant_(module.reg_conf.layers[-1].weight, 0) - if isinstance(module, nn.LayerNorm): - init.ones_(module.weight) - init.zeros_(module.bias) - if hasattr(module, "weight_embedding") and self.config.learn_initial_query: init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: diff --git a/src/transformers/models/deimv2/modeling_deimv2.py b/src/transformers/models/deimv2/modeling_deimv2.py index 070592ff08f0..fe0f002890c5 100644 --- a/src/transformers/models/deimv2/modeling_deimv2.py +++ b/src/transformers/models/deimv2/modeling_deimv2.py @@ -1087,6 +1087,7 @@ class Deimv2PreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" + super()._init_weights(module) # initialize linear layer bias value according to a given probability value. if isinstance(module, (Deimv2ForObjectDetection, Deimv2Decoder)): if module.class_embed is not None: @@ -1132,15 +1133,6 @@ def _init_weights(self, module): init.xavier_uniform_(module.enc_score_head.weight) init.constant_(module.enc_score_head.bias, bias) - if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - init.zeros_(module.bias) - if getattr(module, "running_mean", None) is not None: - init.zeros_(module.running_mean) - init.ones_(module.running_var) - init.zeros_(module.num_batches_tracked) - if isinstance(module, Deimv2Gate): bias = float(-math.log((1 - 0.5) / 0.5)) init.constant_(module.gate.bias, bias) @@ -1150,15 +1142,10 @@ def _init_weights(self, module): init.constant_(module.reg_conf.layers[-1].bias, 0) init.constant_(module.reg_conf.layers[-1].weight, 0) - if isinstance(module, nn.LayerNorm): - init.ones_(module.weight) - init.zeros_(module.bias) - if hasattr(module, "weight_embedding") and self.config.learn_initial_query: init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: init.xavier_uniform_(module.denoising_class_embed.weight) - PreTrainedModel._init_weights(self, module) if isinstance(module, Deimv2SwiGLUFFN): init.xavier_uniform_(module.gate_proj.weight) @@ -2021,7 +2008,6 @@ class Deimv2ObjectDetectionOutput(ModelOutput): """ ) class Deimv2ForObjectDetection(Deimv2PreTrainedModel): - _no_split_modules = None _tied_weights_keys = { r"bbox_embed.(?![0])\d+": r"bbox_embed.0", r"class_embed.(?![0])\d+": r"^class_embed.0", diff --git a/src/transformers/models/deimv2/modular_deimv2.py b/src/transformers/models/deimv2/modular_deimv2.py index 271f2756df67..e675fd12114c 100644 --- a/src/transformers/models/deimv2/modular_deimv2.py +++ b/src/transformers/models/deimv2/modular_deimv2.py @@ -22,7 +22,6 @@ from ... import initialization as init from ...backbone_utils import load_backbone from ...modeling_outputs import ModelOutput -from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging from ...utils.generic import merge_with_config_defaults @@ -492,7 +491,6 @@ class Deimv2PreTrainedModel(DFinePreTrainedModel): @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) - PreTrainedModel._init_weights(self, module) if isinstance(module, Deimv2SwiGLUFFN): init.xavier_uniform_(module.gate_proj.weight) @@ -853,7 +851,7 @@ def forward( class Deimv2ForObjectDetection(DFineForObjectDetection): - _no_split_modules = None + _no_split_modules = AttributeError() # Don't have the same restriction as DFine @property def _tied_weights_keys(self): From 07e3831133a37487ee6faf57aed965f39c962a43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Ouazan?= <83456801+remi-or@users.noreply.github.com> Date: Thu, 23 Apr 2026 11:34:15 +0200 Subject: [PATCH 17/25] [CB] Changes for long generation (#45530) * Fix KV dedup for decode batches * Fix memory estimation * Change default * Added write-only fast path * Take both peaks into account * Revert unused config field * Review 1 * Fix p1s * Fix p2s and p3s that needed it * Added a TODO * Fix test, lower max cached graph, add TODO * Fix fragmentation with big warmup * Add more space for logits processors * Fix --- .../generation/configuration_utils.py | 18 +- .../generation/continuous_batching/cache.py | 170 ++++++++++++------ .../cb_logits_processors.py | 2 + .../continuous_batching/continuous_api.py | 38 ++-- .../continuous_batching/input_outputs.py | 37 ++-- .../continuous_batching/requests.py | 7 +- .../continuous_batching/scheduler.py | 21 ++- tests/generation/test_continuous_batching.py | 6 +- 8 files changed, 189 insertions(+), 110 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 308c42564295..f601a97959c6 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -1556,8 +1556,10 @@ class ContinuousBatchingConfig: Number of blocks in the KV cache. Auto-inferred from GPU memory when `None`. max_batch_tokens (`int`, *optional*): Maximum number of tokens in a batch. Auto-inferred from GPU memory when `None`. - max_memory_percent (`float`, *optional*, defaults to 0.8): - Maximum percentage of free GPU memory (after the model is loaded) to use for the KV cache. + max_memory_percent (`float`, *optional*): + Maximum percentage of free GPU memory (after the model is loaded) to use for the KV cache. When `None`, + resolved at runtime to 0.9 if there is no logit processing and 0.8 if there is, to leave headroom for + vocabulary-sized temporary tensors. max_blocks_per_request (`int`, *optional*, defaults to 0): Maximum blocks per request, used in the `flash_attn_with_kvcache` fast decode path to dimension the block table. Setting this to 0 disables the fast decode path. @@ -1607,8 +1609,9 @@ class ContinuousBatchingConfig: num_blocks: int | None = None max_batch_tokens: int | None = None - # The max percentage of free GPU memory (after the model is loaded) to use for the KV cache. - max_memory_percent: float = 0.8 + # The max percentage of free GPU memory (after the model is loaded) to use for the KV cache. If None, auto resolved + # to 0.9 (no logit processing) or 0.8 (logit processing) to leave headroom for temporary tensors. + max_memory_percent: float | None = None # This is only used in the flash_attn_with_kvcache fast decode path to dimension the block table. If it is set to 0, # the fast decode path will not be used. Currently turned off by default. @@ -1773,6 +1776,13 @@ def decide_use_async_batching(self, is_attn_mask_needed: bool) -> bool: ) return self.use_async_batching + def resolve_max_memory_percent(self, has_logit_processors: bool) -> None: + """Resolves `max_memory_percent` when unset: 0.9 without logit processors, 0.8 with them. Active processors + materialize `[N, V]` intermediates (e.g. top-p sort, softmax) that get captured into the CUDA graph pool, so + the cache has to cede some budget to that pool.""" + if self.max_memory_percent is None: + self.max_memory_percent = 0.8 if has_logit_processors else 0.9 + def resolve_sentinel_values(self) -> None: """For some parameters (padding intervals and max cached graphs), the default is a sentinel value of 0: that way, if the user specifies a value for those parameters, we know they want it used, ie. we turn on cuda graphs. diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index 9fd0d3afba11..59de60bc957c 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -182,15 +182,30 @@ def __init__( else: num_attention_masks = 1 + # Peak activations coefficients (for number of blocks and number of batch tokens) + q_bytes_per_token = config.num_attention_heads * self.head_dim + lm_head_peak = ( + 0, # number of blocks does not affect the LM head peak activation + config.hidden_size + 2 * config.vocab_size, # hidden states + logits + ) + attention_peak = ( + 2 * page_size, # old K and V, read from cache (in the worst case scenario: whole cache is read) + config.hidden_size + q_bytes_per_token + 2 * page_size, # hidden state + Q + new K and V + ) + memory_handler = PagedAttentionMemoryHandler( - block_size=self.block_size, + continuous_batching_config=continuous_batching_config, page_size=page_size, num_groups=self.num_groups, group_size=group_size, - peak_activation_per_token=(config.hidden_size + config.vocab_size), + activation_peaks=[lm_head_peak, attention_peak], num_attention_masks=num_attention_masks, - continuous_batching_config=continuous_batching_config, ) + + # If somehow the max memory percent is not yet resolved, resolve it conservatively + if continuous_batching_config.max_memory_percent is None: + continuous_batching_config.resolve_max_memory_percent(has_logit_processors=True) + num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens( num_blocks=continuous_batching_config.num_blocks, max_batch_tokens=continuous_batching_config.max_batch_tokens, @@ -316,17 +331,20 @@ def extend_read_and_write_indices( request_id: str, past_length: int, query_length: int, - read_index: list[list[int]], + read_index: list[list[int]] | None, write_index: list[list[int]], ) -> None: """Retrieve physical cache indices for reading KV states in the cache across all layer groups. This method coordinates with all cache managers to build the complete set of read indices needed for attention computation. + When read_index is None, the batch has no cache reads and we only compute the write indices. """ - for cm, read_indices, write_indices in zip(self.group_cache_managers, read_index, write_index): - indices = cm.get_read_indices(request_id, past_length, query_length) - read_indices.extend(indices) - indices = cm.get_write_indices(request_id, past_length, query_length) - write_indices.extend(indices) + # Write indices are always computed + for cm, write_indices in zip(self.group_cache_managers, write_index): + write_indices.extend(cm.get_write_indices(request_id, past_length, query_length)) + # Read indices are only computed if there are cache indices + if read_index is not None: + for cm, read_indices in zip(self.group_cache_managers, read_index): + read_indices.extend(cm.get_read_indices(request_id, past_length, query_length)) def fill_block_table( self, request_id: str, past_length: int, query_length: int, block_table: torch.Tensor @@ -355,26 +373,34 @@ def update( read_index: list[torch.Tensor], # shape [num_layer_groups, seqlen_kv + past_length] write_index: list[torch.Tensor], # shape [num_layer_groups, seqlen_q] ) -> tuple[torch.Tensor, torch.Tensor]: # shape [seqlen_kv + past_length, num_kv_heads, head_dim] - """Update the cache with new key-value states for a specific layer. This method writes new KV states to the - appropriate cache locations. The behavior differs based on the layer's attention type: + """Update the cache with new key-value states for a specific layer, and retrieves the relevant KV states from + the cache for attention computation. The behavior differs based on the layer's attention type: - Full attention: New KV states are written to cache, then complete sequence is read from cache - Sliding window: Old KV is read from cache along with extra spaces for the new KV, then new KV is written to cache. This is because new KV might overwrite the old KV, so we need to read the old KV first. + When the layer's read index is empty, the batch has no cache reads (all requests are non-chunked prefills): we + only write to the cache and return the input KV states directly, skipping the index_select read-back. + Returns the complete KV states (cached + new) for attention computation. """ - # Retrieve the layer read and write indices + # Retrieve the layer write index and the relevant cache tensors group_idx, layer_idx_in_group = self.layer_index_to_group_indices[layer_idx] layer_read_index = read_index[group_idx] layer_write_index = write_index[group_idx] - # Select the correct cache k_cache = self.key_cache[layer_idx_in_group] v_cache = self.value_cache[layer_idx_in_group] # Transpose the key and value states to match the cache shape, after which shape is [seqlen_kv, num_kv_heads, head_dim] key_states = key_states.transpose(1, 2).squeeze(0) value_states = value_states.transpose(1, 2).squeeze(0) + # Case: write-only, no cache read. The input KV states already contain everything the attention needs. + if layer_read_index.numel() == 0: + k_cache.index_copy_(0, layer_write_index, key_states) + v_cache.index_copy_(0, layer_write_index, value_states) + return key_states, value_states + # Case: full attention sliding_window = self.sliding_windows[layer_idx] if sliding_window == 1: @@ -509,25 +535,26 @@ class PagedAttentionMemoryHandler: _activation_dtype = torch.bfloat16 _input_dtype = torch.int32 - _upper_bound_max_batch_tokens = 256 + _upper_bound_max_batch_tokens = 1024 _upper_bound_num_blocks = 4096 def __init__( self, - block_size: int, + continuous_batching_config: ContinuousBatchingConfig, page_size: int, num_groups: int, group_size: int, - peak_activation_per_token: int, + activation_peaks: list[tuple[int, int]], num_attention_masks: int, - continuous_batching_config: ContinuousBatchingConfig, ) -> None: - """Initialize the memory handler.""" - self.block_size = block_size + """Initialize the memory handler. `activation_peaks` is a list of `(Δcn, Δcm)` pairs giving the activation memory + contributions proportional to N (pages) and M (batch tokens) for each peak. Memory must satisfy the constraint + at every peak, so we solve each polynomial independently and take the most restrictive result.""" + self.block_size = continuous_batching_config.block_size self.page_size = page_size self.num_groups = num_groups self.group_size = group_size - self.peak_activation_per_token = peak_activation_per_token + self.activation_peaks = activation_peaks self.num_attention_masks = num_attention_masks self.max_blocks_per_request = continuous_batching_config.max_blocks_per_request or 0 # This is the number of output rows for the output_ids tensor @@ -545,23 +572,29 @@ def get_available_memory(max_memory_percent: float = 1.0) -> int: # Formatting is disabled because of comment indentation, which improves readability. # fmt: off - def _equation_coefficients(self, cache_dtype: torch.dtype) -> tuple[int, int, int, int]: - """Returns (coeff_n, coeff_m, coeff_nm, coeff_mm) for the memory polynomial. Each addend is annotated with - the tensor it corresponds to in `ContinuousBatchingIOs._setup_static_tensors`. + def _equation_coefficients( + self, peak: tuple[int, int], cache_dtype: torch.dtype + ) -> tuple[int, int, int, int]: + """Returns `(coeff_n, coeff_m, coeff_nm, coeff_mm)` for the memory polynomial of a single activation peak. + `peak = (Δcn, Δcm)` is the peak-specific activation contribution; the rest of the coefficients are shared + across peaks. Each addend is annotated with the tensor it corresponds to in + `ContinuousBatchingIOs._setup_static_tensors` (or the forward pass, for activation terms). """ i = self._input_dtype.itemsize # int32 a = self._activation_dtype.itemsize # bfloat16 c = cache_dtype.itemsize k = self.io_multiplier # 1 sync, 2 async (IO tensors only) + delta_n, delta_m = peak # -- N terms: cost per cache page -------------------------------------------------- coeff_n = ( 2 * self.group_size * self.page_size * c # kv_cache: 2 * group_size * [N, page_size] * cache_dtype + k * self.num_groups * 8 # read_index: [num_groups, N + M] (N part only, int64) + + delta_n * a # activation peak: N-proportional part ) # -- M terms: cost per batch token ------------------------------------------------- coeff_m = ( - self.peak_activation_per_token * a # activation peak (largest hidden state per token) + delta_m * a # activation peak: M-proportional part + k * 7 * i # bulk_input: [7, M] int32, packed as 7 rows + k * self.num_output_rows * i # output_ids: [num_output_rows, M] int32 + k * self.num_groups # block_table: [bt_groups, M, max_blocks_per_req] int32 @@ -569,9 +602,9 @@ def _equation_coefficients(self, cache_dtype: torch.dtype) -> tuple[int, int, in + k * self.num_groups * 8 # write_index: [num_groups, M] int64 + k * self.num_groups * 8 # read_index: [num_groups, N + M] (M part only, int64) ) - # -- N·M terms: cost per (page × batch token) ------------------------------------- + # -- N·M terms: cost per (page × batch token) -------------------------------------- coeff_nm = k * self.num_attention_masks * a # attention_mask: [1, 1, M, N + M] (N·M part only) - # -- M² terms: cost per (batch token squared) ------------------------------------- + # -- M² terms: cost per (batch token squared) -------------------------------------- coeff_mm = k * self.num_attention_masks * a # attention_mask: [1, 1, M, N + M] (M² part only) return coeff_n, coeff_m, coeff_nm, coeff_mm @@ -590,55 +623,80 @@ def _solve_quadratic(a: float, b: float, c: float) -> float: raise ValueError(f"No positive solution (root = {root})") return root - def infer_num_blocks_and_max_batch_tokens( + def _solve_for_peak( self, - num_blocks: int | None = None, - max_batch_tokens: int | None = None, - max_memory_percent: float = 0.8, # FIXME: it seems we overcommit memory, was changed from 0.9 which caused OOMs in our benchmarking CI - cache_dtype: torch.dtype = torch.float16, + peak: tuple[int, int], + available: int, + num_blocks: int | None, + max_batch_tokens: int | None, + cache_dtype: torch.dtype, ) -> tuple[int, int]: - """Solve for the missing variable(s) in the memory polynomial (see ``_equation_coefficients``). When both - are unknown, assumes M = m·N (m = 0.01, i.e. one batch fills ~1 % of the cache) and solves the resulting - quadratic in N. - """ - available = self.get_available_memory(max_memory_percent) - coeff_n, coeff_m, coeff_nm, coeff_mm = self._equation_coefficients(cache_dtype) - logger.info(f"Cache memory: {available}") + """Solve for `(num_blocks, max_batch_tokens)` against one activation peak's memory polynomial. Clamps to upper + bounds. Either input may be None; whichever is None is solved for.""" + cn, cm, cnm, cmm = self._equation_coefficients(peak, cache_dtype) if num_blocks is None and max_batch_tokens is None: # Substitute M = m·N → (coeff_nm·m + coeff_mm·m²)·N² + (coeff_n + coeff_m·m)·N − avail = 0 m = 0.01 - num_pages = self._solve_quadratic( - coeff_nm * m + coeff_mm * m**2, - coeff_n + coeff_m * m, - -available, - ) - num_blocks = min(floor(num_pages) // self.block_size, self._upper_bound_num_blocks) - max_batch_tokens = min(int(num_pages * m), self._upper_bound_max_batch_tokens) - - elif num_blocks is None: + num_pages = self._solve_quadratic(cnm * m + cmm * m**2, cn + cm * m, -available) + max_batch_tokens = int(num_pages * m) + if max_batch_tokens > self._upper_bound_max_batch_tokens: + max_batch_tokens = self._upper_bound_max_batch_tokens + # If max_batch_tokens is clamped, we recompute num_blocks below to get a higher value + num_blocks = None + else: + num_blocks = min(floor(num_pages) // self.block_size, self._upper_bound_num_blocks) + + if num_blocks is None: # M given → linear in N: (coeff_n + coeff_nm·M)·N = avail − coeff_m·M − coeff_mm·M² M = max_batch_tokens - num_pages = floor((available - coeff_m * M - coeff_mm * M**2) / (coeff_n + coeff_nm * M)) + num_pages = floor((available - cm * M - cmm * M**2) / (cn + cnm * M)) num_blocks = min(num_pages // self.block_size, self._upper_bound_num_blocks) - elif max_batch_tokens is None: # N given → quadratic in M: coeff_mm·M² + (coeff_m + coeff_nm·N)·M + (coeff_n·N − avail) = 0 N = num_blocks * self.block_size - M = self._solve_quadratic(coeff_mm, coeff_m + coeff_nm * N, coeff_n * N - available) + M = self._solve_quadratic(cmm, cm + cnm * N, cn * N - available) max_batch_tokens = min(floor(M), self._upper_bound_max_batch_tokens) + return num_blocks, max_batch_tokens + + def infer_num_blocks_and_max_batch_tokens( + self, + num_blocks: int | None = None, + max_batch_tokens: int | None = None, + max_memory_percent: float = 0.9, + cache_dtype: torch.dtype = torch.float16, + ) -> tuple[int, int]: + """Solve for the missing variable(s) in the memory polynomial (see ``_equation_coefficients``). There is one + polynomial per activation peak; we solve each independently and take the most restrictive (smallest) result. + When both `N` and `M` are unknown, assumes `M = m·N` (m = 0.01, i.e. one batch fills ~1 % of the cache) and + solves the resulting quadratic in N. + """ + available = self.get_available_memory(max_memory_percent) + logger.info(f"Cache memory: {available}") + # Solve each peak independently, then take the element-wise min (tightest constraint wins) + acc_num_blocks = float("inf") + acc_max_batch_tokens = float("inf") + for peak in self.activation_peaks: + n_blocks, m_batch_tokens = self._solve_for_peak(peak, available, num_blocks, max_batch_tokens, cache_dtype) + acc_num_blocks = min(acc_num_blocks, n_blocks) + acc_max_batch_tokens = min(acc_max_batch_tokens, m_batch_tokens) + # Now update the value (cannot update in loop, it would overwrite the user-passed values) + num_blocks, max_batch_tokens = acc_num_blocks, acc_max_batch_tokens # Validate - memory_footprint = self.compute_memory_footprint( - max_batch_tokens=max_batch_tokens, num_blocks=num_blocks, cache_dtype=cache_dtype - ) + memory_footprint = self.compute_memory_footprint(num_blocks, max_batch_tokens, cache_dtype) if memory_footprint > available: raise MemoryError(f"Memory footprint {memory_footprint} is more than available memory {available}") return num_blocks, max_batch_tokens def compute_memory_footprint(self, num_blocks: int, max_batch_tokens: int, cache_dtype: torch.dtype) -> int: - """Evaluate the memory polynomial at concrete (N, M) values.""" + """Evaluate the memory polynomial at concrete (N, M) values, taking the max across activation peaks.""" N = num_blocks * self.block_size M = max_batch_tokens - cn, cm, cnm, cmm = self._equation_coefficients(cache_dtype) - return cn * N + cm * M + cnm * N * M + cmm * M * M + + max_memory_footprint = 0 + for peak in self.activation_peaks: + cn, cm, cnm, cmm = self._equation_coefficients(peak, cache_dtype) + memory_footprint = cn * N + cm * M + cnm * N * M + cmm * M * M + max_memory_footprint = max(max_memory_footprint, memory_footprint) + return max_memory_footprint diff --git a/src/transformers/generation/continuous_batching/cb_logits_processors.py b/src/transformers/generation/continuous_batching/cb_logits_processors.py index 3a5f7eb8df26..619d9fefea5e 100644 --- a/src/transformers/generation/continuous_batching/cb_logits_processors.py +++ b/src/transformers/generation/continuous_batching/cb_logits_processors.py @@ -319,6 +319,8 @@ def __call__(self, scores: torch.FloatTensor, tensor_arg: torch.Tensor) -> torch return scores.masked_fill(indices_to_remove, self.filter_value) +# TODO: add non-per-request CB variants so the memory-efficient warpers work when `per_request_processors=False`. +# TODO: fuse temperature + top-k + top-p into a single pass to reuse the softmax/sort and cut activation peak. CLASSIC_TO_CB_PROCESSORS_MAP = { TemperatureLogitsWarper: ContinuousBatchingTemperatureLogitsWarper, TopKLogitsWarper: ContinuousBatchingTopKLogitsWarper, diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 47290b9d70b6..0521c6402ca9 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -623,26 +623,18 @@ def _sample(self, scores: torch.Tensor, logits_indices: torch.Tensor, output_ids output_ids[1, :tokens].copy_(logprobs.view(dtype=torch.int32)) @torch.inference_mode() - def warmup( - self, - model: nn.Module, - logit_processor: LogitsProcessorList, - num_query_tokens: int = 0, - num_cache_tokens: int = 0, - ) -> None: + def warmup(self, model: nn.Module) -> None: """Pre-capture CUDA graphs (or trigger compile warmup) for varlen and decode paths. In async mode, both IO - pairs are warmed up since each has its own graph buffer and static tensors.""" + pairs are warmed up since each has its own graph buffer and static tensors. The varlen path is warmed up at + the largest possible `(q, kv)` sizes so subsequent captures fit inside it without growing the pool.""" if not self._pad_inputs: logger.info("CUDA graphs and compile are disabled, skipping warmup.") return None - num_query_tokens = num_query_tokens if num_query_tokens > 0 else self.max_batch_tokens - num_query_tokens = min(num_query_tokens, self.max_batch_tokens) - num_cache_tokens = num_cache_tokens if num_cache_tokens > 0 else self.cache.block_size * num_query_tokens - num_cache_tokens = min(num_cache_tokens, self.cache.num_blocks * self.cache.block_size) - + num_query_tokens = self.max_batch_tokens num_pages = self.cache.num_blocks * self.cache.block_size + num_cache_tokens = num_pages - num_query_tokens compute_stream = self.inputs_and_outputs.compute_stream # In async mode, each IO pair has its own graph buffer and static tensors, so we warm up both @@ -677,7 +669,7 @@ def warmup( forward_fn(*forward_fn_args) logger.info(f"Varlen warmup completed in {perf_counter() - start:.2f}s") except Exception as e: - logger.warning(f"Failed to warm up varlen path: {e}") + logger.warning(f"Failed to warm up varlen path: {e}. Graph pool may fragment and OOM under load.") finally: for fs in future_states: self.cache.free_blocks(fs.state.request_id) @@ -811,12 +803,12 @@ def is_running(self) -> bool: """Check if the background generation thread is running.""" return self._generation_thread is not None and self._generation_thread.is_alive() - def warmup(self, num_query_tokens: int = 0, num_cache_tokens: int = 0) -> None: + def warmup(self) -> None: """Pre-capture CUDA graphs for varlen and decode paths by running dummy batches. Initializes the batch processor if not already done.""" if self.batch_processor is None: self.batch_processor = self._create_batch_processor() - self.batch_processor.warmup(self.model, self.logit_processor, num_query_tokens, num_cache_tokens) + self.batch_processor.warmup(self.model) self.warmed_up = True # NOTE: don't forget to update `continuous_batching_context_manager` when changing this method's definition @@ -1040,6 +1032,8 @@ def _generation_step(self) -> None: self.batch_processor._generation_step(self.model) def _create_batch_processor(self) -> ContinuousBatchProcessor: + # Resolve max_memory_percent now that we know whether any logit processors are active. + self.continuous_batching_config.resolve_max_memory_percent(self.logit_processor.do_processing) # Create the PagedAttentionCache paged_attention_cache = PagedAttentionCache( self.model.config, @@ -1225,25 +1219,25 @@ def continuous_batching_context_manager( timeout: float | None = None, continuous_batching_config: ContinuousBatchingConfig | None = None, persistent_manager: bool = False, - warmup_requests: int | None = 0, + warmup: bool = True, **deprecated_kwargs, ) -> Generator[ContinuousBatchingManager]: """A context manager to safely use the continuous batching manager. Arguments are similar to the ones of `init_continuous_batching`, except for: - block: whether to block the thread when stopping the manager. Default is True. - timeout: maximum time to wait for the thread to stop. Default is None (no timeout). - - warmup_query_tokens: the number of expected requests for which to warmup. 0 is auto, None is no warmup. + - warmup: whether to pre-capture CUDA graphs at the largest sizes before running. Default is True. """ manager = self.init_continuous_batching( generation_config=generation_config, continuous_batching_config=continuous_batching_config, **deprecated_kwargs, ) - if not (warmup_requests is None or manager.warmed_up): + if warmup and not manager.warmed_up: # Warmup is long (~30 sec): best to signal the user it's happening than let them think the manager is stuck - logger.warning("Warming up for coninuous batching...") + logger.warning("Warming up for continuous batching...") start = perf_counter() - manager.warmup(num_query_tokens=warmup_requests, num_cache_tokens=0) + manager.warmup() logger.warning(f"Warming up completed in {perf_counter() - start:.2f}s.") manager.start() try: @@ -1320,7 +1314,7 @@ def generate_batch( block=True, timeout=5, persistent_manager=persistent_manager, - warmup_requests=len(inputs) if warmup else None, + warmup=warmup, **deprecated_kwargs, ) logging_cm = logging_redirect_tqdm([logger]) diff --git a/src/transformers/generation/continuous_batching/input_outputs.py b/src/transformers/generation/continuous_batching/input_outputs.py index 134941c2526f..fbe7890a15b9 100644 --- a/src/transformers/generation/continuous_batching/input_outputs.py +++ b/src/transformers/generation/continuous_batching/input_outputs.py @@ -14,7 +14,6 @@ from contextlib import nullcontext from dataclasses import dataclass from functools import partial -from itertools import count from typing import Any import torch @@ -250,10 +249,11 @@ def _transfer_inputs( # Only transfer block_table for decode-only batches (when it's actually used) if self.use_block_table: other.block_table.copy_(self.block_table, non_blocking=non_blocking) - # Otherwise, we transfer the read and write indices + # Otherwise, we transfer the write indices (and read indices if the batch uses any cache reads) else: other.write_index_storage.copy_(self.write_index_storage, non_blocking=non_blocking) - other.read_index_storage.copy_(self.read_index_storage, non_blocking=non_blocking) + if self.max_kv_read > 0: + other.read_index_storage.copy_(self.read_index_storage, non_blocking=non_blocking) # Transfer the attention masks if needed if self.attention_mask is not None and other.attention_mask is not None: for layer_type in self.attention_mask.keys(): @@ -373,14 +373,15 @@ def prepare_batch_tensors( self.requests_in_batch = [] self.req_id_to_new_token_position = {} - # Prepare accumulators + # Prepare accumulators. For batches with no past cache to read, we leave read_index empty: the cache.update + # will detect the 0-size indices and skip the read. input_ids = [] position_ids = [] cumulative_seqlens_q = [0] logits_indices = [] cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k.keys()} - read_index = [[] for _ in range(self.cache.num_groups)] write_index = [[] for _ in range(self.cache.num_groups)] + read_index = None if self.max_kv_read == 0 else [[] for _ in range(self.cache.num_groups)] # Go through all the requests in the batch for i, future_state in enumerate(requests_in_batch): @@ -448,14 +449,16 @@ def prepare_batch_tensors( sliding_window=self.sliding_window if layer_type == "sliding_attention" else 1, ) - # If we are not using the block table, we populate the read and write indices + # If we are not using the block table, we populate the write indices (and maybe the read indices) if not self.use_block_table: to_index_tensor = partial(torch.tensor, dtype=torch.int64, device=self.device) - for i, group_read_indices, group_write_indices in zip(count(), read_index, write_index): - self.read_index_storage[i, : len(group_read_indices)] = to_index_tensor(group_read_indices) + for i, group_write_indices in enumerate(write_index): self.write_index_storage[i, : len(group_write_indices)] = to_index_tensor(group_write_indices) - self.true_read_sizes[i] = len(group_read_indices) self.true_write_sizes[i] = len(group_write_indices) + if read_index is not None: + for i, group_read_indices in enumerate(read_index): + self.read_index_storage[i, : len(group_read_indices)] = to_index_tensor(group_read_indices) + self.true_read_sizes[i] = len(group_read_indices) def get_model_kwargs(self, use_padding: bool = False) -> dict[str, Any]: """Get model keyword arguments for the current batch, eventually padding the query dimension and KV dimensions @@ -500,10 +503,14 @@ def get_model_kwargs(self, use_padding: bool = False) -> dict[str, Any]: # For the attributes that are lists of tensors, we construct list of tensor references for i in range(self.cache.num_groups): - read_index_size = kv_size if use_padding else self.true_read_sizes[i] write_index_size = q_size if use_padding else self.true_write_sizes[i] - kwargs.read_index.append(self.read_index_storage[i, :read_index_size]) kwargs.write_index.append(self.write_index_storage[i, :write_index_size]) + # If there is no cache to read, pass a list of empty tensors so `cache.update` uses the write-only fast path + if self.max_kv_read == 0: + read_index_size = 0 + else: + read_index_size = kv_size if use_padding else self.true_read_sizes[i] + kwargs.read_index.append(self.read_index_storage[i, :read_index_size]) # For the attributes that are dict of tensors, we first fill the dict with the actual values for layer_type, seqlens_k in self.cumulative_seqlens_k.items(): @@ -531,11 +538,11 @@ def get_cb_kwargs(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.carry_over_ids, self.output_ids, self.output_ids def _get_graph_key(self) -> tuple[int, ...]: - # Keys for varlen path - if self.max_kv_read > 0: - return (self.num_q_tokens, self.max_kv_read, *self.max_seqlen_k.values()) # Keys for decode fast path - return (self.num_q_tokens,) + if self.use_block_table: + return (self.num_q_tokens,) + # Keys for varlen path + return (self.num_q_tokens, self.max_kv_read, *self.max_seqlen_k.values()) def get_graph(self) -> torch.cuda.CUDAGraph | None: key = self._get_graph_key() diff --git a/src/transformers/generation/continuous_batching/requests.py b/src/transformers/generation/continuous_batching/requests.py index 05bf65725c5a..381c94bc2dc9 100644 --- a/src/transformers/generation/continuous_batching/requests.py +++ b/src/transformers/generation/continuous_batching/requests.py @@ -27,6 +27,7 @@ import psutil # This is a temporary token ID used to represent a token that is not yet generated +# TODO: update this to 0 and check it breaks nothing + simplify carry over and time new logic TMP_TOKEN_ID = -1 @@ -45,9 +46,11 @@ def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]: device = torch.device("cuda") torch.cuda.empty_cache() torch.cuda.synchronize() - total_memory = torch.cuda.get_device_properties(device).total_memory + # Use mem_get_info to get actual free memory: device_properties().total_memory returns the physical device + # total which ignores CUDA context and driver overhead (~0.5 GiB), leading to overcommit. + free_memory, total_memory = torch.cuda.mem_get_info(device) reserved_memory = torch.cuda.memory_reserved(device) - allocated_memory = torch.cuda.memory_allocated(device) + allocated_memory = total_memory - free_memory elif is_torch_xpu_available(): device = torch.device("xpu") torch.xpu.empty_cache() diff --git a/src/transformers/generation/continuous_batching/scheduler.py b/src/transformers/generation/continuous_batching/scheduler.py index f35d2e968342..284c202267c5 100644 --- a/src/transformers/generation/continuous_batching/scheduler.py +++ b/src/transformers/generation/continuous_batching/scheduler.py @@ -205,7 +205,7 @@ def _process_candidates( """ scheduled_requests = [] one_allocation_failed = False - decode_fast_path = True + decode_fast_path = self.cache.max_blocks_per_request > 0 # best way to check if decode fast path availability safety_margins = safety_margin * self.cache.num_blocks original_token_budget, original_cache_budget = token_budget, cache_budget @@ -219,17 +219,22 @@ def _process_candidates( ) break - # Check cache budget + # Infer the tokens that will be present in the batch if token budget is enough + request_tokens = self._infer_request_tokens(state, request_ids_to_remove_from_waiting) + # Account for token budget + request_len = min(len(request_tokens), token_budget) + + # This block checks cache budget: decode batches have infinite budget, but varlen batches don't, because KV + # cache is read through a fixed-sized index tensor. We keep track of the current budget in case the batch + # goes from decode to varlen + is_decode_eligible = request_len == 1 and state.position_offset < self.max_decode_fast_path_length read_cache_needed = state.current_len() if self.read_cache_limit is not None: read_cache_needed = min(read_cache_needed, self.read_cache_limit) - if cache_budget < read_cache_needed: + # A request that would change the batch from decode to varlen is rejected if the cache budget is too low + if not (decode_fast_path and is_decode_eligible) and cache_budget < read_cache_needed: continue - # Infer the tokens that will be present in the batch if token budget is enough - request_tokens = self._infer_request_tokens(state, request_ids_to_remove_from_waiting) - # Account for token budget - request_len = min(len(request_tokens), token_budget) # Check there will be enough cache for the new tokens allocation_successful = self._allocate_blocks_if_needed(state, request_len) @@ -273,7 +278,7 @@ def _process_candidates( request_ids_to_remove_from_waiting.add(req_id) # Early exit of the loop if we have no budget left - if token_budget == 0 or cache_budget == 0: + if token_budget == 0 or (cache_budget <= 0 and not decode_fast_path): break num_q_tokens = original_token_budget - token_budget diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index ff3e54be374f..cd7c95f7bf4e 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -1274,16 +1274,16 @@ def test_memory_prediction( max_blocks_per_request=max_bpr, return_logprobs=logprobs, use_async_batching=use_async_batching, + block_size=block_size, ) handler = PagedAttentionMemoryHandler( - block_size=block_size, + continuous_batching_config=cb_config, page_size=page_size, num_groups=num_groups, group_size=group_size, - peak_activation_per_token=peak_act, + activation_peaks=[(0, peak_act)], num_attention_masks=num_attn_masks, - continuous_batching_config=cb_config, ) N = self.NUM_BLOCKS * block_size # num_pages From 706acf5c2e6783ce55a479fbc9b3e2d31c736508 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 23 Apr 2026 07:19:33 -0400 Subject: [PATCH 18/25] Allow for registered experts from kernels hub (#45577) * Allow for registered experts from kernels hub * remove deepgemm as that is also dynamic * Apply repo consistency fixes * Update src/transformers/modeling_utils.py * Update src/transformers/modeling_utils.py * Apply repo consistency fixes * Apply suggestion from @IlyasMoutawwakil * Apply repo consistency fixes * get rid of triton dependency * keep eager first --------- Co-authored-by: github-actions[bot] Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Co-authored-by: IlyasMoutawwakil --- .../integrations/finegrained_fp8.py | 18 +++++++++++------- src/transformers/modeling_utils.py | 11 ++++++++--- tests/utils/test_modeling_utils.py | 14 ++++++++++++++ 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 213b91e3a115..a6b9a517b20d 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -13,7 +13,6 @@ # limitations under the License. import torch import torch.nn as nn -import triton from torch.nn import functional as F from ..activations import ACT2FN @@ -159,6 +158,11 @@ def _load_deepgemm_kernel(): _deepgemm_available = True +def _cdiv(a: int, b: int) -> int: + """Ceiling division.""" + return (a + b - 1) // b + + def w8a8_fp8_matmul( A: torch.Tensor, B: torch.Tensor, @@ -603,8 +607,8 @@ def __init__( if self.has_gate: gu_proj_out, gu_proj_in = 2 * self.intermediate_dim, self.hidden_dim self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, gu_proj_out, gu_proj_in, dtype=dtype)) - gu_scale_out = triton.cdiv(gu_proj_out, self.block_size[0]) if self.block_size is not None else 1 - gu_scale_in = triton.cdiv(gu_proj_in, self.block_size[1]) if self.block_size is not None else 1 + gu_scale_out = _cdiv(gu_proj_out, self.block_size[0]) if self.block_size is not None else 1 + gu_scale_in = _cdiv(gu_proj_in, self.block_size[1]) if self.block_size is not None else 1 self.gate_up_proj_scale_inv = nn.Parameter( torch.empty(self.num_experts, gu_scale_out, gu_scale_in, dtype=torch.float32) ) @@ -612,8 +616,8 @@ def __init__( else: u_proj_out, u_proj_in = self.intermediate_dim, self.hidden_dim self.up_proj = nn.Parameter(torch.empty(self.num_experts, u_proj_out, u_proj_in, dtype=dtype)) - u_scale_out = triton.cdiv(u_proj_out, self.block_size[0]) if self.block_size is not None else 1 - u_scale_in = triton.cdiv(u_proj_in, self.block_size[1]) if self.block_size is not None else 1 + u_scale_out = _cdiv(u_proj_out, self.block_size[0]) if self.block_size is not None else 1 + u_scale_in = _cdiv(u_proj_in, self.block_size[1]) if self.block_size is not None else 1 self.up_proj_scale_inv = nn.Parameter( torch.empty(self.num_experts, u_scale_out, u_scale_in, dtype=torch.float32) ) @@ -621,8 +625,8 @@ def __init__( d_proj_out, d_proj_in = self.hidden_dim, self.intermediate_dim self.down_proj = nn.Parameter(torch.empty(self.num_experts, d_proj_out, d_proj_in, dtype=dtype)) - d_scale_out = triton.cdiv(d_proj_out, self.block_size[0]) if self.block_size is not None else 1 - d_scale_in = triton.cdiv(d_proj_in, self.block_size[1]) if self.block_size is not None else 1 + d_scale_out = _cdiv(d_proj_out, self.block_size[0]) if self.block_size is not None else 1 + d_scale_in = _cdiv(d_proj_in, self.block_size[1]) if self.block_size is not None else 1 self.down_proj_scale_inv = nn.Parameter( torch.empty(self.num_experts, d_scale_out, d_scale_in, dtype=torch.float32) ) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index eb092019b678..d58c9a52fd33 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -66,10 +66,12 @@ ) from .integrations.deepspeed import _load_state_dict_into_zero3_model from .integrations.eager_paged import eager_paged_attention_forward +from .integrations.finegrained_fp8 import ALL_FP8_EXPERTS_FUNCTIONS from .integrations.flash_attention import flash_attention_forward from .integrations.flash_paged import paged_attention_forward from .integrations.flex_attention import flex_attention_forward from .integrations.hub_kernels import allow_all_hub_kernels, is_kernel +from .integrations.moe import ALL_EXPERTS_FUNCTIONS from .integrations.peft import maybe_load_adapters from .integrations.sdpa_attention import sdpa_attention_forward from .integrations.sdpa_paged import sdpa_attention_paged_forward @@ -1969,11 +1971,14 @@ def get_correct_attn_implementation(self, requested_attention: str | None, is_in def get_correct_experts_implementation(self, requested_experts: str | None) -> str: applicable_experts = "grouped_mm" if requested_experts is None else requested_experts - if applicable_experts not in ["eager", "grouped_mm", "batched_mm", "deepgemm"]: + base_experts_fns = ["eager"] + list(set(ALL_EXPERTS_FUNCTIONS.keys()) | set(ALL_FP8_EXPERTS_FUNCTIONS.keys())) + valid_experts_str_list = [f'`experts_implementation="{fn}"`' for fn in base_experts_fns] + valid_experts_str_list[-1] = "and " + valid_experts_str_list[-1] + valid_experts_str = ", ".join(valid_experts_str_list) + if applicable_experts not in base_experts_fns: message = ( f'Specified `experts_implementation="{applicable_experts}"` is not supported. The only possible arguments are ' - '`experts_implementation="eager"`, `"experts_implementation=grouped_mm"`, `"experts_implementation=batched_mm"` ' - 'and `"experts_implementation=deepgemm"`.' + f"{valid_experts_str}." ) raise ValueError(message) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 6a27b6b5e0fb..fab48f9ddb8a 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -2823,6 +2823,20 @@ def test_error_wrong_attn_implementation(self): self.assertTrue('The only possible arguments are `attn_implementation="eager"' in str(cm.exception)) + def test_registered_experts_implementation_is_valid(self): + from transformers.integrations.moe import ALL_EXPERTS_FUNCTIONS + + def custom_experts_forward(*args, **kwargs): + pass + + experts_implementation = "custom_experts" + model = BaseModel(PreTrainedConfig()) + + with patch.dict(ALL_EXPERTS_FUNCTIONS._global_mapping, {}, clear=False): + ALL_EXPERTS_FUNCTIONS.register(experts_implementation, custom_experts_forward) + + self.assertEqual(model.get_correct_experts_implementation(experts_implementation), experts_implementation) + def test_not_available_flash(self): if is_flash_attn_2_available(): self.skipTest(reason="Please uninstall flash-attn package to run test_not_available_flash") From bd69ed2ad7979e8896d01fbc2fa5090d424fc8a8 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Thu, 23 Apr 2026 05:15:40 -0700 Subject: [PATCH 19/25] [docs] multi-turn tool calling (#45554) * docs * feedback --- docs/source/en/serve-cli/serving.md | 95 +++++++++++++++++++++++++---- 1 file changed, 84 insertions(+), 11 deletions(-) diff --git a/docs/source/en/serve-cli/serving.md b/docs/source/en/serve-cli/serving.md index 783eb0c8dd87..83dcb9e88d9a 100644 --- a/docs/source/en/serve-cli/serving.md +++ b/docs/source/en/serve-cli/serving.md @@ -456,7 +456,7 @@ data: {"id":"f47ac10b-58cc-4372-a567-0e02b2c3d479","choices":[{"delta":{"content ### Audio-based completions -Multimodal models like [Gemma 4](https://huggingface.co/google/gemma-4-E2B-it) and [Qwen2.5-Omni](https://huggingface.co/Qwen/Qwen2.5-Omni-3B) accept audio input using the OpenAI `input_audio` content type. The audio must be base64-encoded and the format (`mp3` or `wav`) must be specified. +Multimodal models like [Gemma 4](https://huggingface.co/google/gemma-4-E2B-it) and [Qwen2.5-Omni](https://huggingface.co/Qwen/Qwen2.5-Omni-3B) accept audio input through the OpenAI `input_audio` content type. Base64-encode the audio and specify the format (`mp3` or `wav`). @@ -695,7 +695,7 @@ data: {"id":"cb997e1d-98b9-414a-be89-1880288610ef","choices":[{"delta":{"content > [!WARNING] > The `audio_url` content type is an extension not part of the OpenAI standard and may change in future versions. -As a convenience, audio can also be passed by URL using the `audio_url` content type, avoiding the need for base64 encoding. +You can also pass audio by URL with the `audio_url` content type to skip base64 encoding. ```python completion = client.chat.completions.create( @@ -717,7 +717,7 @@ completion = client.chat.completions.create( > [!WARNING] > The `video_url` content type is an extension not part of the OpenAI standard and may change in future versions. -Video input is supported using the `video_url` content type. If the model supports audio (e.g. Gemma 4, Qwen2.5-Omni), the audio track is automatically extracted from the video and processed alongside the visual frames. +Use the `video_url` content type for video input. If the model supports audio (e.g. Gemma 4, Qwen2.5-Omni), the server extracts the audio track from the video and processes it with the visual frames. > [!TIP] > Video processing requires [torchcodec](https://github.com/pytorch/torchcodec). Install it with `pip install torchcodec`. @@ -934,7 +934,7 @@ data: {"id":"cb997e1d-98b9-414a-be89-1880288610ef","choices":[{"delta":{"content -### Multi-turn conversations +### Multi-turn conversations[[completions]] To have a multi-turn conversation, include the full conversation history in the `messages` list with alternating `user` and `assistant` roles. Like all OpenAI-compatible servers, the API is stateless, so every request must contain the complete conversation history. @@ -954,7 +954,7 @@ completion = client.chat.completions.create( print(completion.choices[0].message.content) ``` -The follow-up question "How many people live there?" relies on the prior context, and the model answers about Paris accordingly. +The follow-up question "How many people live there?" relies on the prior context, so the model answers about Paris. ``` As of 2021, the population of Paris is approximately 2.2 million people. @@ -1466,7 +1466,7 @@ data: {"content_index":0,"delta":"This ","item_id":"msg_a1b2c3d4","output_index" > [!WARNING] > The `audio_url` content type is an extension not part of the OpenAI standard and may change in future versions. -As a convenience, audio can also be passed by URL using the `audio_url` content type, avoiding the need for base64 encoding. +You can also pass audio by URL with the `audio_url` content type to skip base64 encoding. ```python response = client.responses.create( @@ -1621,7 +1621,7 @@ data: {"content_index":0,"delta":"Based ","item_id":"msg_b2c3d4e5","output_index -### Multi-turn conversations +### Multi-turn conversations[[responses]] For multi-turn conversations, pass a list of messages with `role` keys in the `input` field. Like all OpenAI-compatible servers, the API is stateless, so every request must contain the complete conversation history. @@ -1643,7 +1643,7 @@ response = client.responses.create( print(response.output[0].content[0].text) ``` -The follow-up question "How many people live there?" relies on the prior context, and the model answers about Paris accordingly. +The follow-up question "How many people live there?" relies on the prior context, so the model answers about Paris. ``` As of 2021, Paris has a population of approximately 2.8 million people. @@ -1734,7 +1734,7 @@ The stream ends with exactly one terminal event, `ready` (success) or `error` (f ## Timeout -`transformers serve` supports different requests by different models. Each model loads on demand and stays in GPU memory. Models unload automatically after 300 seconds of inactivity to free up GPU memory. Set `--model-timeout` to a different value in seconds, or `-1` to disable unloading entirely. +`transformers serve` handles requests for any model. Each model loads on demand and stays in GPU memory. Models unload automatically after 300 seconds of inactivity to free GPU memory. Set `--model-timeout` to a different value in seconds, or `-1` to disable unloading. ```shell transformers serve --model-timeout 400 @@ -1742,7 +1742,7 @@ transformers serve --model-timeout 400 ### Loading examples -See the example responses below for a freshly downloaded model, a model loaded from your local cache (skips the download stage), and a model that already exists in memory. +The examples below show responses for a freshly downloaded model, a model loaded from your local cache (skips the download stage), and a model already in memory. @@ -1784,7 +1784,7 @@ data: {"status": "ready", "model": "org/model@main", "cached": true} The `transformers serve` server supports OpenAI-style function calling. Models trained for tool-use generate structured function calls that your application executes. > [!NOTE] -> Tool calling is currently limited to the Qwen model family. +> Tool calling works with any model whose tokenizer declares tool call tokens. Qwen and Gemma 4 work out of the box. Open an [issue](https://github.com/huggingface/transformers/issues/new/choose) to request support for a specific model. Define tools as a list of function specifications following the OpenAI format. @@ -1846,6 +1846,79 @@ for event in response: print(event) ``` +### Multi-turn tool calling + +After the model returns a tool call, execute the function locally, then send the result back in a follow-up request to get the model's final answer. The pattern differs slightly between the two APIs. See the [OpenAI function calling guide](https://developers.openai.com/api/docs/guides/function-calling?api-mode=chat) for the full spec. + +The examples below reuse the `tools` list defined above. + + + + +Pass the tool result as a `role: "tool"` message with the matching `tool_call_id`. + +```py +# Model returns a tool call +messages = [{"role": "user", "content": "What's the weather like in San Francisco?"}] +response = client.chat.completions.create( + model="Qwen/Qwen2.5-7B-Instruct", + messages=messages, + tools=tools, +) +assistant_message = response.choices[0].message + +# Execute the tool locally +tool_call = assistant_message.tool_calls[0] +result = {"temperature": 22, "condition": "sunny"} # your actual function call here + +# Send the tool result back +messages.append(assistant_message) +messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "content": json.dumps(result), +}) +final_response = client.chat.completions.create( + model="Qwen/Qwen2.5-7B-Instruct", + messages=messages, + tools=tools, +) +print(final_response.choices[0].message.content) +``` + + + + +Pass the tool result as a `function_call_output` item in the `input` list of the follow-up request. + +```py +user_message = {"role": "user", "content": "What's the weather like in San Francisco?"} +response = client.responses.create( + model="Qwen/Qwen2.5-7B-Instruct", + input=[user_message], + tools=tools, + stream=False, +) +tool_call = next(item for item in response.output if item.type == "function_call") + +result = {"temperature": 22, "condition": "sunny"} + +final_response = client.responses.create( + model="Qwen/Qwen2.5-7B-Instruct", + input=[ + user_message, + tool_call.model_dump(exclude_none=True), + {"type": "function_call_output", "call_id": tool_call.call_id, "output": json.dumps(result)}, + ], + tools=tools, + stream=False, +) +print(final_response.output_text) +``` + + + + ## Port forwarding Port forwarding lets you serve models from a remote server. Make sure you have SSH access to the server, then run this command on your local machine. From 8e64e5334f59a1819e7538ed8f1e4ae90b14e315 Mon Sep 17 00:00:00 2001 From: BADAOUI Abdennacer <106801897+Abdennacer-Badaoui@users.noreply.github.com> Date: Thu, 23 Apr 2026 14:20:50 +0200 Subject: [PATCH 20/25] [AMD CI] Fix expectations for Gemma3n (#45602) update expectations for gemma3n --- tests/models/gemma3n/test_modeling_gemma3n.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/models/gemma3n/test_modeling_gemma3n.py b/tests/models/gemma3n/test_modeling_gemma3n.py index 0d6d7e0446d0..65a622163c88 100644 --- a/tests/models/gemma3n/test_modeling_gemma3n.py +++ b/tests/models/gemma3n/test_modeling_gemma3n.py @@ -993,7 +993,7 @@ def test_model_4b_bf16(self): output_text = self.processor.batch_decode(output, skip_special_tokens=True) EXPECTED_TEXTS = Expectations({ ("cuda", None): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a clear blue ocean. The cow is facing the viewer with its head slightly'], - ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean under a clear blue sky. The cow is facing the viewer'], + ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a clear blue ocean. The cow is facing the viewer with its head slightly'], }).get_expectation() # fmt: skip self.assertEqual(output_text, EXPECTED_TEXTS) @@ -1077,7 +1077,7 @@ def test_model_4b_batch(self): output_text = self.processor.batch_decode(output, skip_special_tokens=True) EXPECTED_TEXTS = Expectations({ ("cuda", None): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a clear blue ocean. The cow is facing the viewer with its head slightly', "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Subject:** The first image features a cow"], - ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean under a clear blue sky. The cow is facing the viewer', "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Subject Matter:** The first image shows a"], + ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. The sky is blue with a few white clouds. The', "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Subject:** The first image features a cow"], ("xpu", None): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. The cow is facing the viewer with its head slightly turned', "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Subject:** The first image features a cow"], }).get_expectation() # fmt: skip self.assertEqual(output_text, EXPECTED_TEXTS) @@ -1104,7 +1104,7 @@ def test_model_4b_image(self): EXPECTED_TEXTS = Expectations({ ("cuda", None): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a clear blue ocean. The cow is facing the viewer with its head slightly'], ("xpu", None): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a clear blue ocean. The cow is facing the viewer with its head slightly'], - ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean under a clear blue sky. The cow is facing the viewer'], + ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a clear blue ocean. The cow is facing the viewer with its head slightly'], }).get_expectation() # fmt: skip self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES) self.assertEqual(output_text, EXPECTED_TEXTS) @@ -1146,7 +1146,7 @@ def test_model_4b_multiimage(self): EXPECTED_TEXTS = Expectations({ ("cuda", None): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nIn the image, I see a street scene in what appears to be a Chinatown district. Here are some of the key elements:\n\n* **A'], ("xpu", None): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nIn the image, I see a street scene in what appears to be a Chinatown district. Here are the key elements:\n\n* **A prominent red'], - ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nIn the image, I see a street scene in what appears to be a Chinatown district. \n\nHere are some key elements:\n\n* **A'], + ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nIn the image, I see a street scene in what appears to be a Chinatown district. Here are some of the key elements:\n\n* **A'], }).get_expectation() # fmt: skip self.assertEqual(output_text, EXPECTED_TEXTS) @@ -1191,7 +1191,7 @@ def test_generation_beyond_sliding_window(self): EXPECTED_COMPLETIONS = Expectations({ ("cuda", None): [" and the people are so friendly. I'm so glad I came here. I'm so", ", green, yellow, orange, purple, pink, brown, black, white.\n\nHere'"], - ("rocm", (9, 4)): [" and the food is delicious. I'm so glad I came here. I'm so glad", ", green, yellow, orange, purple, pink, brown, black, white.\n\nHere'"], + ("rocm", (9, 4)): [' and the food is delicious. The staff is friendly and helpful. The atmosphere is relaxed and welcoming.', ", green, yellow, orange, purple, pink, brown, black, white.\n\nHere'"], }).get_expectation() # fmt: skip self.assertEqual(output_text, EXPECTED_COMPLETIONS) From 03238980c9f197c407c4d1f205bf7b702f6fefd4 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 23 Apr 2026 08:22:14 -0400 Subject: [PATCH 21/25] fix transformers + torchao nvfp4 serialization (#45573) Summary: 1. fix torchao NVFP4 serialization with transformers 2. add a test to cover the fix While i'm here, also did the following bundled into this PR: 3. make the torchao serialization test have human readable names (easier to debug) 4. fix the float8 test (update the expected output) after this PR the test command for all torchao configs passes on an NVIDIA B200 Test Plan: ``` RUN_SLOW=1 pytest tests/quantization/torchao_integration/test_torchao.py -k "Serialization" -s ``` Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- .../quantizers/quantizer_torchao.py | 1 + .../torchao_integration/test_torchao.py | 26 ++++++++++++------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index a76f73aeb562..fd117b08023b 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -184,6 +184,7 @@ def get_weight_conversions(self): source_patterns=[ "_weight_qdata", "_weight_scale_and_zero", + "_weight_per_tensor_scale", "_weight_scale", "_weight_zero_point", "_weight_act_pre_scale", diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index ebcc08816d95..b188b4f9a0c3 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -39,6 +39,7 @@ from torchao.dtypes import ( AffineQuantizedTensor, ) + from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8Tensor, @@ -587,13 +588,14 @@ class TorchAoSerializationTest(unittest.TestCase): test_params = ( [ - (Int8WeightOnlyConfig(version=2), ALL_DEVICES_COMMON), - (Int8DynamicActivationInt8WeightConfig(version=2), ALL_DEVICES_COMMON), - (Float8DynamicActivationFloat8WeightConfig(), Expectations({("cuda", None): "What are we having for dinner?\n\nJess: (smiling) I", ("xpu", None): "What are we having for dinner?\n\nJess: (smiling) I"})), - (Float8WeightOnlyConfig(), Expectations({("cuda", None): COMMON_OUTPUT, ("xpu", None): COMMON_OUTPUT})), - (Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), Expectations({("cuda", None): "What are we having for dinner?\nRed, white, and green beans,", ("xpu", None): COMMON_OUTPUT})), - (Int8DynamicActivationIntxWeightConfig(), Expectations({("cpu", None): COMMON_OUTPUT, ("cuda", 9): COMMON_OUTPUT, ("cuda", 8): "What are we having for dinner?\n\nJEN: (smiling) I", ("xpu", None): COMMON_OUTPUT})), - (IntxWeightOnlyConfig(), ALL_DEVICES_COMMON), + ("Int8WeightOnlyConfig", Int8WeightOnlyConfig(version=2), ALL_DEVICES_COMMON), + ("Int8DynamicActivationInt8WeightConfig", Int8DynamicActivationInt8WeightConfig(version=2), ALL_DEVICES_COMMON), + ("Float8DynamicActivationFloat8WeightConfig", Float8DynamicActivationFloat8WeightConfig(), Expectations({("cuda", None): COMMON_OUTPUT, ("xpu", None): "What are we having for dinner?\n\nJess: (smiling) I"})), + ("Float8WeightOnlyConfig", Float8WeightOnlyConfig(), Expectations({("cuda", None): COMMON_OUTPUT, ("xpu", None): COMMON_OUTPUT})), + ("Int4WeightOnlyConfig", Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), Expectations({("cuda", None): "What are we having for dinner?\nRed, white, and green beans,", ("xpu", None): COMMON_OUTPUT})), + ("Int8DynamicActivationIntxWeightConfig", Int8DynamicActivationIntxWeightConfig(), Expectations({("cpu", None): COMMON_OUTPUT, ("cuda", 9): COMMON_OUTPUT, ("cuda", 8): "What are we having for dinner?\n\nJEN: (smiling) I", ("xpu", None): COMMON_OUTPUT})), + ("IntxWeightOnlyConfig", IntxWeightOnlyConfig(), ALL_DEVICES_COMMON), + ("NVFP4DynamicActivationNVFP4WeightConfig", NVFP4DynamicActivationNVFP4WeightConfig(), Expectations({("cuda", None): "What are we having for dinner?\n\n10. Avoid using \"I"})), ] if is_torchao_available() else [] @@ -609,8 +611,12 @@ def _check_serialization(self, device, config, expected_output): if isinstance(config, (Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig)): if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 9): self.skipTest(f"{type(config).__name__} requires CUDA capability >= (8, 9)") + if isinstance(config, NVFP4DynamicActivationNVFP4WeightConfig): + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (10, 0): + self.skipTest(f"{type(config).__name__} requires CUDA capability >= (10, 0) (SM100)") quant_config = TorchAoConfig(config) - dtype = torch.bfloat16 if isinstance(config, Int4WeightOnlyConfig) else "auto" + needs_bfloat16 = isinstance(config, Int4WeightOnlyConfig | NVFP4DynamicActivationNVFP4WeightConfig) + dtype = torch.bfloat16 if needs_bfloat16 else "auto" quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, dtype=dtype, @@ -629,7 +635,7 @@ def _check_serialization(self, device, config, expected_output): self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), expected_output) @parameterized.expand(test_params, skip_on_empty=True) - def test_serialization_cpu(self, config, expected_outputs): + def test_serialization_cpu(self, _name, config, expected_outputs): try: expected = expected_outputs.find_expectation(("cpu", None, None)) except ValueError: @@ -638,7 +644,7 @@ def test_serialization_cpu(self, config, expected_outputs): @parameterized.expand(test_params, skip_on_empty=True) @require_torch_accelerator - def test_serialization_accelerator(self, config, expected_outputs): + def test_serialization_accelerator(self, _name, config, expected_outputs): try: expected = expected_outputs.get_expectation() except ValueError: From 533c4e1a4ca714f2953e74f0e510853f08defaf9 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Thu, 23 Apr 2026 15:24:35 +0200 Subject: [PATCH 22/25] SonicMoe (#45433) * added sonic moe * use lazy_load_kernel * style * use concatenated revision * final touches * fix * merge conflict * simpler naming * style * add sonicmoe test * skip fp32 on sonic * add transposed support * fix --------- Co-authored-by: vasqu --- src/transformers/integrations/hub_kernels.py | 1 + src/transformers/integrations/moe.py | 13 +- src/transformers/integrations/sonicmoe.py | 124 ++++++++++++++++++ .../models/gpt_oss/modeling_gpt_oss.py | 2 +- .../models/gpt_oss/modular_gpt_oss.py | 2 +- .../modular_openai_privacy_filter.py | 2 + tests/test_modeling_common.py | 99 +++++++------- 7 files changed, 188 insertions(+), 55 deletions(-) create mode 100644 src/transformers/integrations/sonicmoe.py diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index b1e6c74ddf10..70a343424aa8 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -289,6 +289,7 @@ def register_kernel_mapping_transformers(*args, **kwargs): "falcon_mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1}, "finegrained-fp8": {"repo_id": "kernels-community/finegrained-fp8", "version": 1}, "deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1}, + "sonic-moe": {"repo_id": "kernels-community/sonic-moe", "version": 1}, } _KERNEL_MODULE_MAPPING: dict[str, ModuleType | None] = {} diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index d17522d26daa..c8a8e87f3621 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -23,6 +23,7 @@ is_torch_less_or_equal, is_torchdynamo_compiling, ) +from .sonicmoe import sonicmoe_experts_forward if is_torch_available(): @@ -31,6 +32,7 @@ logger = logging.get_logger(__name__) + # Examples of experts class with its eager mm implementation # class Experts(torch.nn.Module): # """Collection of expert weights stored as 3D tensors.""" @@ -458,6 +460,7 @@ class ExpertsInterface(GeneralInterface): """Interface for registering custom experts forward functions.""" _global_mapping = { + "sonicmoe": sonicmoe_experts_forward, "batched_mm": batched_mm_experts_forward, "grouped_mm": grouped_mm_experts_forward, } @@ -498,6 +501,7 @@ def use_experts_implementation( experts_class: type[torch.nn.Module] | None = None, *, experts_interface: ExpertsInterface = ALL_EXPERTS_FUNCTIONS, + is_concatenated: bool = True, is_transposed: bool = False, has_bias: bool = False, has_gate: bool = True, @@ -509,10 +513,16 @@ def use_experts_implementation( The experts class to modify. If not provided, returns a decorator that can be applied to the class. experts_interface (`ExpertsInterface`, *optional*, defaults to `ALL_EXPERTS_FUNCTIONS`): The experts interface to use for dispatching the forward method. + is_concatenated (`bool`, *optional*, defaults to `True`): + Whether the expert weights are stored in concatenated layout [gate;up] + or interleaved layout [gate0, up0, gate1, up1, ...]. is_transposed (`bool`, *optional*, defaults to `False`): Whether the expert weights are stored in transposed format. has_bias (`bool`, *optional*, defaults to `False`): - Whether the expert layers include bias terms. + Whether the expert layers include bias terms or not. + has_gate (`bool`, *optional*, defaults to `True`): + Whether the experts use a gating mechanism or not. + Whether it has gate_up_proj weights or just up_proj weights. Returns: `type[torch.nn.Module]`: The modified experts class. @@ -529,6 +539,7 @@ def __init__(self, config, *args, **kwargs): self.has_gate = has_gate self.has_bias = has_bias self.is_transposed = is_transposed + self.is_concatenated = is_concatenated @wraps(original_forward) def forward(self, *args, **kwargs): diff --git a/src/transformers/integrations/sonicmoe.py b/src/transformers/integrations/sonicmoe.py new file mode 100644 index 000000000000..e322bb4bc061 --- /dev/null +++ b/src/transformers/integrations/sonicmoe.py @@ -0,0 +1,124 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SonicMoE integration: fused MoE using CuteDSL kernels from `kernels-community/sonic-moe`. + +Provides `sonicmoe_experts_forward` registered as "sonicmoe" in the ExpertsInterface. +Requirements: CUDA, `kernels`, `nvidia-cutlass-dsl`, has_gate=True. +""" + +import functools + +import torch + +from ..utils import logging +from .hub_kernels import lazy_load_kernel + + +logger = logging.get_logger(__name__) + +# Map activation function names from HF config to SonicMoE epilogue names +ACT_MAP = {"silu": "swiglu", "gelu": "geglu", "relu": "reglu"} + + +@functools.cache +def _load_sonic_kernel(): + """ + Load sonic-moe once and return its required symbols. + + Raises: + ImportError if the kernel or required symbols are not found. + + Returns: + Tuple of (ActivationType, moe_general_routing_inputs function) from the sonic-moe kernel. + """ + + kernel = lazy_load_kernel("sonic-moe") + if kernel is None: + raise ImportError( + "sonic-moe kernel not found. Make sure you have the `kernels` and `nvidia-cutlass-dsl` packages installed." + ) + + ActivationType = getattr(getattr(kernel, "enums", None), "ActivationType", None) + moe_general_routing_inputs = getattr(kernel, "moe_general_routing_inputs", None) + + missing = [ + name + for name, attr in [ + ("enums.ActivationType", ActivationType), + ("moe_general_routing_inputs", moe_general_routing_inputs), + ] + if attr is None + ] + if missing: + raise ImportError( + f"sonic-moe kernel is missing required symbols: {', '.join(missing)}. " + "Make sure you have the `kernels` package and `nvidia-cutlass-dsl` installed." + ) + + return ActivationType, moe_general_routing_inputs + + +def sonicmoe_experts_forward( + self: torch.nn.Module, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + if not self.has_gate: + raise ValueError("sonicmoe requires gated experts (has_gate=True)") + if hidden_states.device.type != "cuda": + raise ValueError("sonicmoe requires CUDA device") + + ActivationType, moe_general_routing_inputs = _load_sonic_kernel() + + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + + # Flatten — token_indices must be int32, sorted ascending (required by sonic-moe) + token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1).int() + router_scores = top_k_weights.reshape(-1).to(hidden_states.dtype) + expert_ids = top_k_index.reshape(-1).int() + + # Map activation function + act_name = getattr(self.config, "hidden_act", "silu").lower() + activation_type = getattr(ActivationType, ACT_MAP.get(act_name, "swiglu").upper(), ActivationType.SWIGLU) + + # Permute weights as expected by sonic-moe (E=num_experts, H=hidden_size, I=intermediate_size). + # Non-transposed: gate_up_proj is (E, 2*I, H), down_proj is (E, H, I) -> permute(1, 2, 0). + # Transposed: gate_up_proj is (E, H, 2*I), down_proj is (E, I, H) -> permute(2, 1, 0). + perm = (2, 1, 0) if self.is_transposed else (1, 2, 0) + w1 = self.gate_up_proj.permute(*perm) # (2*I, H, E) + w2 = self.down_proj.permute(*perm) # (I, H, E) + b1 = self.gate_up_proj_bias if self.has_bias else None + b2 = self.down_proj_bias if self.has_bias else None + + output, _ = moe_general_routing_inputs( + hidden_states, + router_scores, + token_idx, + expert_ids, + w1, + b1, + w2, + b2, + E=self.num_experts, + activation_type=activation_type, + stream_id=torch.cuda.current_stream(device).cuda_stream, + is_inference_mode_enabled=not torch.is_grad_enabled(), + concat_layout=self.is_concatenated, + ) + + return output diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 00f9ac601b0e..55381a7e3c21 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -65,7 +65,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -@use_experts_implementation(is_transposed=True, has_bias=True) +@use_experts_implementation(is_concatenated=False, is_transposed=True, has_bias=True) class GptOssExperts(nn.Module): def __init__(self, config): super().__init__() diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index f7c89cab08e5..3354acef2196 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -62,7 +62,7 @@ def forward(self, hidden_states): return (self.weight * hidden_states).to(input_dtype) # main diff with Llama -@use_experts_implementation(is_transposed=True, has_bias=True) +@use_experts_implementation(is_concatenated=False, is_transposed=True, has_bias=True) class GptOssExperts(nn.Module): def __init__(self, config): super().__init__() diff --git a/src/transformers/models/openai_privacy_filter/modular_openai_privacy_filter.py b/src/transformers/models/openai_privacy_filter/modular_openai_privacy_filter.py index fc77aafbdcf5..422235d9da91 100644 --- a/src/transformers/models/openai_privacy_filter/modular_openai_privacy_filter.py +++ b/src/transformers/models/openai_privacy_filter/modular_openai_privacy_filter.py @@ -21,6 +21,7 @@ from torch.nn import functional as F from ...configuration_utils import PreTrainedConfig +from ...integrations import use_experts_implementation from ...masking_utils import create_bidirectional_sliding_window_mask from ...modeling_layers import GenericForTokenClassification from ...modeling_outputs import BaseModelOutput @@ -213,6 +214,7 @@ def forward( return attn_output, attn_weights +@use_experts_implementation(is_transposed=True, has_bias=True) class OpenAIPrivacyFilterExperts(GptOssExperts): def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: # Concatenated layout instead of interleaving diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index b909212b62cd..bc8f65891445 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -51,7 +51,11 @@ is_deepspeed_zero3_enabled, unset_hf_deepspeed_config, ) -from transformers.integrations.moe import batched_mm_experts_forward, grouped_mm_experts_forward +from transformers.integrations.moe import ( + batched_mm_experts_forward, + grouped_mm_experts_forward, + sonicmoe_experts_forward, +) from transformers.modeling_layers import GradientCheckpointingLayer from transformers.modeling_utils import FLASH_ATTN_KERNEL_FALLBACK, _get_tied_weight_keys from transformers.models.auto import get_values @@ -110,6 +114,7 @@ GENERATION_CONFIG_NAME, SAFE_WEIGHTS_NAME, ModelOutput, + is_kernels_available, is_torch_bf16_available_on_device, is_torch_fp16_available_on_device, ) @@ -576,59 +581,49 @@ def _test_eager_matches_batched_and_grouped_inference(self, name, dtype): model.save_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname).eval().to(torch_device).to(dtype) - with torch.no_grad(): - inputs_dict = {k: v.to(dtype) if torch.is_floating_point(v) else v for k, v in inputs_dict.items()} - prepared_inputs = self._prepare_for_class(inputs_dict, model_class) - - mock_batched_mm_forward = Mock(wraps=batched_mm_experts_forward) - mock_grouped_mm_forward = Mock(wraps=grouped_mm_experts_forward) - with ( - # This is needed because we call the functions through the interface's global mapping - patch.dict( - "transformers.integrations.moe.ALL_EXPERTS_FUNCTIONS._global_mapping", - {"batched_mm": mock_batched_mm_forward, "grouped_mm": mock_grouped_mm_forward}, - ), - ): - model.set_experts_implementation("eager") - self.assertEqual(model.config._experts_implementation, "eager") - outputs_eager = model(**prepared_inputs) - mock_batched_mm_forward.assert_not_called() - mock_grouped_mm_forward.assert_not_called() + inputs_dict = {k: v.to(dtype) if torch.is_floating_point(v) else v for k, v in inputs_dict.items()} + prepared_inputs = self._prepare_for_class(inputs_dict, model_class) - mock_batched_mm_forward.reset_mock() - mock_grouped_mm_forward.reset_mock() + implementations = ["eager", "batched_mm", "grouped_mm"] + mocks = { + "batched_mm": Mock(wraps=batched_mm_experts_forward), + "grouped_mm": Mock(wraps=grouped_mm_experts_forward), + } - model.set_experts_implementation("batched_mm") - self.assertEqual(model.config._experts_implementation, "batched_mm") - outputs_batched_mm = model(**prepared_inputs) - mock_grouped_mm_forward.assert_not_called() - mock_batched_mm_forward.assert_called() - - mock_batched_mm_forward.reset_mock() - mock_grouped_mm_forward.reset_mock() - - model.set_experts_implementation("grouped_mm") - self.assertEqual(model.config._experts_implementation, "grouped_mm") - outputs_grouped_mm = model(**prepared_inputs) - mock_batched_mm_forward.assert_not_called() - mock_grouped_mm_forward.assert_called() - - mock_batched_mm_forward.reset_mock() - mock_grouped_mm_forward.reset_mock() - - # extract output tensors for comparison - outputs_eager = _get_output_tensors(outputs_eager) - outputs_batched_mm = _get_output_tensors(outputs_batched_mm) - outputs_grouped_mm = _get_output_tensors(outputs_grouped_mm) - - # make sure we have collected some tensors from the outputs - self.assertTrue(outputs_eager, "No outputs from eager implementation") - self.assertTrue(outputs_batched_mm, "No outputs from batched_mm implementation") - self.assertTrue(outputs_grouped_mm, "No outputs from grouped_mm implementation") - - # make sure all implementations give numerically close outputs - torch.testing.assert_close(outputs_eager, outputs_batched_mm, rtol=1e-4, atol=1e-4) - torch.testing.assert_close(outputs_eager, outputs_grouped_mm, rtol=1e-4, atol=1e-4) + if ( + dtype != torch.float32 + and is_kernels_available() + and torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (9, 0) + ): + # we also need nvidia-cutlass-dsl and apache-tvm-ffi + mocks["sonicmoe"] = Mock(wraps=sonicmoe_experts_forward) + implementations.append("sonicmoe") + + outputs = {} + # This is needed because we call the functions through the interface's global mapping + with patch.dict("transformers.integrations.moe.ALL_EXPERTS_FUNCTIONS._global_mapping", mocks): + for impl in implementations: + model.set_experts_implementation(impl) + self.assertEqual(model.config._experts_implementation, impl) + + with torch.no_grad(): + outputs[impl] = _get_output_tensors(model(**prepared_inputs)) + + self.assertTrue(outputs[impl], f"No outputs from {impl} implementation") + + for name, mock in mocks.items(): + if name == impl: + mock.assert_called() + else: + mock.assert_not_called() + + mock.reset_mock() + + # all non-eager implementations must numerically match eager + eager_outputs = outputs.pop("eager") + for impl, impl_outputs in outputs.items(): + torch.testing.assert_close(eager_outputs, impl_outputs, rtol=1e-4, atol=1e-4) def _config_zero_init(config): From 1e071b25731afa4c9c8fda059ee15198efe5f99d Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Thu, 23 Apr 2026 10:30:02 -0400 Subject: [PATCH 23/25] Processing Utils: continue when content is a string (#45605) fix: continue when content is a string --- src/transformers/processing_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index bf5e0c431e42..bb1344a43dcf 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -1813,6 +1813,8 @@ def apply_chat_template( images, videos = [], [] for message in conversation: content = message.get("content") or [] + if isinstance(content, str): + continue visuals = [ content_block for content_block in content if content_block["type"] in ["image", "video"] ] From 57f9936a2619d2f2d4af89bde34d5eb611c2b728 Mon Sep 17 00:00:00 2001 From: Tarek Ziade Date: Thu, 23 Apr 2026 16:45:34 +0200 Subject: [PATCH 24/25] qa: bumped mlinter and allow local override (#45585) * qa: bumped mlinter and allow local override * bump version * Update utils/check_modeling_rules_doc.py Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> * license header * license header --------- Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> --- docs/source/en/modeling_rules.md | 16 +- setup.py | 4 +- src/transformers/dependency_versions_table.py | 2 +- utils/check_modeling_rules_doc.py | 49 ++-- utils/check_modeling_structure.py | 29 +- utils/rules.toml | 251 ++++++++++++++++++ 6 files changed, 323 insertions(+), 28 deletions(-) create mode 100644 utils/rules.toml diff --git a/docs/source/en/modeling_rules.md b/docs/source/en/modeling_rules.md index d3b6e48bd7c4..0591a79f89b3 100644 --- a/docs/source/en/modeling_rules.md +++ b/docs/source/en/modeling_rules.md @@ -13,22 +13,22 @@ specific language governing permissions and limitations under the License. # Model structure rules -Transformers enforces a set of static rules on every `modeling_*.py`, `modular_*.py`, and `configuration_*.py` file. The [mlinter](https://github.com/huggingface/transformers-mlinter) tool checks them as part of `make typing` and errors out if violations are found. +Transformers enforces a set of static rules on every `modeling_*.py`, `modular_*.py`, and `configuration_*.py` file. The [mlinter](https://github.com/huggingface/transformers-mlinter) package provides the checker engine, and the repository keeps its active rule set in `utils/rules.toml`. That local TOML lets us enable, disable, or tweak rules quickly without waiting for a new `transformers-mlinter` release. These are the expected model conventions for adding or changing modeling code. They keep the codebase consistent and ensure compatibility with features like pipeline parallelism, device maps, and weight tying. ## Running the checker -`make typing` runs `mlinter` alongside the `ty` type checker. Run `mlinter` on its own with the following commands. +`make typing` runs `mlinter` alongside the `ty` type checker through the repo wrapper, so it picks up `utils/rules.toml`. Run the same wrapper directly with the following commands. ```bash -mlinter # check all modeling files -mlinter --changed-only # check only files changed vs origin/main -mlinter --list-rules # list all rules and their enabled status -mlinter --rule TRF001 # show built-in docs for a specific rule +python utils/check_modeling_structure.py # check all modeling files +python utils/check_modeling_structure.py --changed-only # check only files changed vs origin/main +python utils/check_modeling_structure.py --list-rules # list all rules and their enabled status +python utils/check_modeling_structure.py --rule TRF001 # show built-in docs for a specific rule ``` -The `--changed-only` flag is the fastest option during development. It only checks the files you've modified relative to the main branch. +The `--changed-only` flag is the fastest option during development. It only checks the files you've modified relative to the main branch. If you invoke `mlinter` directly instead of the wrapper, pass `--rules-toml utils/rules.toml` so local overrides are applied. ## Fixing a violation @@ -52,7 +52,7 @@ Use the rule ID to look up the fix in the [rules reference](#rules-reference). T ## Rules reference -Each rule below lists what it enforces and a diff showing the fix. Run `mlinter --rule TRF001` to see the built-in docs for any rule. +Each rule below lists what it enforces and a diff showing the fix. Run `python utils/check_modeling_structure.py --rule TRF001` to see the built-in docs for any rule with the repo's current rule set. diff --git a/setup.py b/setup.py index 2e6adca0315c..42c865b1b9ba 100644 --- a/setup.py +++ b/setup.py @@ -124,7 +124,9 @@ "rjieba", "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", "ruff==0.14.10", - "transformers-mlinter==0.1.0", + # When bumping `transformers-mlinter`, sync repo-local rule overrides from + # `utils/rules.toml` back into the released package. + "transformers-mlinter==0.1.1", "ty==0.0.20", # `sacrebleu` not used in `transformers`. However, it is needed in several tests, when a test calls # `evaluate.load("sacrebleu")`. This metric is used in the examples that we use to test the `Trainer` with, in the diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 399b0be222e9..1a721ca2a82a 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -56,7 +56,7 @@ "rjieba": "rjieba", "rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", "ruff": "ruff==0.14.10", - "transformers-mlinter": "transformers-mlinter==0.1.0", + "transformers-mlinter": "transformers-mlinter==0.1.1", "ty": "ty==0.0.20", "sacrebleu": "sacrebleu>=1.4.12,<2.0.0", "sacremoses": "sacremoses", diff --git a/utils/check_modeling_rules_doc.py b/utils/check_modeling_rules_doc.py index 24e7b17fd925..8eaf8e57012d 100644 --- a/utils/check_modeling_rules_doc.py +++ b/utils/check_modeling_rules_doc.py @@ -13,7 +13,7 @@ # limitations under the License. """ Keep `## Rules reference` section of docs/source/en/modeling_rules.md in sync -with the rules defined in the mlinter package. +with the rules defined in utils/rules.toml via the installed mlinter package. Usage (from the root of the repo): @@ -31,21 +31,22 @@ """ import argparse -import os +from pathlib import Path CHECKER_CONFIG = { "name": "modeling_rules_doc", "label": "Modeling rules documentation", - # Depends on the installed `mlinter` package output, which cannot be expressed - # as repo file globs for the checker cache. + # Depends on utils/rules.toml plus the installed `mlinter` package output, + # which cannot be fully expressed as repo file globs for the checker cache. "file_globs": None, - "check_args": [], - "fix_args": ["--fix_and_overwrite"], + "check_args": ["--rules-toml", "utils/rules.toml"], + "fix_args": ["--rules-toml", "utils/rules.toml", "--fix_and_overwrite"], } -ROOT = os.path.dirname(os.path.dirname(__file__)) -DOC_PATH = os.path.join(ROOT, "docs", "source", "en", "modeling_rules.md") +ROOT = Path(__file__).resolve().parent.parent +DOC_PATH = ROOT / "docs" / "source" / "en" / "modeling_rules.md" +RULES_TOML_PATH = ROOT / "utils" / "rules.toml" BEGIN_MARKER = "" END_MARKER = "" @@ -54,21 +55,29 @@ def _require_mlinter(): try: import mlinter + from mlinter import mlinter as mlinter_impl except ModuleNotFoundError as error: raise ModuleNotFoundError( "This script requires the standalone `transformers-mlinter` package. " 'Install the repo quality dependencies with `pip install -e ".[quality]"` and retry.' ) from error - return mlinter + return mlinter, mlinter_impl -def generate_rules_reference() -> str: - return _require_mlinter().render_rules_reference() +def _resolve_path(path: Path) -> Path: + return path if path.is_absolute() else ROOT / path -def check_modeling_rules_doc(overwrite: bool = False): - with open(DOC_PATH, encoding="utf-8") as f: +def generate_rules_reference(rule_specs_path: Path = RULES_TOML_PATH) -> str: + mlinter, mlinter_impl = _require_mlinter() + # Reuse mlinter's registry-switching helper so docs rendering reflects the repo-local rule file. + with mlinter_impl._using_rule_specs(_resolve_path(rule_specs_path)): + return mlinter.render_rules_reference() + + +def check_modeling_rules_doc(overwrite: bool = False, rule_specs_path: Path = RULES_TOML_PATH): + with DOC_PATH.open(encoding="utf-8") as f: content = f.read() begin_idx = content.find(BEGIN_MARKER) @@ -80,7 +89,7 @@ def check_modeling_rules_doc(overwrite: bool = False): ) after_begin = begin_idx + len(BEGIN_MARKER) - expected = "\n\n" + generate_rules_reference() + "\n" + expected = "\n\n" + generate_rules_reference(rule_specs_path) + "\n" current = content[after_begin:end_idx] if current == expected: @@ -88,22 +97,28 @@ def check_modeling_rules_doc(overwrite: bool = False): if overwrite: new_content = content[:after_begin] + expected + content[end_idx:] - with open(DOC_PATH, "w", encoding="utf-8") as f: + with DOC_PATH.open("w", encoding="utf-8") as f: f.write(new_content) print(f"Updated rules reference in {DOC_PATH}") else: raise ValueError( "The rules reference section in docs/source/en/modeling_rules.md is out of sync " - "with the mlinter package's rules. Run `make fix-repo` to regenerate it." + "with utils/rules.toml. Run `make fix-repo` to regenerate it." ) if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument( + "--rules-toml", + type=Path, + default=RULES_TOML_PATH, + help="Path to a rules TOML file. Defaults to utils/rules.toml.", + ) parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") args = parser.parse_args() try: - check_modeling_rules_doc(args.fix_and_overwrite) + check_modeling_rules_doc(args.fix_and_overwrite, args.rules_toml) except ModuleNotFoundError as error: raise SystemExit(str(error)) from error diff --git a/utils/check_modeling_structure.py b/utils/check_modeling_structure.py index 447eabf8b8a6..6078672d7349 100644 --- a/utils/check_modeling_structure.py +++ b/utils/check_modeling_structure.py @@ -1,6 +1,23 @@ #!/usr/bin/env python +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Thin local entrypoint for the external mlinter package.""" +import sys +from pathlib import Path + + CHECKER_CONFIG = { "name": "modeling_structure", "label": "Modeling file structure", @@ -9,10 +26,12 @@ "src/transformers/models/**/modular_*.py", "src/transformers/models/**/configuration_*.py", ], - "check_args": [], + "check_args": ["--rules-toml", "utils/rules.toml"], "fix_args": None, } +RULES_TOML_PATH = Path(__file__).resolve().with_name("rules.toml") + def _require_mlinter(): try: @@ -26,8 +45,16 @@ def _require_mlinter(): return mlinter +def _add_default_rules_toml(argv: list[str]) -> list[str]: + if any(arg == "--rules-toml" or arg.startswith("--rules-toml=") for arg in argv[1:]): + return argv + + return [argv[0], "--rules-toml", str(RULES_TOML_PATH), *argv[1:]] + + if __name__ == "__main__": try: + sys.argv = _add_default_rules_toml(sys.argv) raise SystemExit(_require_mlinter().main()) except ModuleNotFoundError as error: raise SystemExit(str(error)) from error diff --git a/utils/rules.toml b/utils/rules.toml new file mode 100644 index 000000000000..1c7de0e729b0 --- /dev/null +++ b/utils/rules.toml @@ -0,0 +1,251 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file can carry repo-local rule overrides for faster iteration between +# `transformers-mlinter` releases. +# Keep it synced with the upstream package's rules.toml when possible so local +# behavior does not drift from the published checker longer than necessary. + +version = 1 + +[rules.TRF001] +description = "Class-level config_class on PreTrainedModel should match Config naming." +default_enabled = true +allowlist_models = ["qwen3_omni_moe"] + +[rules.TRF001.explanation] +what_it_does = "Checks naming consistency between PreTrainedModel and config_class." +why_bad = "Mismatched config_class can break loading, auto classes, and developer expectations." +diff = ''' + class AcmePreTrainedModel(PreTrainedModel): +- config_class = WileConfig ++ config_class = AcmeConfig +''' + +[rules.TRF002] +description = "base_model_prefix should be a non-empty canonical string when defined on PreTrainedModel classes." +default_enabled = true +allowlist_models = ["lighton_ocr"] + +[rules.TRF002.explanation] +what_it_does = "Checks that base_model_prefix, when set, is a non-empty, whitespace-free string literal." +why_bad = "Invalid prefixes can break weight loading key mapping and base model access patterns." +diff = ''' + class AcmePreTrainedModel(PreTrainedModel): +- base_model_prefix = "" ++ base_model_prefix = "model" +''' + +[rules.TRF003] +description = "forward() should use capture_output/can_return_tuple decorators instead of manual return_dict branching." +default_enabled = false +allowlist_models = [] + +[rules.TRF003.explanation] +what_it_does = "Detects forward methods that use the old 'if not return_dict: return (x,)' pattern." +why_bad = "The old return_dict branching pattern is error-prone and verbose. Use the capture_output or can_return_tuple decorators instead." +diff = ''' +-def forward(self, x, return_dict=None): +- if not return_dict: +- return (x,) +- return AcmeModelOutput(last_hidden_state=x) ++@can_return_tuple ++def forward(self, x): ++ return AcmeModelOutput(last_hidden_state=x) +''' + +[rules.TRF004] +description = "Models must never override tie_weights. Use _tied_weights_keys instead." +default_enabled = true +allowlist_models = ["data2vec", "hubert", "sew", "sew_d", "unispeech", "unispeech_sat", "wav2vec2", "wav2vec2_conformer", "wavlm"] + +[rules.TRF004.explanation] +what_it_does = "Checks that no model class defines a tie_weights method." +why_bad = "Overriding tie_weights leads to bad consequences for loading, device_map computation, and saving. Use _tied_weights_keys class attribute to declare tied weights instead." +diff = ''' +-def tie_weights(self): +- self.lm_head.weight = self.emb.weight ++class AcmeForCausalLM(AcmePreTrainedModel): ++ _tied_weights_keys = ["lm_head.weight"] +''' + +[rules.TRF005] +description = "_no_split_modules, when defined, should be a list/tuple of non-empty strings." +default_enabled = true +allowlist_models = ["d_fine", "deformable_detr", "glm46v", "lw_detr", "pp_doclayout_v3", "rt_detr", "rt_detr_v2", "voxtral", "voxtral_realtime"] + +[rules.TRF005.explanation] +what_it_does = "Checks the shape of _no_split_modules when present." +why_bad = "Malformed values can break device-map partitioning and sharding behavior." +diff = ''' +-_no_split_modules = [SomeLayerClass, ""] ++_no_split_modules = ["AcmeDecoderLayer", "AcmeAttention"] +''' + +[rules.TRF006] +description = "forward with cache arguments should reference cache control/state variables consistently." +default_enabled = true +allowlist_models = ["chinese_clip", "evolla", "idefics2", "llama4"] + +[rules.TRF006.explanation] +what_it_does = "Checks forward signatures that expose cache arguments for usage of those arguments in method body." +why_bad = "Unused cache arguments can indicate incomplete caching support and inconsistent API behavior." +diff = ''' + def forward(self, x, past_key_values=None, use_cache=False): ++ if use_cache: ++ ... + return x +''' + +[rules.TRF007] +description = "self.post_init() in __init__ should remain at the end of initialization for PreTrainedModel classes." +default_enabled = true +allowlist_models = ["distilbert", "lxmert", "mt5", "pix2struct", "pop2piano", "switch_transformers", "t5"] + +[rules.TRF007.explanation] +what_it_does = "Checks for self attribute assignments after self.post_init() in __init__." +why_bad = "Mutating model structure after post_init can bypass intended initialization/finalization logic." +diff = ''' + def __init__(self, config): + ... +- self.post_init() +- self.proj = nn.Linear(...) ++ self.proj = nn.Linear(...) ++ self.post_init() +''' + +[rules.TRF008] +description = "Doc decorators on PreTrainedModel classes should avoid empty add_start_docstrings usage." +default_enabled = true + +[rules.TRF008.explanation] +what_it_does = "Checks add_start_docstrings usage on model classes for non-empty docstring arguments." +why_bad = "Empty decorator usage produces unclear docs and weakens generated API documentation quality." +diff = ''' +-@add_start_docstrings("") ++@add_start_docstrings("The Acme model.") + class AcmeModel(AcmePreTrainedModel): + ... +''' + +[rules.TRF009] +description = "modeling_.py should avoid importing implementation code from another model package." +default_enabled = true +allowlist_models = ["dpr", "maskformer", "sam3_video", "vision_text_dual_encoder"] + +[rules.TRF009.explanation] +what_it_does = "Checks modeling files for cross-model imports such as transformers.models.other_model.* or from ..other_model.* imports." +why_bad = "Cross-model implementation imports violate the single-file policy and make model behavior harder to inspect and maintain." +diff = ''' +-from transformers.models.llama.modeling_llama import LlamaAttention ++# Keep implementation local to this file. ++# If reusing code, copy it with a # Copied from comment. +''' + +[rules.TRF010] +description = "Direct config definitions must use @strict(accept_kwargs=True)." +default_enabled = true +allowlist_models = ["nemotron_h", "vibevoice_asr"] + +[rules.TRF010.explanation] +what_it_does = "Checks direct PreTrainedConfig/PretrainedConfig subclasses in configuration_*.py and modular_*.py for an explicit @strict(accept_kwargs=True) decorator." +why_bad = "Without strict, new config classes miss the repo's runtime type-validation contract and drift from the dataclass-based config standard." +diff = ''' ++@strict(accept_kwargs=True) + class AcmeConfig(PreTrainedConfig): + ... +''' + +[rules.TRF011] +description = "forward() must not access non-nn.Module attributes on submodules (breaks pipeline parallelism with Identity replacement)." +default_enabled = true +allowlist_models = [] + +[rules.TRF011.explanation] +what_it_does = "In forward() methods of PreTrainedModel subclasses, checks for attribute accesses on submodules that would not exist on torch.nn.Identity. This includes attribute accesses on loop variables iterating over self.layers, and self.. chains where is not a standard nn.Module attribute." +why_bad = "Pipeline parallelism may replace any submodule with torch.nn.Identity. Accessing custom attributes (e.g. decoder_layer.attention_type) on a replaced module raises AttributeError at runtime. Per-layer metadata should be read from self.config instead." +diff = ''' + def forward(self, ...): +- for decoder_layer in self.layers: ++ for i, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, +- attention_mask=causal_mask_mapping[decoder_layer.attention_type], ++ attention_mask=causal_mask_mapping[self.config.layer_types[i]], + ) +''' + +[rules.TRF012] +description = "_init_weights must use init primitives, not in-place operations on module weights." +default_enabled = true +allowlist_models = [] + +[rules.TRF012.explanation] +what_it_does = "Checks that _init_weights(self, module) does not use in-place operations (e.g. .normal_(), .zero_()) directly on module weights." +why_bad = "We rely on internal flags set on parameters to track whether they need re-initialization. In-place ops bypass this mechanism. Use the `init` primitives instead." +diff = ''' ++from transformers import initialization as init ++ + def _init_weights(self, module): +- module.weight.normal_(mean=0.0, std=0.02) ++ init.normal_(module.weight, mean=0.0, std=0.02) +''' + +[rules.TRF013] +description = "PreTrainedModel __init__ must call self.post_init()." +default_enabled = true +allowlist_models = [] + +[rules.TRF013.explanation] +what_it_does = "Checks that every PreTrainedModel subclass with an __init__ method calls self.post_init(). In modular files, calling super().__init__() is also accepted since it propagates post_init from the parent." +why_bad = "post_init performs essential finalization (weight initialization, gradient checkpointing setup, etc.). Omitting it causes subtle runtime bugs." +diff = ''' + class AcmeModel(AcmePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList(...) ++ self.post_init() +''' + +[rules.TRF014] +description = "`trust_remote_code` should never be used in native model integrations." +default_enabled = true +allowlist_models = [] + +[rules.TRF014.explanation] +what_it_does = "Checks whether `trust_remote_code` is passed or used in code (e.g. as kwarg) within native model integration files." +why_bad = "`trust_remote_code` allows arbitrary loading, including binaries, which should only be a power feature for users, not a standard use-case. Native integrations must not depend on it, as remote code cannot be reviewed or maintained within transformers." +diff = ''' + class AcmeModel(AcmePreTrainedModel): + def __init__(self, config): + super().__init__(config) +- self.model = AutoModel.from_pretrained(..., trust_remote_code=True) ++ self.model = AutoModel.from_pretrained(...) +''' + +[rules.TRF015] +description = "Models with non-empty _tied_weights_keys must have tie_word_embeddings in their Config." +default_enabled = true +allowlist_models = [] + +[rules.TRF015.explanation] +what_it_does = "When a PreTrainedModel subclass defines _tied_weights_keys as a non-empty collection, checks that the corresponding configuration file declares a tie_word_embeddings field." +why_bad = "Without tie_word_embeddings in the config, users cannot control weight tying behavior. The model ties weights unconditionally, breaking serialization round-trips and preventing fine-tuning with untied heads." +diff = ''' + # configuration_foo.py + @strict(accept_kwargs=True) + class FooConfig(PreTrainedConfig): + hidden_size: int = 768 ++ tie_word_embeddings: bool = True +''' From fb1f387c209f0c0374b0579b7f78017cdc34cbd3 Mon Sep 17 00:00:00 2001 From: Harshal Janjani Date: Thu, 23 Apr 2026 19:00:16 +0400 Subject: [PATCH 25/25] fix: Fix loss coupling issue --- src/transformers/loss/loss_deimv2.py | 61 ++++++++++++++++------------ 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/src/transformers/loss/loss_deimv2.py b/src/transformers/loss/loss_deimv2.py index 88c1aa94d1fa..5c8573f7da44 100644 --- a/src/transformers/loss/loss_deimv2.py +++ b/src/transformers/loss/loss_deimv2.py @@ -208,28 +208,34 @@ def Deimv2ForObjectDetectionLoss( outputs_loss = {"logits": logits, "pred_boxes": pred_boxes.clamp(min=0, max=1)} auxiliary_outputs = None - if config.auxiliary_loss and denoising_meta_values is not None: - dn_out_coord, outputs_coord = torch.split( - outputs_coord.clamp(min=0, max=1), denoising_meta_values["dn_num_split"], dim=2 - ) - dn_out_class, outputs_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2) - # https://github.com/Intellindust-AI-Lab/DEIMv2/blob/main/engine/deim/deim_decoder.py#L562-L571 - # The original splits denoising queries in the decoder; here it happens in the loss since the decoder returns unsplit tensors. - _, logits = torch.split(logits, denoising_meta_values["dn_num_split"], dim=1) - _, pred_boxes = torch.split(pred_boxes, denoising_meta_values["dn_num_split"], dim=1) - dn_out_corners, out_corners = torch.split(predicted_corners, denoising_meta_values["dn_num_split"], dim=2) - dn_out_refs, out_refs = torch.split(initial_reference_points, denoising_meta_values["dn_num_split"], dim=2) - - outputs_loss["logits"] = logits - outputs_loss["pred_boxes"] = pred_boxes.clamp(min=0, max=1) + if config.auxiliary_loss: + if denoising_meta_values is not None: + dn_out_coord, normal_out_coord = torch.split( + outputs_coord.clamp(min=0, max=1), denoising_meta_values["dn_num_split"], dim=2 + ) + dn_out_class, normal_out_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2) + # https://github.com/Intellindust-AI-Lab/DEIMv2/blob/main/engine/deim/deim_decoder.py#L562-L571 + # The original splits denoising queries in the decoder; here it happens in the loss since the decoder returns unsplit tensors. + _, normal_logits = torch.split(logits, denoising_meta_values["dn_num_split"], dim=1) + _, normal_pred_boxes = torch.split(pred_boxes, denoising_meta_values["dn_num_split"], dim=1) + dn_out_corners, out_corners = torch.split(predicted_corners, denoising_meta_values["dn_num_split"], dim=2) + dn_out_refs, out_refs = torch.split(initial_reference_points, denoising_meta_values["dn_num_split"], dim=2) + + outputs_loss["logits"] = normal_logits + outputs_loss["pred_boxes"] = normal_pred_boxes.clamp(min=0, max=1) + else: + normal_out_coord = outputs_coord.clamp(min=0, max=1) + normal_out_class = outputs_class + out_corners = predicted_corners + out_refs = initial_reference_points auxiliary_outputs = _set_aux_loss2( - outputs_class[:, :-1].transpose(0, 1), - outputs_coord[:, :-1].transpose(0, 1), + normal_out_class[:, :-1].transpose(0, 1), + normal_out_coord[:, :-1].transpose(0, 1), out_corners[:, :-1].transpose(0, 1), out_refs[:, :-1].transpose(0, 1), out_corners[:, -1], - outputs_class[:, -1], + normal_out_class[:, -1], ) outputs_loss["auxiliary_outputs"] = auxiliary_outputs @@ -237,16 +243,17 @@ def Deimv2ForObjectDetectionLoss( _set_aux_loss([enc_topk_logits], [enc_topk_bboxes.clamp(min=0, max=1)]) ) - dn_auxiliary_outputs = _set_aux_loss2( - dn_out_class.transpose(0, 1), - dn_out_coord.transpose(0, 1), - dn_out_corners.transpose(0, 1), - dn_out_refs.transpose(0, 1), - dn_out_corners[:, -1], - dn_out_class[:, -1], - ) - outputs_loss["dn_auxiliary_outputs"] = dn_auxiliary_outputs - outputs_loss["denoising_meta_values"] = denoising_meta_values + if denoising_meta_values is not None: + dn_auxiliary_outputs = _set_aux_loss2( + dn_out_class.transpose(0, 1), + dn_out_coord.transpose(0, 1), + dn_out_corners.transpose(0, 1), + dn_out_refs.transpose(0, 1), + dn_out_corners[:, -1], + dn_out_class[:, -1], + ) + outputs_loss["dn_auxiliary_outputs"] = dn_auxiliary_outputs + outputs_loss["denoising_meta_values"] = denoising_meta_values loss_dict = criterion(outputs_loss, labels)