diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 4ed0930605e8..84e368379e29 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -32,7 +32,7 @@ ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, torch_int from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig @@ -100,6 +100,8 @@ output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -137,6 +139,8 @@ output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -1009,15 +1013,63 @@ def __init__(self, config: AltCLIPVisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + self.position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})." + ) target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings @@ -1097,6 +1149,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -1111,7 +1164,7 @@ def forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layrnorm(hidden_states) encoder_outputs = self.encoder( @@ -1156,6 +1209,7 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" @@ -1186,6 +1240,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1546,6 +1601,7 @@ def get_image_features( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" @@ -1578,6 +1634,7 @@ def get_image_features( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1598,6 +1655,7 @@ def forward( return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, AltCLIPOutput]: r""" @@ -1642,6 +1700,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 81785e147db9..c7b32de0d08e 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -34,7 +34,13 @@ ) from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig @@ -111,6 +117,8 @@ output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, defaults to `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -276,15 +284,63 @@ def __init__(self, config: BridgeTowerVisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + self.position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})." + ) target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings @@ -302,8 +358,13 @@ def __init__(self, config): [nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) for _ in range(config.num_hidden_layers)] ) - def forward(self, pixel_values: torch.Tensor, attention_mask): - hidden_states = self.embeddings(pixel_values) + def forward( + self, + pixel_values: torch.Tensor, + attention_mask, + interpolate_pos_encoding: bool = False, + ): + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding) hidden_states = self.ln_pre(hidden_states) # NLD -> LND hidden_states = hidden_states.permute(1, 0, 2) @@ -324,8 +385,12 @@ def forward(self, pixel_values: torch.Tensor, attention_mask): hidden_states = torch.stack(hidden_states_stack, dim=0) return hidden_states - def forward_pre(self, pixel_values: torch.Tensor): - hidden_states = self.embeddings(pixel_values) + def forward_pre( + self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = False, + ): + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.ln_pre(hidden_states) # NLD -> LND hidden_states = hidden_states.permute(1, 0, 2) @@ -1015,8 +1080,8 @@ def __init__(self, config): def dtype(self): return self.visual.embeddings.patch_embedding.weight.dtype - def forward(self, image, image_mask=None): - return self.visual(image.type(self.dtype), image_mask) + def forward(self, image, image_mask=None, interpolate_pos_encoding=False): + return self.visual(image.type(self.dtype), image_mask, interpolate_pos_encoding) class BridgeTowerTextModel(BridgeTowerPreTrainedModel): @@ -1280,6 +1345,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple[torch.Tensor], BridgeTowerModelOutput]: r""" output_hidden_states (`bool`, *optional*): @@ -1352,7 +1418,9 @@ def forward( all_hidden_states_text += (text_embeds,) if image_embeds is None: - image_embeds = self.vision_model.visual.forward_pre(pixel_values.type(self.vision_model.dtype)) + image_embeds = self.vision_model.visual.forward_pre( + pixel_values.type(self.vision_model.dtype), interpolate_pos_encoding=interpolate_pos_encoding + ) else: # Permute as BridgeTowerResidualAttention has batch_first=True image_embeds = image_embeds.permute(1, 0, 2) diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 6fbd9459f5ad..393d5784bb44 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -38,6 +38,7 @@ add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from .configuration_chinese_clip import ChineseCLIPConfig, ChineseCLIPTextConfig, ChineseCLIPVisionConfig @@ -188,15 +189,63 @@ def __init__(self, config: ChineseCLIPVisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + self.position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})." + ) target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings @@ -798,6 +847,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -813,6 +864,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -1052,6 +1105,7 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" @@ -1066,7 +1120,7 @@ def forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layrnorm(hidden_states) encoder_outputs = self.encoder( @@ -1299,6 +1353,7 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" @@ -1329,6 +1384,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1425,6 +1481,7 @@ def get_image_features( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" @@ -1461,6 +1518,7 @@ def get_image_features( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1481,6 +1539,7 @@ def forward( return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, ChineseCLIPOutput]: r""" @@ -1516,6 +1575,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 64eb027e9e22..370f17f47965 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -36,6 +36,7 @@ is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, + torch_int, ) from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig @@ -196,15 +197,63 @@ def __init__(self, config: CLIPVisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + self.position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})." + ) target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings @@ -704,6 +753,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -741,6 +792,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -1023,6 +1076,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -1037,7 +1091,7 @@ def forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layrnorm(hidden_states) encoder_outputs = self.encoder( @@ -1087,6 +1141,7 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" @@ -1118,6 +1173,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) @@ -1214,6 +1270,7 @@ def get_image_features( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" @@ -1249,6 +1306,7 @@ def get_image_features( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1268,6 +1326,7 @@ def forward( return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, CLIPOutput]: r""" @@ -1305,6 +1364,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1466,6 +1526,7 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, CLIPVisionModelOutput]: r""" @@ -1495,6 +1556,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index a6507e431f68..90520524fa88 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -33,6 +33,7 @@ add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from .configuration_clipseg import CLIPSegConfig, CLIPSegTextConfig, CLIPSegVisionConfig @@ -163,40 +164,62 @@ def __init__(self, config: CLIPSegVisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def interpolate_position_embeddings(self, new_size): - if len(new_size) != 2: - raise ValueError("new_size should consist of 2 values") + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. - num_patches_one_direction = int(self.num_patches**0.5) - # we interpolate the position embeddings in 2D - a = self.position_embedding.weight[1:].T.view( - 1, self.config.hidden_size, num_patches_one_direction, num_patches_one_direction - ) - b = ( - nn.functional.interpolate(a, new_size, mode="bicubic", align_corners=False) - .squeeze(0) - .view(self.config.hidden_size, new_size[0] * new_size[1]) - .T + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + self.position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, ) - result = torch.cat([self.position_embedding.weight[:1], b]) - return result + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})." + ) patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - - if embeddings.shape[1] != self.num_positions: - new_shape = int(math.sqrt(embeddings.shape[1] - 1)) - embeddings = embeddings + self.interpolate_position_embeddings((new_shape, new_shape)) - embeddings = embeddings.to(embeddings.dtype) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) else: embeddings = embeddings + self.position_embedding(self.position_ids) - return embeddings @@ -512,6 +535,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -549,6 +574,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -825,6 +852,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -839,7 +867,7 @@ def forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layrnorm(hidden_states) encoder_outputs = self.encoder( @@ -884,6 +912,7 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" @@ -912,6 +941,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1005,6 +1035,7 @@ def get_image_features( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" @@ -1040,6 +1071,7 @@ def get_image_features( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1059,6 +1091,7 @@ def forward( return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, CLIPSegOutput]: r""" @@ -1096,6 +1129,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1363,6 +1397,7 @@ def forward( labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, CLIPSegOutput]: r""" @@ -1402,6 +1437,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=True, # we need the intermediate hidden states + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) pooled_output = self.clip.visual_projection(vision_outputs[1]) diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 59d3a406ec35..5fa29d70d8a5 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -37,7 +37,13 @@ ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) from .configuration_git import GitConfig, GitVisionConfig @@ -602,6 +608,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -631,15 +639,63 @@ def __init__(self, config: GitVisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + self.position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})." + ) target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings @@ -924,6 +980,8 @@ def forward( output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -948,6 +1006,7 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: r""" @@ -963,7 +1022,7 @@ def forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layrnorm(hidden_states) encoder_outputs = self.encoder( @@ -1012,6 +1071,7 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: r""" @@ -1041,6 +1101,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1167,6 +1228,7 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]: r""" @@ -1235,13 +1297,17 @@ def forward( if pixel_values is not None: if pixel_values.ndim == 4: # here we assume pixel_values is of shape (batch_size, num_channels, height, width) - visual_features = self.image_encoder(pixel_values).last_hidden_state + visual_features = self.image_encoder( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ).last_hidden_state elif pixel_values.ndim == 5: # here we assume pixel_values is of shape (batch_size, num_frames, num_channels, height, width) visual_features = [] for frame_idx in range(pixel_values.shape[1]): - visual_features_frame = self.image_encoder(pixel_values[:, frame_idx, :, :]).last_hidden_state + visual_features_frame = self.image_encoder( + pixel_values[:, frame_idx, :, :], interpolate_pos_encoding=interpolate_pos_encoding + ).last_hidden_state visual_features_frame += self.img_temperal_embedding[frame_idx] visual_features.append(visual_features_frame) @@ -1358,6 +1424,7 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]: r""" @@ -1511,6 +1578,7 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 90e21ed2f558..5adc48a3a2ef 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -38,6 +38,7 @@ add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from .configuration_kosmos2 import Kosmos2Config, Kosmos2TextConfig, Kosmos2VisionConfig @@ -121,6 +122,8 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -259,6 +262,8 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -401,15 +406,63 @@ def __init__(self, config: Kosmos2VisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + self.position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})." + ) target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings @@ -701,6 +754,7 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -712,7 +766,7 @@ def forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layrnorm(hidden_states) encoder_outputs = self.encoder( @@ -1443,6 +1497,7 @@ def forward( pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" @@ -1453,6 +1508,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) @@ -1769,6 +1825,7 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, Kosmos2ModelOutput]: r""" @@ -1820,6 +1877,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`. diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 791e501d1737..d05db378443e 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -32,6 +32,7 @@ add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from .configuration_x_clip import XCLIPConfig, XCLIPTextConfig, XCLIPVisionConfig @@ -121,15 +122,63 @@ def __init__(self, config: XCLIPVisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + self.position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})." + ) target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings @@ -567,6 +616,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -604,6 +655,8 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -954,6 +1007,7 @@ def forward( pixel_values: torch.FloatTensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" @@ -966,7 +1020,7 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layernorm(hidden_states) encoder_outputs = self.encoder( @@ -1455,6 +1509,7 @@ def forward( return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, XCLIPOutput]: r""" @@ -1555,6 +1610,7 @@ def forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) diff --git a/tests/models/altclip/test_modeling_altclip.py b/tests/models/altclip/test_modeling_altclip.py index 83b6d60595d3..f4ac29479c5f 100755 --- a/tests/models/altclip/test_modeling_altclip.py +++ b/tests/models/altclip/test_modeling_altclip.py @@ -597,3 +597,44 @@ def test_inference(self): expected_probs = torch.tensor([[9.9942e-01, 5.7805e-04]], device=torch_device) self.assertTrue(torch.allclose(probs, expected_probs, atol=5e-3)) + + @slow + def test_inference_interpolate_pos_encoding(self): + # ViT models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + model_name = "BAAI/AltCLIP" + model = AltCLIPModel.from_pretrained(model_name).to(torch_device) + + image_processor = AltCLIPProcessor.from_pretrained( + model_name, size={"shortest_edge": 180}, crop_size={"height": 180, "width": 180} + ) + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = image_processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device) + + # interpolate_pos_encodiung false should return value error + with self.assertRaises(ValueError, msg="doesn't match model"): + with torch.no_grad(): + model(**inputs, interpolate_pos_encoding=False) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 145, 1024)) + print("nilesh ") + print(outputs.vision_model_output.last_hidden_state.shape) + print(outputs.vision_model_output.last_hidden_state[0, :3, :3]) + + self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[-0.3589, -0.5939, 0.3534], [0.4346, 0.1647, 0.7071], [1.1404, -0.4716, 0.1664]] + ).to(torch_device) + + self.assertTrue( + torch.allclose(outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4) + ) diff --git a/tests/models/bridgetower/test_modeling_bridgetower.py b/tests/models/bridgetower/test_modeling_bridgetower.py index 44e6a404f623..cceeee4912dc 100644 --- a/tests/models/bridgetower/test_modeling_bridgetower.py +++ b/tests/models/bridgetower/test_modeling_bridgetower.py @@ -656,3 +656,37 @@ def test_training(self): for name, param in model.named_parameters(): if self._is_layer_used(model_class, name): self.assertIsNotNone(param.grad, f"Gradients should not be None - got {param.grad} for {name}") + + @slow + def test_inference_interpolate_pos_encoding(self): + # ViT models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + model_name = "BridgeTower/bridgetower-base" + model = BridgeTowerModel.from_pretrained(model_name).to(torch_device) + + image_processor = BridgeTowerProcessor.from_pretrained(model_name, size={"shortest_edge": 180}) + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = image_processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device) + + # interpolate_pos_encodiung false should return value error + with self.assertRaises(ValueError, msg="doesn't match model"): + with torch.no_grad(): + model(**inputs, interpolate_pos_encoding=False) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 122, 768)) + + self.assertEqual(outputs.image_features.shape, expected_shape) + + expected_slice = torch.tensor( + [[-0.6518, 0.4978, -0.4544], [-2.6672, -0.0843, -0.4210], [-2.4510, -0.1002, -0.3458]] + ).to(torch_device) + + self.assertTrue(torch.allclose(outputs.image_features[0, :3, :3], expected_slice, atol=1e-4)) diff --git a/tests/models/chinese_clip/test_modeling_chinese_clip.py b/tests/models/chinese_clip/test_modeling_chinese_clip.py index 7046f28b5f94..647b3ac7b73a 100644 --- a/tests/models/chinese_clip/test_modeling_chinese_clip.py +++ b/tests/models/chinese_clip/test_modeling_chinese_clip.py @@ -740,3 +740,41 @@ def test_inference(self): expected_probs = torch.tensor([[1.2686e-03, 5.4499e-02, 6.7968e-04, 9.4355e-01]], device=torch_device) self.assertTrue(torch.allclose(probs, expected_probs, atol=5e-3)) + + @slow + def test_inference_interpolate_pos_encoding(self): + # ViT models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + model_name = "OFA-Sys/chinese-clip-vit-base-patch16" + model = ChineseCLIPModel.from_pretrained(model_name).to(torch_device) + + image_processor = ChineseCLIPProcessor.from_pretrained( + model_name, size={"height": 180, "width": 180}, crop_size={"height": 180, "width": 180} + ) + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = image_processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device) + + # interpolate_pos_encodiung false should return value error + with self.assertRaises(ValueError, msg="doesn't match model"): + with torch.no_grad(): + model(**inputs, interpolate_pos_encoding=False) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 122, 768)) + + self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[-0.3990, 0.2983, -0.1239], [-0.1452, -0.2759, 0.0403], [-0.3149, -0.4763, 0.8555]] + ).to(torch_device) + + self.assertTrue( + torch.allclose(outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4) + ) diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index 3b6994428088..c94aa0412653 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -1179,3 +1179,40 @@ def test_inference(self): expected_logits = torch.tensor([[24.5701, 19.3049]], device=torch_device) self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3)) + + @slow + def test_inference_interpolate_pos_encoding(self): + # CLIP models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(torch_device) + + processor = CLIPProcessor.from_pretrained( + "openai/clip-vit-base-patch32", size={"height": 180, "width": 180}, crop_size={"height": 180, "width": 180} + ) + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device) + + # interpolate_pos_encodiung false should return value error + with self.assertRaises(ValueError, msg="doesn't match model"): + with torch.no_grad(): + model(**inputs, interpolate_pos_encoding=False) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 26, 768)) + + self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[-0.1538, 0.0322, -0.3235], [0.2893, 0.1135, -0.5708], [0.0461, 0.1540, -0.6018]] + ).to(torch_device) + + self.assertTrue( + torch.allclose(outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4) + ) diff --git a/tests/models/clipseg/test_modeling_clipseg.py b/tests/models/clipseg/test_modeling_clipseg.py index a6f286c4c6b7..c5edf7cb757b 100644 --- a/tests/models/clipseg/test_modeling_clipseg.py +++ b/tests/models/clipseg/test_modeling_clipseg.py @@ -796,7 +796,7 @@ def test_inference_image_segmentation(self): # forward pass with torch.no_grad(): - outputs = model(**inputs) + outputs = model(**inputs, interpolate_pos_encoding=True) # verify the predicted masks self.assertEqual( @@ -804,8 +804,9 @@ def test_inference_image_segmentation(self): torch.Size((3, 352, 352)), ) expected_masks_slice = torch.tensor( - [[-7.4613, -7.4785, -7.3628], [-7.3268, -7.0899, -7.1333], [-6.9838, -6.7900, -6.8913]] + [[-7.4613, -7.4785, -7.3627], [-7.3268, -7.0898, -7.1333], [-6.9838, -6.7900, -6.8913]] ).to(torch_device) + self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_masks_slice, atol=1e-3)) # verify conditional and pooled output @@ -813,3 +814,40 @@ def test_inference_image_segmentation(self): expected_pooled_output = torch.tensor([0.5036, -0.2681, -0.2644]).to(torch_device) self.assertTrue(torch.allclose(outputs.conditional_embeddings[0, :3], expected_conditional, atol=1e-3)) self.assertTrue(torch.allclose(outputs.pooled_output[0, :3], expected_pooled_output, atol=1e-3)) + + @slow + def test_inference_interpolate_pos_encoding(self): + # ViT models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + model = CLIPSegModel.from_pretrained("openai/clip-vit-base-patch32").to(torch_device) + + processor = CLIPSegProcessor.from_pretrained( + "openai/clip-vit-base-patch32", size={"height": 180, "width": 180}, crop_size={"height": 180, "width": 180} + ) + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device) + + # interpolate_pos_encodiung false should return value error + with self.assertRaises(ValueError, msg="doesn't match model"): + with torch.no_grad(): + model(**inputs, interpolate_pos_encoding=False) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 26, 768)) + + self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[-0.1538, 0.0322, -0.3235], [0.2893, 0.1135, -0.5708], [0.0461, 0.1540, -0.6018]] + ).to(torch_device) + + self.assertTrue( + torch.allclose(outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4) + ) diff --git a/tests/models/git/test_modeling_git.py b/tests/models/git/test_modeling_git.py index 02a3c033c452..33da9e26cba0 100644 --- a/tests/models/git/test_modeling_git.py +++ b/tests/models/git/test_modeling_git.py @@ -614,3 +614,38 @@ def test_batched_generation(self): generated_captions = processor.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(generated_captions, ["two cats sleeping on a pink blanket next to remotes."] * 2) + + @slow + def test_inference_interpolate_pos_encoding(self): + # CLIP family models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + model = GitModel.from_pretrained("microsoft/git-base").to(torch_device) + + processor = GitProcessor.from_pretrained( + "microsoft/git-base", size={"height": 180, "width": 180}, crop_size={"height": 180, "width": 180} + ) + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device) + + # interpolate_pos_encodiung false should return value error + with self.assertRaises(ValueError, msg="doesn't match model"): + with torch.no_grad(): + model(**inputs, interpolate_pos_encoding=False) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 130, 768)) + + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[-1.0296, 2.5960, 0.8703], [1.7027, 1.3302, -0.4543], [-1.4932, -0.1084, 0.0502]] + ).to(torch_device) + + self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)) diff --git a/tests/models/kosmos2/test_modeling_kosmos2.py b/tests/models/kosmos2/test_modeling_kosmos2.py index 6f34689004ef..913111c0a088 100644 --- a/tests/models/kosmos2/test_modeling_kosmos2.py +++ b/tests/models/kosmos2/test_modeling_kosmos2.py @@ -762,3 +762,40 @@ def test_snowman_image_captioning_batch(self): self.assertEqual(processed_text[0], EXPECTED_PROCESSED_TEXT_0) self.assertEqual(all_final_text[0], EXPECTED_FINAL_TEXT_0) self.assertListEqual(all_entities[0], EXPECTED_ENTITIES_0) + + @slow + def test_inference_interpolate_pos_encoding(self): + # ViT models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + model = Kosmos2Model.from_pretrained("microsoft/kosmos-2-patch14-224").to(torch_device) + + processor = AutoProcessor.from_pretrained( + "microsoft/kosmos-2-patch14-224", size={"shortest_edge": 180}, crop_size={"height": 180, "width": 180} + ) + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device) + + # interpolate_pos_encodiung false should return value error + with self.assertRaises(ValueError, msg="doesn't match model"): + with torch.no_grad(): + model(**inputs, interpolate_pos_encoding=False) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 145, 1024)) + + self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[1.0022, -1.1901, 3.2887], [2.6164, 0.0515, -0.8270], [1.8315, 0.1272, -0.8590]] + ).to(torch_device) + + self.assertTrue( + torch.allclose(outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4) + ) diff --git a/tests/models/x_clip/test_modeling_x_clip.py b/tests/models/x_clip/test_modeling_x_clip.py index 70e7bb341c7e..8b91019bae18 100644 --- a/tests/models/x_clip/test_modeling_x_clip.py +++ b/tests/models/x_clip/test_modeling_x_clip.py @@ -731,3 +731,39 @@ def test_inference(self): expected_logits = torch.tensor([[14.0181, 20.2771, 14.4776]], device=torch_device) self.assertTrue(torch.allclose(outputs.logits_per_video, expected_logits, atol=1e-3)) + + @slow + def test_inference_interpolate_pos_encoding(self): + # XCLIP models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + model = XCLIPModel.from_pretrained("microsoft/xclip-base-patch32").to(torch_device) + + processor = XCLIPProcessor.from_pretrained( + "microsoft/xclip-base-patch32", size=180, crop_size={"height": 180, "width": 180} + ) + + video = prepare_video() + inputs = processor(text="what's in the video", videos=video, return_tensors="pt").to(torch_device) + + # interpolate_pos_encodiung false should return value error + with self.assertRaises(ValueError, msg="doesn't match model"): + with torch.no_grad(): + model(**inputs, interpolate_pos_encoding=False) + # forward pass + with torch.no_grad(): + outputs = model(**inputs, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((8, 26, 768)) + + self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[0.0126, 0.2109, 0.0609], [0.0448, 0.5862, -0.1688], [-0.0881, 0.8525, -0.3044]] + ).to(torch_device) + + self.assertTrue( + torch.allclose(outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4) + )