diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index b54775dd6e10..098d56f1b19a 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/owlvit/modular_owlvit.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_owlvit.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2022 Google AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,30 +17,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch OWL-ViT model.""" from collections.abc import Callable from dataclasses import dataclass from typing import Any import torch -from torch import Tensor, nn +from torch import nn from ... import initialization as init from ...activations import ACT2FN +from ...loss.loss_for_object_detection import box_iou, generalized_box_iou from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - ModelOutput, - TransformersKwargs, - auto_docstring, - is_vision_available, - logging, - torch_int, -) +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, is_vision_available, torch_int from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from .configuration_owlvit import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig @@ -44,24 +43,6 @@ from transformers.image_transforms import center_to_corners_format -logger = logging.get_logger(__name__) - - -# See all OwlViT models at https://huggingface.co/models?filter=owlvit - - -# Copied from transformers.models.clip.modeling_clip.contrastive_loss with clip->owlvit -def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: - return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) - - -# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->owlvit -def owlvit_loss(similarity: torch.Tensor) -> torch.Tensor: - caption_loss = contrastive_loss(similarity) - image_loss = contrastive_loss(similarity.t()) - return (caption_loss + image_loss) / 2.0 - - @dataclass @auto_docstring class OwlViTOutput(ModelOutput): @@ -100,74 +81,6 @@ def to_tuple(self) -> tuple[Any]: ) -# Copied from transformers.loss.loss_for_object_detection._upcast -def _upcast(t: Tensor) -> Tensor: - # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type - if t.is_floating_point(): - return t if t.dtype in (torch.float32, torch.float64) else t.float() - else: - return t if t.dtype in (torch.int32, torch.int64) else t.int() - - -# Copied from transformers.loss.loss_for_object_detection.box_area -def box_area(boxes: Tensor) -> Tensor: - """ - Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. - - Args: - boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`): - Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1 - < x2` and `0 <= y1 < y2`. - - Returns: - `torch.FloatTensor`: a tensor containing the area for each box. - """ - boxes = _upcast(boxes) - return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) - - -# Copied from transformers.loss.loss_for_object_detection.box_iou -def box_iou(boxes1, boxes2): - area1 = box_area(boxes1) - area2 = box_area(boxes2) - - left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] - right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] - - width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2] - inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M] - - union = area1[:, None] + area2 - inter - - iou = inter / union - return iou, union - - -# Copied from transformers.loss.loss_for_object_detection.generalized_box_iou -def generalized_box_iou(boxes1, boxes2): - """ - Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format. - - Returns: - `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) - """ - # degenerate boxes gives inf / nan results - # so do an early check - if not (boxes1[:, 2:] >= boxes1[:, :2]).all(): - raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}") - if not (boxes2[:, 2:] >= boxes2[:, :2]).all(): - raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}") - iou, union = box_iou(boxes1, boxes2) - - top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2]) - bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) - - width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2] - area = width_height[:, :, 0] * width_height[:, :, 1] - - return iou - (area - union) / area - - @dataclass @auto_docstring( custom_intro=""" @@ -274,25 +187,26 @@ def to_tuple(self) -> tuple[Any]: class OwlViTVisionEmbeddings(nn.Module): def __init__(self, config: OwlViTVisionConfig): super().__init__() - self.patch_size = config.patch_size self.config = config self.embed_dim = config.hidden_size - self.class_embedding = nn.Parameter(torch.randn(config.hidden_size)) + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, - kernel_size=config.patch_size, - stride=config.patch_size, + kernel_size=self.patch_size, + stride=self.patch_size, bias=False, ) - self.num_patches = (config.image_size // config.patch_size) ** 2 + self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution @@ -322,18 +236,21 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, size=(new_height, new_width), mode="bicubic", align_corners=False, ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: batch_size, _, height, width = pixel_values.shape - patch_embeds = self.patch_embedding(pixel_values) # shape = [batch_size, num_channels, height, width] + patch_embeds = self.patch_embedding(pixel_values) patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) @@ -347,8 +264,10 @@ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: boo class OwlViTTextEmbeddings(nn.Module): def __init__(self, config: OwlViTTextConfig): super().__init__() - self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) - self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer( @@ -375,49 +294,36 @@ def forward( return embeddings -# Copied from transformers.models.bert.modeling_bert.eager_attention_forward def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor | None, - scaling: float | None = None, + scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): - if scaling is None: - scaling = query.size(-1) ** -0.5 - - # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling - + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attn_weights class OwlViTAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config): + def __init__(self, config: OwlViTVisionConfig | OwlViTTextConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.is_causal = False @@ -433,12 +339,17 @@ def forward( attention_mask: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape - query_states = self.q_proj(hidden_states).view(*hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(*hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2) + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward @@ -446,22 +357,21 @@ def forward( attn_output, attn_weights = attention_interface( self, - query_states, - key_states, - value_states, + queries, + keys, + values, attention_mask, scaling=self.scale, dropout=0.0 if not self.training else self.dropout, **kwargs, ) - attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() attn_output = self.out_proj(attn_output) return attn_output, attn_weights -# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->OwlViT class OwlViTMLP(nn.Module): def __init__(self, config): super().__init__() @@ -477,7 +387,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->OwlViT class OwlViTEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: OwlViTVisionConfig | OwlViTTextConfig): super().__init__() @@ -577,7 +486,6 @@ def _init_weights(self, module: nn.Module): init.zeros_(module.bias) -# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->OwlViT class OwlViTEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a @@ -587,7 +495,7 @@ class OwlViTEncoder(nn.Module): config: OwlViTConfig """ - def __init__(self, config: OwlViTConfig): + def __init__(self, config: OwlViTTextConfig | OwlViTVisionConfig): super().__init__() self.config = config self.layers = nn.ModuleList([OwlViTEncoderLayer(config) for _ in range(config.num_hidden_layers)]) @@ -831,6 +739,18 @@ def forward( ) +# contrastive loss function, adapted from +# https://sachinruk.github.io/blog/2021-03-07-owlvit.html +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) + + +def owlvit_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.t()) + return (caption_loss + image_loss) / 2.0 + + @auto_docstring class OwlViTModel(OwlViTPreTrainedModel): config: OwlViTConfig diff --git a/src/transformers/models/owlvit/modular_owlvit.py b/src/transformers/models/owlvit/modular_owlvit.py new file mode 100644 index 000000000000..5fc21d2388d5 --- /dev/null +++ b/src/transformers/models/owlvit/modular_owlvit.py @@ -0,0 +1,1184 @@ +# Copyright 2022 Google AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OWL-ViT model (modular; see `modeling_owlvit.py`, auto-generated).""" + +from dataclasses import dataclass +from typing import Any + +import torch +from torch import nn + +from ... import initialization as init +from ...loss.loss_for_object_detection import box_iou, generalized_box_iou +from ...masking_utils import create_causal_mask +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + ModelOutput, + TransformersKwargs, + auto_docstring, + is_vision_available, +) +from ...utils.generic import can_return_tuple, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from ..clip.modeling_clip import ( + CLIPMLP, + CLIPAttention, + CLIPEncoder, + CLIPEncoderLayer, + CLIPTextEmbeddings, + CLIPVisionEmbeddings, + contrastive_loss, +) +from .configuration_owlvit import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig + + +if is_vision_available(): + from transformers.image_transforms import center_to_corners_format + + +def owlvit_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.t()) + return (caption_loss + image_loss) / 2.0 + + +@dataclass +@auto_docstring +class OwlViTOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size * num_max_text_queries, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of + [`OwlViTVisionModel`]. + text_model_output (tuple[`BaseModelOutputWithPooling`]): + The output of the [`OwlViTTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`OwlViTVisionModel`]. + """ + + loss: torch.FloatTensor | None = None + logits_per_image: torch.FloatTensor | None = None + logits_per_text: torch.FloatTensor | None = None + text_embeds: torch.FloatTensor | None = None + image_embeds: torch.FloatTensor | None = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`OwlViTForObjectDetection`]. + """ +) +class OwlViTObjectDetectionOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~OwlViTImageProcessor.post_process_object_detection`] to retrieve the + unnormalized bounding boxes. + text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`): + Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes + image embeddings for each patch. + class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`): + Class embeddings of all image patches. OWL-ViT represents images as a set of image patches where the total + number of patches is (image_size / patch_size)**2. + text_model_output (tuple[`BaseModelOutputWithPooling`]): + The output of the [`OwlViTTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`OwlViTVisionModel`]. + """ + + loss: torch.FloatTensor | None = None + loss_dict: dict | None = None + logits: torch.FloatTensor | None = None + pred_boxes: torch.FloatTensor | None = None + text_embeds: torch.FloatTensor | None = None + image_embeds: torch.FloatTensor | None = None + class_embeds: torch.FloatTensor | None = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`OwlViTForObjectDetection.image_guided_detection`]. + """ +) +class OwlViTImageGuidedObjectDetectionOutput(ModelOutput): + r""" + logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`): + Classification logits (including no-object) for all queries. + image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`): + Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes + image embeddings for each patch. + query_image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`): + Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes + image embeddings for each patch. + target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual target image in the batch + (disregarding possible padding). You can use [`~OwlViTImageProcessor.post_process_object_detection`] to + retrieve the unnormalized bounding boxes. + query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual query image in the batch + (disregarding possible padding). You can use [`~OwlViTImageProcessor.post_process_object_detection`] to + retrieve the unnormalized bounding boxes. + class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`): + Class embeddings of all image patches. OWL-ViT represents images as a set of image patches where the total + number of patches is (image_size / patch_size)**2. + text_model_output (tuple[`BaseModelOutputWithPooling`]): + The output of the [`OwlViTTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`OwlViTVisionModel`]. + """ + + logits: torch.FloatTensor | None = None + image_embeds: torch.FloatTensor | None = None + query_image_embeds: torch.FloatTensor | None = None + target_pred_boxes: torch.FloatTensor | None = None + query_pred_boxes: torch.FloatTensor | None = None + class_embeds: torch.FloatTensor | None = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +class OwlViTVisionEmbeddings(CLIPVisionEmbeddings): + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + patch_embeds = self.patch_embedding(pixel_values) + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class OwlViTTextEmbeddings(CLIPTextEmbeddings): + def forward( + self, + input_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class OwlViTAttention(CLIPAttention): + pass + + +class OwlViTMLP(CLIPMLP): + pass + + +class OwlViTEncoderLayer(CLIPEncoderLayer): + def __init__(self, config: OwlViTVisionConfig | OwlViTTextConfig): + super().__init__(config) + self.self_attn = OwlViTAttention(config) + self.mlp = OwlViTMLP(config) + + +@auto_docstring +class OwlViTPreTrainedModel(PreTrainedModel): + config: OwlViTConfig + base_model_prefix = "owlvit" + input_modalities = ("image", "text") + supports_gradient_checkpointing = True + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + _supports_attention_backend = True + _no_split_modules = ["OwlViTEncoderLayer"] + _can_record_outputs = { + "hidden_states": OwlViTEncoderLayer, + "attentions": OwlViTAttention, + } + _keys_to_ignore_on_load_unexpected = [ + r".*text_model\.embeddings\.position_ids", + r".*vision_model\.embeddings\.position_ids", + ] + + @torch.no_grad() + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, OwlViTTextEmbeddings): + init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02) + init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02) + init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1))) + elif isinstance(module, OwlViTVisionEmbeddings): + init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1))) + elif isinstance(module, OwlViTAttention): + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor + init.normal_(module.q_proj.weight, std=in_proj_std) + init.normal_(module.k_proj.weight, std=in_proj_std) + init.normal_(module.v_proj.weight, std=in_proj_std) + init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, OwlViTMLP): + in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + init.normal_(module.fc1.weight, std=fc_std) + init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, OwlViTModel): + init.normal_( + module.text_projection.weight, + std=module.text_embed_dim**-0.5 * factor, + ) + init.normal_( + module.visual_projection.weight, + std=module.vision_embed_dim**-0.5 * factor, + ) + init.constant_(module.logit_scale, self.config.logit_scale_init_value) + elif isinstance(module, OwlViTForObjectDetection): + init.copy_(module.box_bias, module.compute_box_bias(module.num_patches_height, module.num_patches_width)) + if isinstance(module, nn.LayerNorm): + init.zeros_(module.bias) + init.ones_(module.weight) + if isinstance(module, nn.Linear): + init.normal_(module.weight, mean=0.0, std=factor) + if module.bias is not None: + init.zeros_(module.bias) + + +class OwlViTEncoder(CLIPEncoder): + def __init__(self, config: OwlViTTextConfig | OwlViTVisionConfig): + super().__init__(config) + self.layers = nn.ModuleList([OwlViTEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + + +class OwlViTTextTransformer(OwlViTPreTrainedModel): + def __init__(self, config: OwlViTTextConfig): + super().__init__(config) + + embed_dim = config.hidden_size + self.embeddings = OwlViTTextEmbeddings(config) + self.encoder = OwlViTEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + """ + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + attention_mask = create_causal_mask( + config=self.config, + inputs_embeds=hidden_states, + attention_mask=attention_mask, + past_key_values=None, + ) + + kwargs.pop("is_causal", None) + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + is_causal=True, + **kwargs, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # take features from the end of tokens embedding (end of token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(torch.int).argmax(dim=-1).to(last_hidden_state.device), + ] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + ) + + +class OwlViTTextModel(OwlViTPreTrainedModel): + config: OwlViTTextConfig + input_modalities = ("text",) + + def __init__(self, config: OwlViTTextConfig): + super().__init__(config) + self.text_model = OwlViTTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @auto_docstring + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + + Examples: + ```python + >>> from transformers import AutoProcessor, OwlViTTextModel + + >>> model = OwlViTTextModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") + >>> inputs = processor( + ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt" + ... ) + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + **kwargs, + ) + + +class OwlViTVisionTransformer(OwlViTPreTrainedModel): + def __init__(self, config: OwlViTVisionConfig): + super().__init__(config) + + self.embeddings = OwlViTVisionEmbeddings(config) + self.pre_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.encoder = OwlViTEncoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding: bool | None = False, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + # Cast the input to the expected `dtype` + expected_input_dtype = self.embeddings.patch_embedding.weight.dtype + pixel_values = pixel_values.to(expected_input_dtype) + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + hidden_states = self.pre_layernorm(hidden_states) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + **kwargs, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + ) + + +class OwlViTVisionModel(OwlViTPreTrainedModel): + config: OwlViTVisionConfig + main_input_name = "pixel_values" + input_modalities = ("image",) + + def __init__(self, config: OwlViTVisionConfig): + super().__init__(config) + self.vision_model = OwlViTVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor | None = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + r""" + Examples: + ```python + >>> from PIL import Image + >>> import httpx + >>> from io import BytesIO + >>> from transformers import AutoProcessor, OwlViTVisionModel + + >>> model = OwlViTVisionModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> with httpx.stream("GET", url) as response: + ... image = Image.open(BytesIO(response.read())) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + + return self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + + +@auto_docstring +class OwlViTModel(OwlViTPreTrainedModel): + config: OwlViTConfig + + def __init__(self, config: OwlViTConfig): + super().__init__(config) + + text_config = config.text_config + vision_config = config.vision_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = OwlViTTextTransformer(text_config) + self.vision_model = OwlViTVisionTransformer(vision_config) + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + self.logit_scale = nn.Parameter(torch.tensor(config.logit_scale_init_value)) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def get_text_features( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + + Examples: + ```python + >>> import torch + >>> from transformers import AutoProcessor, OwlViTModel + + >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") + >>> inputs = processor( + ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt" + ... ) + >>> with torch.inference_mode(): + ... text_features = model.get_text_features(**inputs) + ```""" + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + **kwargs, + ) + pooled_output = text_outputs.pooler_output + text_outputs.pooler_output = self.text_projection(pooled_output) + + return text_outputs + + @can_return_tuple + @auto_docstring + def get_image_features( + self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + Examples: + ```python + >>> import torch + >>> from transformers.image_utils import load_image + >>> from transformers import AutoProcessor, OwlViTModel + + >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> inputs = processor(images=image, return_tensors="pt") + >>> with torch.inference_mode(): + ... image_features = model.get_image_features(**inputs) + ```""" + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + vision_outputs.pooler_output = self.visual_projection(vision_outputs.pooler_output) + + return vision_outputs + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + return_loss: bool | None = None, + interpolate_pos_encoding: bool = False, + return_base_image_embeds: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | OwlViTOutput: + r""" + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + return_base_image_embeds (`bool`, *optional*): + Whether or not to return the base image embeddings. + + Examples: + ```python + >>> from PIL import Image + >>> import httpx + >>> from io import BytesIO + >>> from transformers import AutoProcessor, OwlViTModel + + >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> with httpx.stream("GET", url) as response: + ... image = Image.open(BytesIO(response.read())) + >>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + + # Get embeddings for all text queries in all batch samples + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + **kwargs, + ) + + text_embeds = text_outputs.pooler_output + text_embeds = self.text_projection(text_embeds) + image_embeds = vision_outputs.pooler_output + image_embeds = self.visual_projection(image_embeds) + + # normalized features + image_embeds = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True) + text_embeds_norm = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True) + + # cosine similarity as logits and set it on the correct device + logit_scale = self.logit_scale.exp().to(image_embeds.device) + + logits_per_text = torch.matmul(text_embeds_norm, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + loss = owlvit_loss(logits_per_text) + + text_embeds = text_embeds_norm + + return OwlViTOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +class OwlViTBoxPredictionHead(nn.Module): + def __init__(self, config: OwlViTConfig, out_dim: int = 4): + super().__init__() + + width = config.vision_config.hidden_size + self.dense0 = nn.Linear(width, width) + self.dense1 = nn.Linear(width, width) + self.gelu = nn.GELU() + self.dense2 = nn.Linear(width, out_dim) + + def forward(self, image_features: torch.Tensor) -> torch.FloatTensor: + output = self.dense0(image_features) + output = self.gelu(output) + output = self.dense1(output) + output = self.gelu(output) + output = self.dense2(output) + return output + + +class OwlViTClassPredictionHead(nn.Module): + def __init__(self, config: OwlViTConfig): + super().__init__() + + out_dim = config.text_config.hidden_size + self.query_dim = config.vision_config.hidden_size + + self.dense0 = nn.Linear(self.query_dim, out_dim) + self.logit_shift = nn.Linear(self.query_dim, 1) + self.logit_scale = nn.Linear(self.query_dim, 1) + self.elu = nn.ELU() + + def forward( + self, + image_embeds: torch.FloatTensor, + query_embeds: torch.FloatTensor | None, + query_mask: torch.Tensor | None, + ) -> tuple[torch.FloatTensor]: + image_class_embeds = self.dense0(image_embeds) + if query_embeds is None: + device = image_class_embeds.device + batch_size, num_patches = image_class_embeds.shape[:2] + pred_logits = torch.zeros((batch_size, num_patches, self.query_dim)).to(device) + return (pred_logits, image_class_embeds) + + # Normalize image and text features + image_class_embeds = image_class_embeds / (torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6) + query_embeds = query_embeds / (torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6) + + # Get class predictions + pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds) + + # Apply a learnable shift and scale to logits + logit_shift = self.logit_shift(image_embeds) + logit_scale = self.logit_scale(image_embeds) + logit_scale = self.elu(logit_scale) + 1 + pred_logits = (pred_logits + logit_shift) * logit_scale + + if query_mask is not None: + if query_mask.ndim > 1: + query_mask = torch.unsqueeze(query_mask, dim=-2) + + pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits) + pred_logits = pred_logits.to(torch.float32) + + return (pred_logits, image_class_embeds) + + +class OwlViTForObjectDetection(OwlViTPreTrainedModel): + config: OwlViTConfig + + def __init__(self, config: OwlViTConfig): + super().__init__(config) + + self.owlvit = OwlViTModel(config) + self.class_head = OwlViTClassPredictionHead(config) + self.box_head = OwlViTBoxPredictionHead(config) + + self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps) + self.sigmoid = nn.Sigmoid() + self.config = config + self.num_patches_height = self.config.vision_config.image_size // self.config.vision_config.patch_size + self.num_patches_width = self.config.vision_config.image_size // self.config.vision_config.patch_size + self.register_buffer( + "box_bias", self.compute_box_bias(self.num_patches_height, self.num_patches_width), persistent=False + ) + + self.post_init() + + @staticmethod + def normalize_grid_corner_coordinates(num_patches_height: int, num_patches_width: int) -> torch.Tensor: + # Create grid coordinates using torch + x_coordinates = torch.arange(1, num_patches_width + 1, dtype=torch.float32) + y_coordinates = torch.arange(1, num_patches_height + 1, dtype=torch.float32) + xx, yy = torch.meshgrid(x_coordinates, y_coordinates, indexing="xy") + + # Stack the coordinates and divide by their respective patch counts + box_coordinates = torch.stack((xx, yy), dim=-1) + box_coordinates[..., 0] /= num_patches_width + box_coordinates[..., 1] /= num_patches_height + + # Flatten (h, w, 2) -> (h*w, 2) + box_coordinates = box_coordinates.view(-1, 2) + + return box_coordinates + + def compute_box_bias(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor: + # The box center is biased to its position on the feature grid + box_coordinates = self.normalize_grid_corner_coordinates(num_patches_height, num_patches_width) + box_coordinates = torch.clip(box_coordinates, 0.0, 1.0) + + # Unnormalize xy + box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4) + + # The box size is biased to the patch size + box_size = torch.full_like(box_coord_bias, 1.0) + box_size[..., 0] /= num_patches_width + box_size[..., 1] /= num_patches_height + box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4) + + # Compute box bias + box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1) + return box_bias + + def box_predictor( + self, + image_feats: torch.FloatTensor, + feature_map: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + """ + Args: + image_feats: + Features extracted from the image, returned by the `image_text_embedder` method. + feature_map: + A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method. + interpolate_pos_encoding: + Whether to interpolate the pre-trained position encodings. + Returns: + pred_boxes: + List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary. + """ + # Bounding box detection head [batch_size, num_boxes, 4]. + pred_boxes = self.box_head(image_feats) + + # Compute the location of each token on the grid and use it to compute a bias for the bbox prediction + if interpolate_pos_encoding: + _, num_patches_height, num_patches_width, _ = feature_map.shape + box_bias = self.compute_box_bias(num_patches_height, num_patches_width) + else: + box_bias = self.box_bias + + box_bias = box_bias.to(feature_map.device) + pred_boxes += box_bias + pred_boxes = self.sigmoid(pred_boxes) + return pred_boxes + + def class_predictor( + self, + image_feats: torch.FloatTensor, + query_embeds: torch.FloatTensor | None = None, + query_mask: torch.Tensor | None = None, + ) -> tuple[torch.FloatTensor]: + """ + Args: + image_feats: + Features extracted from the `image_text_embedder`. + query_embeds: + Text query embeddings. + query_mask: + Must be provided with query_embeddings. A mask indicating which query embeddings are valid. + """ + (pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask) + + return (pred_logits, image_class_embeds) + + def image_text_embedder( + self, + input_ids: torch.Tensor, + pixel_values: torch.FloatTensor, + attention_mask: torch.Tensor, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor]: + outputs = self.owlvit( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + + if interpolate_pos_encoding: + _, _, height, width = pixel_values.shape + num_patches_height = height // self.config.vision_config.patch_size + num_patches_width = width // self.config.vision_config.patch_size + else: + num_patches_height = self.num_patches_height + num_patches_width = self.num_patches_width + + # Get image embeddings + last_hidden_state = outputs.vision_model_output[0] + image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state) + + # Resize class token + class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape) + + # Merge image embedding with class tokens + image_embeds = image_embeds[:, 1:, :] * class_token_out + image_embeds = self.layer_norm(image_embeds) + + # Resize to [batch_size, num_patches_height, num_patches_width, hidden_size] + new_size = ( + image_embeds.shape[0], + num_patches_height, + num_patches_width, + image_embeds.shape[-1], + ) + image_embeds = image_embeds.reshape(new_size) + text_embeds = outputs[-4] + + return (text_embeds, image_embeds, outputs) + + def image_embedder( + self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor]: + # Get OwlViTModel vision embeddings (same as CLIP) + vision_outputs: BaseModelOutputWithPooling = self.owlvit.vision_model( + pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs + ) + + if interpolate_pos_encoding: + _, _, height, width = pixel_values.shape + num_patches_height = height // self.config.vision_config.patch_size + num_patches_width = width // self.config.vision_config.patch_size + else: + num_patches_height = self.num_patches_height + num_patches_width = self.num_patches_width + + # Apply post_layernorm to last_hidden_state, return non-projected output + last_hidden_state = vision_outputs[0] + image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state) + + # Resize class token + class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape) + + # Merge image embedding with class tokens + image_embeds = image_embeds[:, 1:, :] * class_token_out + image_embeds = self.layer_norm(image_embeds) + + # Resize to [batch_size, num_patches_height, num_patches_width, hidden_size] + new_size = ( + image_embeds.shape[0], + num_patches_height, + num_patches_width, + image_embeds.shape[-1], + ) + image_embeds = image_embeds.reshape(new_size) + + return (image_embeds, vision_outputs) + + def embed_image_query( + self, + query_image_features: torch.FloatTensor, + query_feature_map: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + _, class_embeds = self.class_predictor(query_image_features) + pred_boxes = self.box_predictor(query_image_features, query_feature_map, interpolate_pos_encoding) + pred_boxes_as_corners = center_to_corners_format(pred_boxes) + + # Loop over query images + best_class_embeds = [] + best_box_indices = [] + pred_boxes_device = pred_boxes_as_corners.device + + for i in range(query_image_features.shape[0]): + each_query_box = torch.tensor([[0, 0, 1, 1]], device=pred_boxes_device) + each_query_pred_boxes = pred_boxes_as_corners[i] + ious, _ = box_iou(each_query_box, each_query_pred_boxes) + + # If there are no overlapping boxes, fall back to generalized IoU + if torch.all(ious[0] == 0.0): + ious = generalized_box_iou(each_query_box, each_query_pred_boxes) + + # Use an adaptive threshold to include all boxes within 80% of the best IoU + iou_threshold = torch.max(ious) * 0.8 + + selected_inds = (ious[0] >= iou_threshold).nonzero() + if selected_inds.numel(): + selected_embeddings = class_embeds[i][selected_inds.squeeze(1)] + mean_embeds = torch.mean(class_embeds[i], axis=0) + mean_sim = torch.einsum("d,id->i", mean_embeds, selected_embeddings) + best_box_ind = selected_inds[torch.argmin(mean_sim)] + best_class_embeds.append(class_embeds[i][best_box_ind]) + best_box_indices.append(best_box_ind) + + if best_class_embeds: + query_embeds = torch.stack(best_class_embeds) + box_indices = torch.stack(best_box_indices) + else: + query_embeds, box_indices = None, None + + return query_embeds, box_indices, pred_boxes + + @can_return_tuple + @auto_docstring + def image_guided_detection( + self, + pixel_values: torch.FloatTensor, + query_pixel_values: torch.FloatTensor | None = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> OwlViTImageGuidedObjectDetectionOutput: + r""" + query_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values of query image(s) to be detected. Pass in one query image per target image. + + Examples: + ```python + >>> import httpx + >>> from io import BytesIO + >>> from PIL import Image + >>> import torch + >>> from transformers import AutoProcessor, OwlViTForObjectDetection + + >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch16") + >>> model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> with httpx.stream("GET", url) as response: + ... image = Image.open(BytesIO(response.read())) + >>> query_url = "http://images.cocodataset.org/val2017/000000001675.jpg" + >>> with httpx.stream("GET", query_url) as response: + ... query_image = Image.open(BytesIO(response.read())) + >>> inputs = processor(images=image, query_images=query_image, return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model.image_guided_detection(**inputs) + >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2] + >>> target_sizes = torch.Tensor([image.size[::-1]]) + >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> results = processor.post_process_image_guided_detection( + ... outputs=outputs, threshold=0.6, nms_threshold=0.3, target_sizes=target_sizes + ... ) + >>> i = 0 # Retrieve predictions for the first image + >>> boxes, scores = results[i]["boxes"], results[i]["scores"] + >>> for box, score in zip(boxes, scores): + ... box = [round(i, 2) for i in box.tolist()] + ... print(f"Detected similar object with confidence {round(score.item(), 3)} at location {box}") + Detected similar object with confidence 0.856 at location [10.94, 50.4, 315.8, 471.39] + Detected similar object with confidence 1.0 at location [334.84, 25.33, 636.16, 374.71] + ```""" + # Compute feature maps for the input and query images + query_feature_map = self.image_embedder( + pixel_values=query_pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + )[0] + feature_map, vision_outputs = self.image_embedder( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + + batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape + image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)) + + batch_size, num_patches_height, num_patches_width, hidden_dim = query_feature_map.shape + query_image_feats = torch.reshape( + query_feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim) + ) + # Get top class embedding and best box index for each query image in batch + query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query( + query_image_feats, query_feature_map, interpolate_pos_encoding + ) + + # Predict object classes [batch_size, num_patches, num_queries+1] + (pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_embeds=query_embeds) + + # Predict object boxes + target_pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding) + + return OwlViTImageGuidedObjectDetectionOutput( + image_embeds=feature_map, + query_image_embeds=query_feature_map, + target_pred_boxes=target_pred_boxes, + query_pred_boxes=query_pred_boxes, + logits=pred_logits, + class_embeds=class_embeds, + text_model_output=None, + vision_model_output=vision_outputs, + ) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.Tensor, + pixel_values: torch.FloatTensor, + attention_mask: torch.Tensor | None = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> OwlViTObjectDetectionOutput: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids). + + Examples: + ```python + >>> import httpx + >>> from io import BytesIO + >>> from PIL import Image + >>> import torch + + >>> from transformers import OwlViTProcessor, OwlViTForObjectDetection + + >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") + >>> model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> with httpx.stream("GET", url) as response: + ... image = Image.open(BytesIO(response.read())) + >>> text_labels = [["a photo of a cat", "a photo of a dog"]] + >>> inputs = processor(text=text_labels, images=image, return_tensors="pt") + >>> outputs = model(**inputs) + + >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2] + >>> target_sizes = torch.tensor([(image.height, image.width)]) + >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> results = processor.post_process_grounded_object_detection( + ... outputs=outputs, target_sizes=target_sizes, threshold=0.1, text_labels=text_labels + ... ) + >>> # Retrieve predictions for the first image for the corresponding text queries + >>> result = results[0] + >>> boxes, scores, text_labels = result["boxes"], result["scores"], result["text_labels"] + >>> for box, score, text_label in zip(boxes, scores, text_labels): + ... box = [round(i, 2) for i in box.tolist()] + ... print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}") + Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29] + Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17] + ```""" + # Embed images and text queries + query_embeds, feature_map, outputs = self.image_text_embedder( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + + # Text and vision model outputs + text_outputs = outputs.text_model_output + vision_outputs = outputs.vision_model_output + + batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape + image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)) + + # Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim] + max_text_queries = input_ids.shape[0] // batch_size + query_embeds = query_embeds.reshape(batch_size, max_text_queries, query_embeds.shape[-1]) + + # If first token is 0, then this is a padded query [batch_size, num_queries]. + input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1]) + query_mask = input_ids[..., 0] > 0 + + # Predict object classes [batch_size, num_patches, num_queries+1] + (pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds, query_mask) + + # Predict object boxes + pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding) + + return OwlViTObjectDetectionOutput( + image_embeds=feature_map, + text_embeds=query_embeds, + pred_boxes=pred_boxes, + logits=pred_logits, + class_embeds=class_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +__all__ = ["OwlViTModel", "OwlViTPreTrainedModel", "OwlViTTextModel", "OwlViTVisionModel", "OwlViTForObjectDetection"]