Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions docs/source/en/model_doc/deimv2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
<!--Copyright 2025 The HuggingFace Team.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

*This model was released in 2025 and added to Hugging Face Transformers in 2025-10.* [web:28][web:25]

# DEIMv2

<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="Object Detection" src="https://img.shields.io/badge/Object%20Detection-0ea5e9?style=flat">
<img alt="AutoBackbone" src="https://img.shields.io/badge/AutoBackbone-16a34a?style=flat">
</div>
</div>

## Overview

DEIMv2 is a real‑time object detection architecture built on DINOv3 features, introducing a Spatial Tuning Adapter (STA) to convert single‑scale ViT features into a lightweight multi‑scale pyramid, a simplified decoder, and an upgraded Dense one‑to‑one matching strategy. [web:16][web:6]

This integration uses the AutoBackbone API so DINO‑family backbones can be reused without re‑implementation in the detection head; the initial release targets DINOv3/ViT backbones, with tiny HGNetv2 variants planned as follow‑ups. [web:17][web:28]

> [!TIP]
> The smallest working example below shows how to run inference and obtain boxes, scores, and labels from post‑processing. [web:25][web:28]

<hfoptions id="usage">
<hfoption id="Pipeline">

from PIL import Image
from transformers import pipeline

detector = pipeline(
task="object-detection",
model="your-org/deimv2-dinov3-base"
)
image = Image.open("path/to/your/image.jpg")
outputs = detector(image)
print(outputs[:3])

text
[web:25][web:28]

</hfoption>
<hfoption id="AutoModel">

from PIL import Image
import requests
from transformers import Deimv2ImageProcessor, Deimv2ForObjectDetection

ckpt = "your-org/deimv2-dinov3-base" # replace when a checkpoint is available
model = Deimv2ForObjectDetection.from_pretrained(ckpt)
processor = Deimv2ImageProcessor.from_pretrained(ckpt)

url = "https://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor.preprocess([image], return_tensors="pt")
outputs = model(**inputs)
results = processor.post_process_object_detection(outputs, threshold=0.5)
print(results)

text
[web:25][web:28]

</hfoption>
<hfoption id="transformers CLI">

echo -e "https://images.cocodataset.org/val2017/000000039769.jpg" | transformers run
--task object-detection
--model your-org/deimv2-dinov3-base

text
[web:25][web:28]

</hfoption>
</hfoptions>

## Model notes

- Backbone via AutoBackbone: loads DINOv3/ViT variants and exposes feature maps to the DEIMv2 head. [web:17][web:28]
- Spatial Tuning Adapter: transforms single‑scale features into a multi‑scale pyramid for accurate localization with minimal overhead. [web:16][web:6]
- Decoder and Dense O2O: streamlined decoder with one‑to‑one assignment for stable training and real‑time throughput. [web:16][web:6]

## Expected inputs and outputs

- Inputs: `pixel_values` shaped \(B \times 3 \times H \times W\), produced by `Deimv2ImageProcessor.preprocess`. [web:43][web:25]
- Outputs: class `logits` \(B \times Q \times C\) and normalized `pred_boxes` \(B \times Q \times 4\); use `post_process_object_detection` to filter and convert to absolute coordinates. [web:43][web:28]

## Configuration

[[autodoc]] Deimv2Config
- init

This configuration defines backbone settings, query count, decoder depth, and STA parameters, and sets `model_type="deimv2"`. [web:28][web:44]

## Base model

[[autodoc]] Deimv2Model
- forward

This module wires the backbone to STA and the decoder, returning decoder hidden states for the detection head. [web:28][web:17]

## Task head

[[autodoc]] Deimv2ForObjectDetection
- forward

This head predicts class logits and normalized bounding boxes for a fixed set of queries. [web:25][web:28]

## Image Processor

[[autodoc]] Deimv2ImageProcessor
- preprocess
- post_process_object_detection

Handles resizing, normalization, batching, and conversion of model outputs to boxes, scores, and labels. [web:43][web:25]

## Resources

- Paper: “Real‑Time Object Detection Meets DINOv3.” [web:16][web:7]
- Official repository and model zoo for reference implementations and weights. [web:3][web:12]
- AutoBackbone documentation for reusing vision backbones. [web:17][web:28]

## Citations

Please cite the original DEIMv2 paper when using this model: “Real‑Time Object Detection Meets DINOv3.” [web:16][web:7]
11 changes: 11 additions & 0 deletions src/transformers/models/deimv2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .configuration_deimv2 import Deimv2Config
from .image_processing_deimv2 import Deimv2ImageProcessor
from .modeling_deimv2 import Deimv2Model, Deimv2ForObjectDetection

__all__ = [
"Deimv2Config",
"Deimv2ImageProcessor",
"Deimv2Model",
"Deimv2ForObjectDetection",
]

57 changes: 57 additions & 0 deletions src/transformers/models/deimv2/configuration_deimv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from dataclasses import dataclass
from typing import Optional, Dict, Any
from ..auto.configuration_auto import AutoBackboneConfig
from ...configuration_utils import PretrainedConfig

@dataclass
class Deimv2Preset:
hidden_dim: int
num_queries: int
num_decoder_layers: int
backbone: str

DEIMV2_PRESETS: Dict[str, Deimv2Preset] = {
"base-dinov3-s": Deimv2Preset(hidden_dim=256, num_queries=300, num_decoder_layers=6, backbone="facebook/dinov2-small"),
"base-dinov3-b": Deimv2Preset(hidden_dim=256, num_queries=300, num_decoder_layers=6, backbone="facebook/dinov2-base"),
}

class Deimv2Config(PretrainedConfig):
model_type = "deimv2"

def __init__(
self,
backbone_config: Optional[Dict[str, Any]] = None,
hidden_dim: int = 256,
num_queries: int = 300,
num_decoder_layers: int = 6,
num_labels: int = 91,
# STA and decoder knobs (placeholders)
sta_num_scales: int = 4,
use_dense_o2o: bool = True,
layer_norm_type: str = "rms",
activation: str = "swish",
**kwargs,
):
super().__init__(**kwargs)
self.backbone_config = backbone_config or AutoBackboneConfig.from_pretrained(DEIMV2_PRESETS["base-dinov3-b"].backbone).to_dict()
self.hidden_dim = hidden_dim
self.num_queries = num_queries
self.num_decoder_layers = num_decoder_layers
self.num_labels = num_labels
self.sta_num_scales = sta_num_scales
self.use_dense_o2o = use_dense_o2o
self.layer_norm_type = layer_norm_type
self.activation = activation
@classmethod
def from_preset(cls, preset_name: str, **kwargs) -> "Deimv2Config":
if preset_name not in DEIMV2_PRESETS:
raise ValueError(f"Preset '{preset_name}' not found. Available presets: {list(DEIMV2_PRESETS.keys())}")
preset = DEIMV2_PRESETS[preset_name]
backbone_config = AutoBackboneConfig.from_pretrained(preset.backbone).to_dict()
return cls(
backbone_config=backbone_config,
hidden_dim=preset.hidden_dim,
num_queries=preset.num_queries,
num_decoder_layers=preset.num_decoder_layers,
**kwargs,
)
45 changes: 45 additions & 0 deletions src/transformers/models/deimv2/image_processing_deimv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import List, Dict, Any, Union
import torch
from PIL import Image
from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_transforms import resize, normalize, to_channel_dimension_format
from ...utils.torch_utils import is_torch_tensor

class Deimv2ImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]

def __init__(self, size: int = 1024, image_mean=None, image_std=None, **kwargs):
super().__init__(**kwargs)
self.size = size
self.image_mean = image_mean or [0.485, 0.456, 0.406]
self.image_std = image_std or [0.229, 0.224, 0.225]

def preprocess(self, images: List[Union[Image.Image, "np.ndarray", torch.Tensor]], return_tensors="pt", **kwargs) -> BatchFeature:
pixel_values = []
for img in images:
if not is_torch_tensor(img):
img = Image.fromarray(img) if not isinstance(img, Image.Image) else img
img = resize(img, size={"shortest_edge": self.size})
img = to_channel_dimension_format(img, "channels_first")
img = normalize(img, mean=self.image_mean, std=self.image_std)
pixel_values.append(torch.as_tensor(img, dtype=torch.float32))
pixel_values = torch.stack(pixel_values, dim=0)
return BatchFeature(data={"pixel_values": pixel_values}, tensor_type=return_tensors)

def post_process_object_detection(self, outputs, threshold: float = 0.5, target_sizes=None) -> List[Dict[str, Any]]:
# Minimal passthrough; replace with real box/logit decoding
logits = outputs["logits"]
boxes = outputs["pred_boxes"]
probs = logits.sigmoid()
results = []
for prob, box in zip(probs, boxes):
keep = prob.max(dim=-1).values > threshold
results.append({"scores": prob[keep].max(dim=-1).values, "labels": prob[keep].argmax(dim=-1), "boxes": box[keep]})
return results
if target_sizes is not None:
for result, size in zip(results, target_sizes):
img_h, img_w = size
boxes = result["boxes"]
boxes = boxes * torch.tensor([img_w, img_h, img_w, img_h], dtype=boxes.dtype, device=boxes.device)
result["boxes"] = boxes
return results
93 changes: 93 additions & 0 deletions src/transformers/models/deimv2/modeling_deimv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Optional, Tuple, Dict, Any
import torch
import torch.nn as nn
from ...modeling_utils import PreTrainedModel
from ..auto import AutoBackbone
from .configuration_deimv2 import Deimv2Config
from ...utils import logging

logger = logging.get_logger(__name__)

class Deimv2PreTrainedModel(PreTrainedModel):
config_class = Deimv2Config
base_model_prefix = "deimv2"
_no_split_modules = []

class SpatialTuningAdapter(nn.Module):
def __init__(self, hidden_dim: int, num_scales: int):
super().__init__()
self.proj = nn.ModuleList([nn.Conv2d(hidden_dim, hidden_dim, 1) for _ in range(num_scales)])

def forward(self, feat: torch.Tensor) -> Tuple[torch.Tensor, ...]:
# feat: (B, C, H, W); create a toy pyramid by striding
feats = []
x = feat
for i, p in enumerate(self.proj):
feats.append(p(x))
if i < len(self.proj) - 1:
x = nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return tuple(feats)

class SimpleDecoder(nn.Module):
def __init__(self, hidden_dim: int, num_layers: int, num_queries: int):
super().__init__()
self.query_embed = nn.Embedding(num_queries, hidden_dim)
self.layers = nn.ModuleList([nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim * 4, batch_first=True) for _ in range(num_layers)])
self.decoder = nn.TransformerDecoder(self.layers[0], num_layers=num_layers)

def forward(self, feats: Tuple[torch.Tensor, ...]) -> torch.Tensor:
# Use the highest-resolution feature for a stub attention target
bs = feats[0].size(0)
tgt = self.query_embed.weight.unsqueeze(0).expand(bs, -1, -1)
# Flatten spatial dims
f = feats[0].flatten(2).transpose(1, 2) # (B, HW, C)
memory = f
hs = self.decoder(tgt, memory) # (B, Q, C)
return hs

class Deimv2Model(Deimv2PreTrainedModel):
def __init__(self, config: Deimv2Config):
super().__init__(config)
self.backbone = AutoBackbone.from_config(config.backbone_config)
out_channels = self.backbone.channels
hidden = config.hidden_dim
if isinstance(out_channels, (tuple, list)):
backbone_dim = out_channels[0]
else:
backbone_dim = out_channels
self.input_proj = nn.Conv2d(backbone_dim, hidden, kernel_size=1)
self.sta = SpatialTuningAdapter(hidden_dim=hidden, num_scales=config.sta_num_scales)
self.decoder = SimpleDecoder(hidden_dim=hidden, num_layers=config.num_decoder_layers, num_queries=config.num_queries)

def forward(self, pixel_values: torch.Tensor, return_dict: bool = True, **kwargs) -> Dict[str, torch.Tensor]:
features = self.backbone(pixel_values).feature_maps # tuple of (B, C, H, W)
x = features[0]
x = self.input_proj(x)
feats = self.sta(x)
hs = self.decoder(feats) # (B, Q, C)
return {"decoder_hidden_states": hs}

class Deimv2ForObjectDetection(Deimv2PreTrainedModel):
def __init__(self, config: Deimv2Config):
super().__init__(config)
self.model = Deimv2Model(config)
hidden = config.hidden_dim
self.class_head = nn.Linear(hidden, config.num_labels)
self.box_head = nn.Linear(hidden, 4)

def forward(self, pixel_values: torch.Tensor, labels: Optional[Dict[str, torch.Tensor]] = None, **kwargs) -> Dict[str, torch.Tensor]:
outputs = self.model(pixel_values, return_dict=True)
hs = outputs["decoder_hidden_states"]
logits = self.class_head(hs)
boxes = self.box_head(hs).sigmoid()
out = {"logits": logits, "pred_boxes": boxes}
# TODO: compute loss if labels provided
return out

def freeze_backbone(self):
for param in self.model.backbone.parameters():
param.requires_grad = False
logger.info("Backbone frozen.")
self.model.backbone.eval()


7 changes: 7 additions & 0 deletions tests/models/deimv2/test_configuration_deimv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from transformers import Deimv2Config
def test_roundtrip():
cfg = Deimv2Config()
s = cfg.to_json_string()
cfg2 = Deimv2Config.from_json_string(s)
assert cfg2.model_type == "deimv2"
assert cfg2.hidden_dim == cfg.hidden_dim
14 changes: 14 additions & 0 deletions tests/models/deimv2/test_image_processing_deimv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
from PIL import Image
import numpy as np
from transformers import Deimv2ImageProcessor

def test_preprocess_postprocess():
proc = Deimv2ImageProcessor(size=256)
img = Image.fromarray((np.random.rand(256,256,3)*255).astype("uint8"))
batch = proc.preprocess([img])
assert "pixel_values" in batch
dummy = {"logits": torch.randn(1, 300, 91), "pred_boxes": torch.rand(1, 300, 4)}
res = proc.post_process_object_detection(dummy, threshold=0.9)
assert isinstance(res, list)
assert "scores" in res[0]
12 changes: 12 additions & 0 deletions tests/models/deimv2/test_modeling_deimv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch
from transformers import Deimv2Config
from transformers.models.deimv2.modeling_deimv2 import Deimv2ForObjectDetection

def test_forward_shapes():
cfg = Deimv2Config()
model = Deimv2ForObjectDetection(cfg)
pixel_values = torch.randn(2, 3, 512, 512)
out = model(pixel_values)
assert out["logits"].shape[:2] == (2, cfg.num_queries)
assert out["pred_boxes"].shape[-1] == 4