Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 55 additions & 135 deletions src/transformers/models/owlvit/modeling_owlvit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/owlvit/modular_owlvit.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_owlvit.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2022 Google AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -11,30 +17,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OWL-ViT model."""

from collections.abc import Callable
from dataclasses import dataclass
from typing import Any

import torch
from torch import Tensor, nn
from torch import nn

from ... import initialization as init
from ...activations import ACT2FN
from ...loss.loss_for_object_detection import box_iou, generalized_box_iou
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
ModelOutput,
TransformersKwargs,
auto_docstring,
is_vision_available,
logging,
torch_int,
)
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, is_vision_available, torch_int
from ...utils.generic import can_return_tuple, merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
from .configuration_owlvit import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig
Expand All @@ -44,24 +43,6 @@
from transformers.image_transforms import center_to_corners_format


logger = logging.get_logger(__name__)


# See all OwlViT models at https://huggingface.co/models?filter=owlvit


# Copied from transformers.models.clip.modeling_clip.contrastive_loss with clip->owlvit
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))


# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->owlvit
def owlvit_loss(similarity: torch.Tensor) -> torch.Tensor:
caption_loss = contrastive_loss(similarity)
image_loss = contrastive_loss(similarity.t())
return (caption_loss + image_loss) / 2.0


@dataclass
@auto_docstring
class OwlViTOutput(ModelOutput):
Expand Down Expand Up @@ -100,74 +81,6 @@ def to_tuple(self) -> tuple[Any]:
)


# Copied from transformers.loss.loss_for_object_detection._upcast
def _upcast(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.is_floating_point():
return t if t.dtype in (torch.float32, torch.float64) else t.float()
else:
return t if t.dtype in (torch.int32, torch.int64) else t.int()


# Copied from transformers.loss.loss_for_object_detection.box_area
def box_area(boxes: Tensor) -> Tensor:
"""
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.

Args:
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
< x2` and `0 <= y1 < y2`.

Returns:
`torch.FloatTensor`: a tensor containing the area for each box.
"""
boxes = _upcast(boxes)
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])


# Copied from transformers.loss.loss_for_object_detection.box_iou
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)

left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]

width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]

union = area1[:, None] + area2 - inter

iou = inter / union
return iou, union


# Copied from transformers.loss.loss_for_object_detection.generalized_box_iou
def generalized_box_iou(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.

Returns:
`torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
iou, union = box_iou(boxes1, boxes2)

top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])

width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
area = width_height[:, :, 0] * width_height[:, :, 1]

return iou - (area - union) / area


@dataclass
@auto_docstring(
custom_intro="""
Expand Down Expand Up @@ -274,25 +187,26 @@ def to_tuple(self) -> tuple[Any]:
class OwlViTVisionEmbeddings(nn.Module):
def __init__(self, config: OwlViTVisionConfig):
super().__init__()
self.patch_size = config.patch_size
self.config = config
self.embed_dim = config.hidden_size
self.class_embedding = nn.Parameter(torch.randn(config.hidden_size))
self.image_size = config.image_size
self.patch_size = config.patch_size

self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=config.patch_size,
stride=config.patch_size,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)

self.num_patches = (config.image_size // config.patch_size) ** 2
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)

# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
Expand Down Expand Up @@ -322,18 +236,21 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width:
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values) # shape = [batch_size, num_channels, height, width]
patch_embeds = self.patch_embedding(pixel_values)
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
Expand All @@ -347,8 +264,10 @@ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: boo
class OwlViTTextEmbeddings(nn.Module):
def __init__(self, config: OwlViTTextConfig):
super().__init__()
self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
embed_dim = config.hidden_size

self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)

# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
Expand All @@ -375,49 +294,36 @@ def forward(
return embeddings


# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float | None = None,
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
if scaling is None:
scaling = query.size(-1) ** -0.5

# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling

attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attn_weights


class OwlViTAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config):
def __init__(self, config: OwlViTVisionConfig | OwlViTTextConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.is_causal = False
Expand All @@ -433,35 +339,39 @@ def forward(
attention_mask: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
"""Input shape: Batch x Time x Channel"""

batch_size, seq_length, embed_dim = hidden_states.shape

query_states = self.q_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)

queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)

attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)

attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
queries,
keys,
values,
attention_mask,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
**kwargs,
)

attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
attn_output = self.out_proj(attn_output)

return attn_output, attn_weights


# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->OwlViT
class OwlViTMLP(nn.Module):
def __init__(self, config):
super().__init__()
Expand All @@ -477,7 +387,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->OwlViT
class OwlViTEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: OwlViTVisionConfig | OwlViTTextConfig):
super().__init__()
Expand Down Expand Up @@ -577,7 +486,6 @@ def _init_weights(self, module: nn.Module):
init.zeros_(module.bias)


# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->OwlViT
class OwlViTEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
Expand All @@ -587,7 +495,7 @@ class OwlViTEncoder(nn.Module):
config: OwlViTConfig
"""

def __init__(self, config: OwlViTConfig):
def __init__(self, config: OwlViTTextConfig | OwlViTVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([OwlViTEncoderLayer(config) for _ in range(config.num_hidden_layers)])
Expand Down Expand Up @@ -831,6 +739,18 @@ def forward(
)


# contrastive loss function, adapted from
# https://sachinruk.github.io/blog/2021-03-07-owlvit.html
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))


def owlvit_loss(similarity: torch.Tensor) -> torch.Tensor:
caption_loss = contrastive_loss(similarity)
image_loss = contrastive_loss(similarity.t())
return (caption_loss + image_loss) / 2.0


@auto_docstring
class OwlViTModel(OwlViTPreTrainedModel):
config: OwlViTConfig
Expand Down
Loading
Loading