From 938e490edf04ac19473fac27b6c44c805121d5c8 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Thu, 12 Sep 2024 18:59:34 +0100 Subject: [PATCH 1/6] rough outline --- src/transformers/__init__.py | 19 +- src/transformers/models/__init__.py | 2 +- .../models/auto/configuration_auto.py | 4 +- .../models/auto/image_processing_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 4 +- .../models/auto/processing_auto.py | 2 +- .../models/auto/tokenization_auto.py | 2 +- src/transformers/models/pixtral/__init__.py | 21 +- .../models/pixtral/configuration_pixtral.py | 1 - .../pixtral/image_processing_pixtral.py | 383 ++++++++++++++++++ .../models/pixtral/processing_pixtral.py | 191 +++++++++ 11 files changed, 611 insertions(+), 19 deletions(-) create mode 100644 src/transformers/models/pixtral/image_processing_pixtral.py create mode 100644 src/transformers/models/pixtral/processing_pixtral.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 94bfae8ebcde..a2be4b145ac8 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -528,7 +528,6 @@ ], "models.pixtral": [ "PixtralConfig", - ], "models.llava_next": [ "LlavaNextConfig", @@ -647,6 +646,7 @@ "models.phi": ["PhiConfig"], "models.phi3": ["Phi3Config"], "models.phobert": ["PhobertTokenizer"], + "models.pixtral": ["PixtralConfig", "PixtralProcessor"], "models.pix2struct": [ "Pix2StructConfig", "Pix2StructProcessor", @@ -1202,6 +1202,7 @@ _import_structure["models.owlv2"].append("Owlv2ImageProcessor") _import_structure["models.owlvit"].extend(["OwlViTFeatureExtractor", "OwlViTImageProcessor"]) _import_structure["models.perceiver"].extend(["PerceiverFeatureExtractor", "PerceiverImageProcessor"]) + _import_structure["models.pixtral"].append("PixtralImageProcessor") _import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"]) _import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"]) _import_structure["models.pvt"].extend(["PvtImageProcessor"]) @@ -5300,10 +5301,6 @@ LlavaConfig, LlavaProcessor, ) - from .models.pixtral import ( - PixtralConfig, - - ) from .models.llava_next import ( LlavaNextConfig, LlavaNextProcessor, @@ -5448,6 +5445,9 @@ Pix2StructTextConfig, Pix2StructVisionConfig, ) + from .models.pixtral import ( + PixtralConfig, PixtralProcessor, + ) from .models.plbart import PLBartConfig from .models.poolformer import ( PoolFormerConfig, @@ -6027,6 +6027,7 @@ PoolFormerFeatureExtractor, PoolFormerImageProcessor, ) + from .models.pixtral import PixtralImageProcessor from .models.pvt import PvtImageProcessor from .models.qwen2_vl import Qwen2VLImageProcessor from .models.rt_detr import RTDetrImageProcessor @@ -7111,10 +7112,6 @@ LlavaForConditionalGeneration, LlavaPreTrainedModel, ) - from .models.pixtral import ( - PixtralModel, - PixtralPreTrainedModel, - ) from .models.llava_next import ( LlavaNextForConditionalGeneration, LlavaNextPreTrainedModel, @@ -7466,6 +7463,10 @@ Pix2StructTextModel, Pix2StructVisionModel, ) + from .models.pixtral import ( + PixtralModel, + PixtralPreTrainedModel, + ) from .models.plbart import ( PLBartForCausalLM, PLBartForConditionalGeneration, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 0ecadd8d2216..2022048cd455 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -130,7 +130,6 @@ lilt, llama, llava, - pixtral, llava_next, llava_next_video, llava_onevision, @@ -188,6 +187,7 @@ phi3, phobert, pix2struct, + pixtral, plbart, poolformer, pop2piano, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index ac2b3ca34949..97d9c60fa41c 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -147,7 +147,6 @@ ("lilt", "LiltConfig"), ("llama", "LlamaConfig"), ("llava", "LlavaConfig"), - ("pixtral", "PixtralConfig"), ("llava_next", "LlavaNextConfig"), ("llava_next_video", "LlavaNextVideoConfig"), ("llava_onevision", "LlavaOnevisionConfig"), @@ -206,6 +205,7 @@ ("phi", "PhiConfig"), ("phi3", "Phi3Config"), ("pix2struct", "Pix2StructConfig"), + ("pixtral", "PixtralConfig"), ("plbart", "PLBartConfig"), ("poolformer", "PoolFormerConfig"), ("pop2piano", "Pop2PianoConfig"), @@ -444,7 +444,6 @@ ("llama2", "Llama2"), ("llama3", "Llama3"), ("llava", "LLaVa"), - ("pixtral", "Pixtral"), ("llava_next", "LLaVA-NeXT"), ("llava_next_video", "LLaVa-NeXT-Video"), ("llava_onevision", "LLaVA-Onevision"), @@ -511,6 +510,7 @@ ("phi3", "Phi3"), ("phobert", "PhoBERT"), ("pix2struct", "Pix2Struct"), + ("pixtral", "Pixtral"), ("plbart", "PLBart"), ("poolformer", "PoolFormer"), ("pop2piano", "Pop2Piano"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index c83c43518a6a..95d9ddef8f79 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -114,6 +114,7 @@ ("owlvit", ("OwlViTImageProcessor",)), ("perceiver", ("PerceiverImageProcessor",)), ("pix2struct", ("Pix2StructImageProcessor",)), + ("pixtral", ("PixtralImageProcessor",)), ("poolformer", ("PoolFormerImageProcessor",)), ("pvt", ("PvtImageProcessor",)), ("pvt_v2", ("PvtImageProcessor",)), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 1f860c25e8db..eb2ab82c960d 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -193,6 +193,7 @@ ("persimmon", "PersimmonModel"), ("phi", "PhiModel"), ("phi3", "Phi3Model"), + ("pixtral", "PixtralModel"), ("plbart", "PLBartModel"), ("poolformer", "PoolFormerModel"), ("prophetnet", "ProphetNetModel"), @@ -277,7 +278,6 @@ ("xmod", "XmodModel"), ("yolos", "YolosModel"), ("yoso", "YosoModel"), - ("pixtral", "PixtralModel"), ] ) @@ -729,12 +729,12 @@ ("instructblipvideo", "InstructBlipVideoForConditionalGeneration"), ("kosmos-2", "Kosmos2ForConditionalGeneration"), ("llava", "LlavaForConditionalGeneration"), - ("pixtral", "PixtralModel"), ("llava_next", "LlavaNextForConditionalGeneration"), ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), ("paligemma", "PaliGemmaForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"), + ("pixtral", "PixtralModel"), ("qwen2_vl", "Qwen2VLForConditionalGeneration"), ("video_llava", "VideoLlavaForConditionalGeneration"), ("vipllava", "VipLlavaForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index d055693316df..82d325248eab 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -71,7 +71,6 @@ ("layoutlmv2", "LayoutLMv2Processor"), ("layoutlmv3", "LayoutLMv3Processor"), ("llava", "LlavaProcessor"), - ("pixtral", "PixtralProcessor"), ("llava_next", "LlavaNextProcessor"), ("llava_next_video", "LlavaNextVideoProcessor"), ("llava_onevision", "LlavaOnevisionProcessor"), @@ -83,6 +82,7 @@ ("owlvit", "OwlViTProcessor"), ("paligemma", "PaliGemmaProcessor"), ("pix2struct", "Pix2StructProcessor"), + ("pixtral", "PixtralProcessor"), ("pop2piano", "Pop2PianoProcessor"), ("qwen2_audio", "Qwen2AudioProcessor"), ("qwen2_vl", "Qwen2VLProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 00bbef64d99a..2f0e8591740d 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -257,7 +257,6 @@ ), ), ("llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), - ("pixtral", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("llava-onevision", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("llava_next", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("llava_next_video", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), @@ -386,6 +385,7 @@ ("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("phobert", ("PhobertTokenizer", None)), ("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)), + ("pixtral", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)), ("prophetnet", ("ProphetNetTokenizer", None)), ("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/pixtral/__init__.py b/src/transformers/models/pixtral/__init__.py index cb43777e2d5a..2020287559a1 100644 --- a/src/transformers/models/pixtral/__init__.py +++ b/src/transformers/models/pixtral/__init__.py @@ -13,11 +13,12 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available _import_structure = { "configuration_pixtral": ["PixtralConfig"], + "processing_pixtral": ["PixtralProcessor"], } @@ -32,9 +33,17 @@ "PixtralPreTrainedModel", ] +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["image_processing_pixtral"] = ["PixtralImageProcessor"] + if TYPE_CHECKING: - from .configuration_pixtral import PixtralConfig + from .configuration_pixtral import PixtralConfig, PixtralProcessor try: if not is_torch_available(): @@ -47,6 +56,14 @@ PixtralPreTrainedModel, ) + try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .image_processing_pixtral import PixtralImageProcessor + else: import sys diff --git a/src/transformers/models/pixtral/configuration_pixtral.py b/src/transformers/models/pixtral/configuration_pixtral.py index 90ce7785f040..667466872ff1 100644 --- a/src/transformers/models/pixtral/configuration_pixtral.py +++ b/src/transformers/models/pixtral/configuration_pixtral.py @@ -15,7 +15,6 @@ from ...configuration_utils import PretrainedConfig from ...utils import logging -from ..auto import CONFIG_MAPPING logger = logging.get_logger(__name__) diff --git a/src/transformers/models/pixtral/image_processing_pixtral.py b/src/transformers/models/pixtral/image_processing_pixtral.py new file mode 100644 index 000000000000..6d354dbbd5fd --- /dev/null +++ b/src/transformers/models/pixtral/image_processing_pixtral.py @@ -0,0 +1,383 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Pixtral.""" + +from typing import Dict, List, Optional, Union, Tuple + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_kwargs, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_vision_available, logging + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + + +# Adapted from function in image_transforms.py t oensure any transparent pixels are converted to white. +def convert_to_rgb(image: ImageInput) -> ImageInput: + """ + Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image + as is. + Args: + image (Image): + The image to convert. + """ + requires_backends(convert_to_rgb, ["vision"]) + + if not isinstance(image, PIL.Image.Image): + return image + + if image.mode == "RGB": + return image + + # First we convert to RGBA to set background to white. + image = image.convert("RGBA") + + # Create a new image with a white background. + new_image = PIL.Image.new("RGBA", image.size, "WHITE") + new_image.paste(image, (0, 0), image) + new_image = new_image.convert("RGB") + return new_image + + +def _num_image_tokens(image_size: Tuple[int, int], patch_size: Tuple[int, int]) -> int: + """ + Calculate the number of image tokens given the image size and patch size. + + Args: + image_size (`Tuple[int, int]`): + The size of the image as `(height, width)`. + patch_size (`Tuple[int, int]`): + The patch size as `(height, width)`. + + Returns: + `int`: The number of image tokens. + """ + height, width = image_size + patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size) + num_width_tokens = (width - 1) // patch_width + 1 + num_height_tokens = (height - 1) // patch_height + 1 + return num_height_tokens, num_width_tokens + + +def get_resize_output_image_size( + input_image: np.ndarray, + patch_size: Union[int, Tuple[int, int], List[int], Tuple[int]], + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> tuple: + """ + Find the target (height, width) dimension of the output image after resizing given the input image and the desired + size. + + Args: + input_image (`np.ndarray`): + The image to resize. + patch_size (`int` or `Tuple[int, int]`): + The patch_size as `(height, width)` to use for resizing the image. If patch_size is an integer, `(patch_size, patch_size)` + will be used + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If unset, will use the inferred format from the input. + + Returns: + `tuple`: The target (height, width) dimension of the output image after resizing. + """ + patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size) + height, width = get_image_size(input_image, input_data_format) + + ratio = max(height / patch_height, width / patch_width) + + if ratio > 1: + # Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results + height = int(np.ceil(height / ratio)) + width = int(np.ceil(width / ratio)) + + num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width)) + return num_height_tokens * patch_height, num_width_tokens * patch_width + + +class PixtralImageProcessor(BaseImageProcessor): + r""" + Constructs a Pixtral image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + patch_size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": }`): + Size of the patches in the model, used to calculate the output image size. Can be overridden by `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 224} + size = get_size_dict(size, default_to_square=False) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073] + self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711] + self.do_convert_rgb = do_convert_rgb + self._valid_processor_keys = [ + "images", + "do_resize", + "size", + "resample", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_convert_rgb", + "return_tensors", + "data_format", + "input_data_format", + ] + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + if "height" in patch_size and "width" in patch_size: + size = (size["height"], size["width"]) + else: + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + image, + patch_size=patch_size, + default_to_square=default_to_square, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + patch_size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`): + Patch size in the model. Used to calculate the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + patch_size = patch_size if patch_size is not None else self.patch_size + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=patch_size, + resample=resample, + ) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + all_images = [] + for image in images: + if do_resize: + image = self.resize(image=image, patch_size=patch_size, resample=resample, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + all_images.append(image) + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for image in all_images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py new file mode 100644 index 000000000000..783543f92075 --- /dev/null +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -0,0 +1,191 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Pixtral. +""" + +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, get_image_size, to_numpy_array +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from ...utils import TensorType, logging + + +logger = logging.get_logger(__name__) + + +class PixtralProcessor(ProcessorMixin): + r""" + Constructs a Pixtral processor which wraps a Pixtral image processor and a Pixtral tokenizer into a single processor. + + [`PixtralProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the + [`~PixtralProcessor.__call__`] and [`~PixtralProcessor.decode`] for more information. + + Args: + image_processor ([`CLIPImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`], *optional*): + The tokenizer is a required input. + patch_size (`int`, *optional*): + Patch size from the vision tower. + vision_feature_select_strategy (`str`, *optional*): + The feature selection strategy used to select the vision feature from the vision backbone. + Shoudl be same as in model's config + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + image_token (`str`, *optional*, defaults to `""`): + Special token used to denote image location. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template", "patch_size", "vision_feature_select_strategy", "image_token"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor=None, + tokenizer=None, + patch_size=None, + vision_feature_select_strategy=None, + chat_template=None, + image_token="", # set the default and let users change if they have peculiar special tokens in rare cases + **kwargs, + ): + self.patch_size = patch_size + self.vision_feature_select_strategy = vision_feature_select_strategy + self.image_token = image_token + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + images: ImageInput = None, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length=None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + if images is not None: + image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors=return_tensors) + else: + image_inputs = {} + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + # try to expand inputs in processing if we have the necessary parts + prompt_strings = text + if image_inputs.get("pixel_values") is not None: + if self.patch_size is not None and self.vision_feature_select_strategy is not None: + # Replace the image token with the expanded image token sequence + pixel_values = image_inputs["pixel_values"] + height, width = get_image_size(to_numpy_array(pixel_values[0])) + num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + + prompt_strings = [] + for sample in text: + sample = sample.replace(self.image_token, self.image_token * num_image_tokens) + prompt_strings.append(sample) + else: + logger.warning_once( + "Expanding inputs for image tokens in Pixtral should be done in processing. " + "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " + "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + + text_inputs = self.tokenizer( + prompt_strings, + return_tensors=return_tensors, + padding=padding, + truncation=truncation, + max_length=max_length, + ) + return BatchFeature(data={**text_inputs, **image_inputs}) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) From ab898efacbbf1738cb804bd20d3413d23a05ed1c Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Thu, 12 Sep 2024 19:20:42 +0100 Subject: [PATCH 2/6] Add in image break and end tokens --- .../models/pixtral/processing_pixtral.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index 783543f92075..602d0b9c37ca 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -47,12 +47,13 @@ class PixtralProcessor(ProcessorMixin): Shoudl be same as in model's config chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string. - image_token (`str`, *optional*, defaults to `""`): + image_token (`str`, *optional*, defaults to `"[IMG]"`): Special token used to denote image location. + image_break_token (`str`, *optional*, defaults to `"[IMG_BREAK]"`): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template", "patch_size", "vision_feature_select_strategy", "image_token"] + valid_kwargs = ["chat_template", "patch_size", "vision_feature_select_strategy", "image_token", "image_break_token"] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" @@ -63,12 +64,16 @@ def __init__( patch_size=None, vision_feature_select_strategy=None, chat_template=None, - image_token="", # set the default and let users change if they have peculiar special tokens in rare cases + image_token="[IMG]", # set the default and let users change if they have peculiar special tokens in rare cases + image_break_token="[IMG_BREAK]", + image_end_token="[IMG_END]", **kwargs, ): self.patch_size = patch_size self.vision_feature_select_strategy = vision_feature_select_strategy self.image_token = image_token + self.image_break_token = image_break_token + self.image_end_token = image_end_token super().__init__(image_processor, tokenizer, chat_template=chat_template) def __call__( @@ -142,13 +147,17 @@ def __call__( # Replace the image token with the expanded image token sequence pixel_values = image_inputs["pixel_values"] height, width = get_image_size(to_numpy_array(pixel_values[0])) - num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + num_height_tokens = height // self.patch_size + num_width_tokens = width // self.patch_size if self.vision_feature_select_strategy == "default": num_image_tokens -= 1 prompt_strings = [] + replace_tokens = [self.image_token] * num_width_tokens + [self.image_break_token * num_height_tokens] + replace_tokens[-1] = self.image_end_token + replace_str = "".join(replace_tokens) for sample in text: - sample = sample.replace(self.image_token, self.image_token * num_image_tokens) + sample = sample.replace(self.image_token, replace_str) prompt_strings.append(sample) else: logger.warning_once( From 5b0ce05c20f4b974e97ecce7236efaa7607d2f15 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Thu, 12 Sep 2024 19:58:00 +0100 Subject: [PATCH 3/6] Fix --- src/transformers/__init__.py | 5 +- .../pixtral/convert_pixtral_weights_to_hf.py | 538 ++++++++++++++++-- .../pixtral/image_processing_pixtral.py | 57 +- .../models/pixtral/modeling_pixtral.py | 51 +- .../models/pixtral/processing_pixtral.py | 22 +- 5 files changed, 559 insertions(+), 114 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a2be4b145ac8..5ff1da3f1e0a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -5446,7 +5446,8 @@ Pix2StructVisionConfig, ) from .models.pixtral import ( - PixtralConfig, PixtralProcessor, + PixtralConfig, + PixtralProcessor, ) from .models.plbart import PLBartConfig from .models.poolformer import ( @@ -6023,11 +6024,11 @@ from .models.owlvit import OwlViTFeatureExtractor, OwlViTImageProcessor from .models.perceiver import PerceiverFeatureExtractor, PerceiverImageProcessor from .models.pix2struct import Pix2StructImageProcessor + from .models.pixtral import PixtralImageProcessor from .models.poolformer import ( PoolFormerFeatureExtractor, PoolFormerImageProcessor, ) - from .models.pixtral import PixtralImageProcessor from .models.pvt import PvtImageProcessor from .models.qwen2_vl import Qwen2VLImageProcessor from .models.rt_detr import RTDetrImageProcessor diff --git a/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py b/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py index 1e608a4699d8..3118e07225b8 100644 --- a/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py +++ b/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py @@ -1,16 +1,19 @@ -from transformers import LlavaConfig, LlavaForConditionalGeneration, AutoTokenizer, MistralConfig, PixtralConfig, PreTrainedTokenizerFast - -import torch -from safetensors.torch import load_file as safe_load_file import regex as re - -from PIL import Image import requests -from transformers import AutoProcessor - +import torch +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from PIL import Image +from safetensors.torch import load_file as safe_load_file +from transformers import ( + AutoProcessor, + LlavaConfig, + LlavaForConditionalGeneration, + MistralConfig, + PixtralConfig, + PreTrainedTokenizerFast, +) -from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # Load Mistral tokenizer @@ -19,11 +22,16 @@ tokenizer = MistralTokenizer.from_model(model_name) vocab = tokenizer.instruct_tokenizer.tokenizer._tekken_token2id_nospecial -all_special = [token.value if hasattr(token,"value") else token for token in tokenizer.instruct_tokenizer.tokenizer._all_special_tokens] -specials_tokens = {token : all_special.index(token) for token in all_special} +all_special = [ + token.value if hasattr(token, "value") else token + for token in tokenizer.instruct_tokenizer.tokenizer._all_special_tokens +] +specials_tokens = {token: all_special.index(token) for token in all_special} specials_tokens.update(vocab) vocab = specials_tokens from transformers.convert_slow_tokenizer import * + + class MistralConverter: """ A general tiktoken converter. @@ -46,13 +54,13 @@ def __init__( def extract_vocab_merges_from_model(self, vocab: str): try: - from tiktoken.load import load_tiktoken_bpe + pass except Exception: raise ValueError( "`tiktoken` is required to read a `tiktoken` file. Install it with " "`pip install tiktoken`." ) - bpe_ranks = vocab + bpe_ranks = vocab byte_encoder = bytes_to_unicode() def token_bytes_to_string(b): @@ -100,7 +108,10 @@ def converted(self) -> Tokenizer: return tokenizer -tokenizer = PreTrainedTokenizerFast(tokenizer_object = MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted()) + +tokenizer = PreTrainedTokenizerFast( + tokenizer_object=MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted() +) text_config = MistralConfig( @@ -121,7 +132,7 @@ def converted(self) -> Tokenizer: rope_theta=1000000000.0, sliding_window=None, tie_word_embeddings=False, - vocab_size=131072 + vocab_size=131072, ) vision_config = PixtralConfig() @@ -130,62 +141,59 @@ def converted(self) -> Tokenizer: config.text_config.head_dim = 128 config.save_pretrained("../pixtral") -tokenizer.model_input_names = ['input_ids', 'attention_mask'] +tokenizer.model_input_names = ["input_ids", "attention_mask"] original_state_dict = safe_load_file("../pixtral/consolidated.safetensors") OLD_KEY_TO_NEW_KEY_MAPPING = { # Layer Normalization Weights - r"vision_encoder.transformer.layers.(\d+).input_layernorm.weight": r"vision_tower.transformer.layers.\1.attention_norm.weight", - r"vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"vision_tower.transformer.layers.\1.ffn_norm.weight", - + r"vision_encoder.transformer.layers.(\d+).input_layernorm.weight": r"vision_tower.transformer.layers.\1.attention_norm.weight", + r"vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"vision_tower.transformer.layers.\1.ffn_norm.weight", # Self Attention Projections - r"vision_encoder.transformer.layers.(\d+).attention.wq.weight": r"vision_tower.transformer.layers.\1.attention.q_proj.weight", - r"vision_encoder.transformer.layers.(\d+).attention.wk.weight": r"vision_tower.transformer.layers.\1.attention.k_proj.weight", - r"vision_encoder.transformer.layers.(\d+).attention.wv.weight": r"vision_tower.transformer.layers.\1.attention.v_proj.weight", - r"vision_encoder.transformer.layers.(\d+).attention.wo.weight": r"vision_tower.transformer.layers.\1.attention.o_proj.weight", - + r"vision_encoder.transformer.layers.(\d+).attention.wq.weight": r"vision_tower.transformer.layers.\1.attention.q_proj.weight", + r"vision_encoder.transformer.layers.(\d+).attention.wk.weight": r"vision_tower.transformer.layers.\1.attention.k_proj.weight", + r"vision_encoder.transformer.layers.(\d+).attention.wv.weight": r"vision_tower.transformer.layers.\1.attention.v_proj.weight", + r"vision_encoder.transformer.layers.(\d+).attention.wo.weight": r"vision_tower.transformer.layers.\1.attention.o_proj.weight", # MLP Projections - r"vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight", - r"vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"vision_tower.transformer.layers.\1.feed_forward.down_proj.weight", - r"vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"vision_tower.transformer.layers.\1.feed_forward.up_proj.weight", - + r"vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight", + r"vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"vision_tower.transformer.layers.\1.feed_forward.down_proj.weight", + r"vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"vision_tower.transformer.layers.\1.feed_forward.up_proj.weight", # Additional mappings - r"vision_encoder": r"vision_tower", - r"vision_language_adapter.w_in": r"multi_modal_projector.linear_1", - r"vision_language_adapter.w_out": r"multi_modal_projector.linear_2", - r"layers.(\d+).attention.wq.weight": r"language_model.model.layers.\1.self_attn.q_proj.weight", - r"layers.(\d+).attention.wk.weight": r"language_model.model.layers.\1.self_attn.k_proj.weight", - r"layers.(\d+).attention.wv.weight": r"language_model.model.layers.\1.self_attn.v_proj.weight", - r"layers.(\d+).attention.wo.weight": r"language_model.model.layers.\1.self_attn.o_proj.weight", - r"layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.mlp.gate_proj.weight", - r"layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.mlp.down_proj.weight", - r"layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.mlp.up_proj.weight", - r"layers.(\d+).ffn_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight", - r"layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight", - r"tok_embeddings.weight": r"language_model.model.embed_tokens.weight", - r"output.weight": r"language_model.lm_head.weight", - r"norm.weight": r"language_model.model.norm.weight" - + r"vision_encoder": r"vision_tower", + r"vision_language_adapter.w_in": r"multi_modal_projector.linear_1", + r"vision_language_adapter.w_out": r"multi_modal_projector.linear_2", + r"layers.(\d+).attention.wq.weight": r"language_model.model.layers.\1.self_attn.q_proj.weight", + r"layers.(\d+).attention.wk.weight": r"language_model.model.layers.\1.self_attn.k_proj.weight", + r"layers.(\d+).attention.wv.weight": r"language_model.model.layers.\1.self_attn.v_proj.weight", + r"layers.(\d+).attention.wo.weight": r"language_model.model.layers.\1.self_attn.o_proj.weight", + r"layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.mlp.gate_proj.weight", + r"layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.mlp.down_proj.weight", + r"layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.mlp.up_proj.weight", + r"layers.(\d+).ffn_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight", + r"layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight", + r"tok_embeddings.weight": r"language_model.model.embed_tokens.weight", + r"output.weight": r"language_model.lm_head.weight", + r"norm.weight": r"language_model.model.norm.weight", } -new_state_dict = {} -all_keys = "\n"+ "\n".join(original_state_dict.keys()) +new_state_dict = {} +all_keys = "\n" + "\n".join(original_state_dict.keys()) old_keys = all_keys for old, new in OLD_KEY_TO_NEW_KEY_MAPPING.items(): - all_keys = re.sub(r"\n"+ old,r"\n"+new,all_keys) + all_keys = re.sub(r"\n" + old, r"\n" + new, all_keys) OLD_TO_NEW = dict(zip(old_keys.split("\n"), all_keys.split("\n"))) -new_dict={} +new_dict = {} + def permute_for_rope(value, n_heads, config): - dim1 = value.shape[0] - dim2 = config.hidden_size - return value.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + dim1 = value.shape[0] + dim2 = config.hidden_size + return value.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) -for key, value in original_state_dict.items(): +for key, value in original_state_dict.items(): new_key = OLD_TO_NEW[key] if "vision_encoder" in key: _config = vision_config @@ -198,9 +206,8 @@ def permute_for_rope(value, n_heads, config): num_attention_heads = _config.num_key_value_heads # convert the text model (basically mistral model) - if "q_proj" in new_key or "k_proj" in new_key: - value = permute_for_rope(value,num_attention_heads, _config) + value = permute_for_rope(value, num_attention_heads, _config) new_dict[new_key] = value @@ -214,17 +221,432 @@ def permute_for_rope(value, n_heads, config): config.image_token_index = 10 config.vision_feature_select_strategy = "full" model = LlavaForConditionalGeneration.from_pretrained("../pixtral", config=config).to("cuda") -processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", image_token = "[IMG]") +processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", image_token="[IMG]") processor.tokenizer = tokenizer prompt = "USER: \nWhat's the content of the image? ASSISTANT:" url = "https://www.ilankelman.org/stopsigns/australia.jpg" image = Image.open(requests.get(url, stream=True).raw) -prompt = '[INST][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_END]Describe this image in one sentence.[/INST]' -input_ids_ = torch.tensor([[1, 3, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 13, 5847, 13089, 1593, 3937, 1294, 1925, 19286, 1046, 4]]).long() +prompt = "[INST][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_END]Describe this image in one sentence.[/INST]" +input_ids_ = torch.tensor( + [ + [ + 1, + 3, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 12, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 12, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 12, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 12, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 12, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 12, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 12, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 12, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 12, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 12, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 12, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 12, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 12, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 12, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 12, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 12, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 12, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 12, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 13, + 5847, + 13089, + 1593, + 3937, + 1294, + 1925, + 19286, + 1046, + 4, + ] + ] +).long() inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda") -input_ids = torch.tensor([[1, 5, 1091, 19227, 4994, 2811, 1429, 5165, 1897, 1429, 5165, 2811, 16753, 2391, 2811, 1429, 1689, 45971, 1095, 45629, 1897, 1429, 14653, 2811, 1429, 4147, 1278, 3519, 17253, 1897, 1429, 26204, 2811, 16753, 4994, 2811, 1429, 6371, 1897, 1429, 48649, 2811, 16753, 17611, 2811, 16753, 4994, 2811, 1429, 3607, 1897, 1429, 14653, 2811, 1429, 1784, 5970, 1321, 3468, 1044, 1324, 3596, 1046, 5151, 12717, 1044, 13461, 50666, 1429, 8092, 2811, 16753, 4994, 2811, 1429, 3607, 1897, 1429, 31222, 2811, 12161, 1099, 79092, 1897, 1429, 38600, 10432, 31597, 1429, 14653, 2811, 1429, 1784, 6138, 5476, 1317, 2210, 1046, 90463, 1593, 1562, 1278, 8616, 7285, 2613, 47579, 1429, 15760, 2811, 12161, 17611, 1897, 1429, 8092, 4964, 2821, 27028, 6, 3, 7493, 1681, 1278, 17253, 2479, 9406, 1294, 6993, 4]]) +input_ids = torch.tensor( + [ + [ + 1, + 5, + 1091, + 19227, + 4994, + 2811, + 1429, + 5165, + 1897, + 1429, + 5165, + 2811, + 16753, + 2391, + 2811, + 1429, + 1689, + 45971, + 1095, + 45629, + 1897, + 1429, + 14653, + 2811, + 1429, + 4147, + 1278, + 3519, + 17253, + 1897, + 1429, + 26204, + 2811, + 16753, + 4994, + 2811, + 1429, + 6371, + 1897, + 1429, + 48649, + 2811, + 16753, + 17611, + 2811, + 16753, + 4994, + 2811, + 1429, + 3607, + 1897, + 1429, + 14653, + 2811, + 1429, + 1784, + 5970, + 1321, + 3468, + 1044, + 1324, + 3596, + 1046, + 5151, + 12717, + 1044, + 13461, + 50666, + 1429, + 8092, + 2811, + 16753, + 4994, + 2811, + 1429, + 3607, + 1897, + 1429, + 31222, + 2811, + 12161, + 1099, + 79092, + 1897, + 1429, + 38600, + 10432, + 31597, + 1429, + 14653, + 2811, + 1429, + 1784, + 6138, + 5476, + 1317, + 2210, + 1046, + 90463, + 1593, + 1562, + 1278, + 8616, + 7285, + 2613, + 47579, + 1429, + 15760, + 2811, + 12161, + 17611, + 1897, + 1429, + 8092, + 4964, + 2821, + 27028, + 6, + 3, + 7493, + 1681, + 1278, + 17253, + 2479, + 9406, + 1294, + 6993, + 4, + ] + ] +) # Generate diff --git a/src/transformers/models/pixtral/image_processing_pixtral.py b/src/transformers/models/pixtral/image_processing_pixtral.py index 6d354dbbd5fd..3481b793ec7e 100644 --- a/src/transformers/models/pixtral/image_processing_pixtral.py +++ b/src/transformers/models/pixtral/image_processing_pixtral.py @@ -14,7 +14,7 @@ # limitations under the License. """Image processor class for Pixtral.""" -from typing import Dict, List, Optional, Union, Tuple +from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -25,11 +25,10 @@ to_channel_dimension_format, ) from ...image_utils import ( - OPENAI_CLIP_MEAN, - OPENAI_CLIP_STD, ChannelDimension, ImageInput, PILImageResampling, + get_image_size, infer_channel_dimension_format, is_scaled_image, make_list_of_images, @@ -39,6 +38,7 @@ validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging +from ...utils.import_utils import requires_backends logger = logging.get_logger(__name__) @@ -97,6 +97,7 @@ def _num_image_tokens(image_size: Tuple[int, int], patch_size: Tuple[int, int]) def get_resize_output_image_size( input_image: np.ndarray, + size: Union[int, Tuple[int, int], List[int], Tuple[int]], patch_size: Union[int, Tuple[int, int], List[int], Tuple[int]], input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> tuple: @@ -107,6 +108,8 @@ def get_resize_output_image_size( Args: input_image (`np.ndarray`): The image to resize. + size (`int` or `Tuple[int, int]`): + Max image size an input image can be. Must be a dictionary with the key "longest_edge". patch_size (`int` or `Tuple[int, int]`): The patch_size as `(height, width)` to use for resizing the image. If patch_size is an integer, `(patch_size, patch_size)` will be used @@ -116,10 +119,11 @@ def get_resize_output_image_size( Returns: `tuple`: The target (height, width) dimension of the output image after resizing. """ + max_height, max_width = size if isinstance(size, (tuple, list)) else (size, size) patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size) height, width = get_image_size(input_image, input_data_format) - ratio = max(height / patch_height, width / patch_width) + ratio = max(height / max_height, width / max_width) if ratio > 1: # Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results @@ -138,8 +142,9 @@ class PixtralImageProcessor(BaseImageProcessor): do_resize (`bool`, *optional*, defaults to `True`): Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by `do_resize` in the `preprocess` method. - patch_size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": }`): - Size of the patches in the model, used to calculate the output image size. Can be overridden by `size` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"longest_edge": 1024}`): + patch_size (`Dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`): + Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. do_rescale (`bool`, *optional*, defaults to `True`): @@ -167,6 +172,7 @@ def __init__( self, do_resize: bool = True, size: Dict[str, int] = None, + patch_size: Dict[str, int] = None, resample: PILImageResampling = PILImageResampling.BICUBIC, do_rescale: bool = True, rescale_factor: Union[int, float] = 1 / 255, @@ -177,11 +183,12 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) - size = size if size is not None else {"shortest_edge": 224} - size = get_size_dict(size, default_to_square=False) + size = size if size is not None else {"longest_edge": 1024} + patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16} self.do_resize = do_resize self.size = size + self.patch_size = patch_size self.resample = resample self.do_rescale = do_rescale self.rescale_factor = rescale_factor @@ -209,6 +216,7 @@ def resize( self, image: np.ndarray, size: Dict[str, int], + patch_size: Dict[str, int], resample: PILImageResampling = PILImageResampling.BICUBIC, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -222,7 +230,9 @@ def resize( image (`np.ndarray`): Image to resize. size (`Dict[str, int]`): - Size of the output image. + Dict containing the longest possible edge of the image. + patch_size (`Dict[str, int]`): + Patch size used to calculate the size of the output image. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): Resampling filter to use when resiizing the image. data_format (`str` or `ChannelDimension`, *optional*): @@ -230,15 +240,22 @@ def resize( input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format of the input image. If not provided, it will be inferred. """ - if "height" in patch_size and "width" in patch_size: + if "longest_edge" in size: + size = (size["longest_edge"], size["longest_edge"]) + elif "height" in size and "width" in size: size = (size["height"], size["width"]) else: - raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") + raise ValueError("size must contain either 'longest_edge' or 'height' and 'width'.") + + if "height" in patch_size and "width" in patch_size: + patch_size = (patch_size["height"], patch_size["width"]) + else: + raise ValueError("patch_size must contain either 'shortest_edge' or 'height' and 'width'.") output_size = get_resize_output_image_size( image, + size=size, patch_size=patch_size, - default_to_square=default_to_square, input_data_format=input_data_format, ) return resize( @@ -254,6 +271,7 @@ def preprocess( self, images: ImageInput, do_resize: bool = None, + size: Dict[str, int] = None, patch_size: Dict[str, int] = None, resample: PILImageResampling = None, do_rescale: bool = None, @@ -276,6 +294,8 @@ def preprocess( passing in images with pixel values between 0 and 1, set `do_rescale=False`. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Describes the maximum input dimensions to the model. patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`): Patch size in the model. Used to calculate the image after resizing. resample (`int`, *optional*, defaults to `self.resample`): @@ -313,7 +333,10 @@ def preprocess( - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. """ + patch_size = get_size_dict(patch_size, default_to_square=True) + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size patch_size = patch_size if patch_size is not None else self.patch_size resample = resample if resample is not None else self.resample do_rescale = do_rescale if do_rescale is not None else self.do_rescale @@ -339,7 +362,7 @@ def preprocess( image_mean=image_mean, image_std=image_std, do_resize=do_resize, - size=patch_size, + size=size, resample=resample, ) @@ -362,7 +385,13 @@ def preprocess( all_images = [] for image in images: if do_resize: - image = self.resize(image=image, patch_size=patch_size, resample=resample, input_data_format=input_data_format) + image = self.resize( + image=image, + size=size, + patch_size=patch_size, + resample=resample, + input_data_format=input_data_format, + ) if do_rescale: image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index e5a4d5b34522..91d93d643f39 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -23,16 +23,13 @@ from ... import PreTrainedModel from ...activations import ACT2FN -from ...modeling_outputs import ModelOutput, BaseModelOutput +from ...modeling_outputs import BaseModelOutput, ModelOutput from ...utils import ( add_start_docstrings, - add_start_docstrings_to_model_forward, logging, - replace_return_docstrings, ) -from ..auto import AutoModelForCausalLM from .configuration_pixtral import PixtralConfig -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS + logger = logging.get_logger(__name__) @@ -84,24 +81,25 @@ class PixtralCausalLMOutputWithPast(ModelOutput): class PixtralRotaryEmbedding(nn.Module): """ - The key with pixtral embedding is just that you have a frequency for each pixel positions. - If you have height x width pixels (or embedding pixels) + The key with pixtral embedding is just that you have a frequency for each pixel positions. + If you have height x width pixels (or embedding pixels) - then the frequency used for ROPE is given by indexing the pre_computed frequency on the - width and height. + then the frequency used for ROPE is given by indexing the pre_computed frequency on the + width and height. - What you output is of dimension batch, height * width, dim with dim the embed dim. + What you output is of dimension batch, height * width, dim with dim the embed dim. - This simply means that for each image hidden states, you are going to add - a corresponding positional embedding, based on it's index in the grid. + This simply means that for each image hidden states, you are going to add + a corresponding positional embedding, based on it's index in the grid. """ + def __init__(self, config, device): super().__init__() self.rope_type = "default" self.dim = config.head_dim self.base = config.rope_theta max_patches_per_side = config.image_size // config.patch_size - freqs = 1.0 / (self.base**(torch.arange(0, self.dim, 2).float() / self.dim)) + freqs = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) h = torch.arange(max_patches_per_side, device=freqs.device) w = torch.arange(max_patches_per_side, device=freqs.device) @@ -114,7 +112,7 @@ def __init__(self, config, device): freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), ], dim=-1, - ).reshape(-1, self.dim) # we reshape to only index on the position indexes, not tuple of indexes + ).reshape(-1, self.dim) # we reshape to only index on the position indexes, not tuple of indexes # Different from paper, but it uses a different permutation in order to obtain the same calculation # TODO maybe make it torch compatible later on. We can also just slice @@ -137,7 +135,6 @@ def forward(self, x, position_ids): sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: @@ -157,7 +154,6 @@ def _dynamic_frequency_update(self, position_ids, device): self.max_seq_len_cached = self.original_max_seq_len - # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -193,6 +189,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=0): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed + class PixtralAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -251,6 +248,7 @@ def forward( return attn_output, attn_weights + # Copied from gemma2 class PixtralMLP(nn.Module): def __init__(self, config): @@ -292,21 +290,19 @@ def position_ids_in_meshgrid(patch_embeds_list): for patch in patch_embeds_list: height, width = patch.shape[-2:] mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") - h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2,-1) + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) ids = h_grid * height + v_grid - positions.append(ids[:,0]) + positions.append(ids[:, 0]) return torch.cat(positions) - class PixtralAttentionLayer(nn.Module): def __init__(self, config): super().__init__() self.attention_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5) self.feed_forward = PixtralMLP(config) self.attention = PixtralAttention(config) - self.ffn_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5) - + self.ffn_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5) def forward( self, @@ -348,7 +344,6 @@ def forward( return outputs - class PixtralTransformer(nn.Module): def __init__(self, config): super().__init__() @@ -433,8 +428,6 @@ def forward( ) - - PIXTRAL_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -488,6 +481,7 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + PIXTRAL_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -562,12 +556,14 @@ def _init_weights(self, module): the complete sequence length. """ + @add_start_docstrings( """The PIXTRAL model which consists of a vision backbone and a language model.""", PIXTRAL_START_DOCSTRING, ) class PixtralModel(PixtralPreTrainedModel): base_model_prefix = "vision_encoder" + def __init__(self, config): super().__init__(config) self.config = config @@ -604,13 +600,10 @@ def forward(self, images: List[torch.Tensor], output_hidden_states=False, *kwarg all tokens of all images of shape (N_toks, D) """ # pass images through initial convolution independently - patch_embeds_list = [ - self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images - ] + patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images] # flatten to a single sequence - patch_embeds = torch.cat( - [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1) + patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1) patch_embeds = self.ln_pre(patch_embeds) # positional embeddings diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index 602d0b9c37ca..6b13b9b03123 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -42,9 +42,6 @@ class PixtralProcessor(ProcessorMixin): The tokenizer is a required input. patch_size (`int`, *optional*): Patch size from the vision tower. - vision_feature_select_strategy (`str`, *optional*): - The feature selection strategy used to select the vision feature from the vision backbone. - Shoudl be same as in model's config chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string. image_token (`str`, *optional*, defaults to `"[IMG]"`): @@ -53,7 +50,13 @@ class PixtralProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template", "patch_size", "vision_feature_select_strategy", "image_token", "image_break_token"] + valid_kwargs = [ + "chat_template", + "patch_size", + "vision_feature_select_strategy", + "image_token", + "image_break_token", + ] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" @@ -62,7 +65,6 @@ def __init__( image_processor=None, tokenizer=None, patch_size=None, - vision_feature_select_strategy=None, chat_template=None, image_token="[IMG]", # set the default and let users change if they have peculiar special tokens in rare cases image_break_token="[IMG_BREAK]", @@ -143,17 +145,15 @@ def __call__( # try to expand inputs in processing if we have the necessary parts prompt_strings = text if image_inputs.get("pixel_values") is not None: - if self.patch_size is not None and self.vision_feature_select_strategy is not None: + if self.patch_size is not None: # Replace the image token with the expanded image token sequence pixel_values = image_inputs["pixel_values"] height, width = get_image_size(to_numpy_array(pixel_values[0])) num_height_tokens = height // self.patch_size num_width_tokens = width // self.patch_size - if self.vision_feature_select_strategy == "default": - num_image_tokens -= 1 prompt_strings = [] - replace_tokens = [self.image_token] * num_width_tokens + [self.image_break_token * num_height_tokens] + replace_tokens = [self.image_token] * num_width_tokens + [self.image_break_token] * num_height_tokens replace_tokens[-1] = self.image_end_token replace_str = "".join(replace_tokens) for sample in text: @@ -162,8 +162,8 @@ def __call__( else: logger.warning_once( "Expanding inputs for image tokens in Pixtral should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " + "Please add `patch_size` to the model's processing config or set directly " + "with `processor.patch_size = {{patch_size}}`" "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." ) From d28c7ebb647588fcb877daa13a7b55720beb8c86 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Thu, 12 Sep 2024 19:59:50 +0100 Subject: [PATCH 4/6] Udo some formatting changes --- .../pixtral/convert_pixtral_weights_to_hf.py | 538 ++---------------- .../models/pixtral/modeling_pixtral.py | 51 +- 2 files changed, 87 insertions(+), 502 deletions(-) diff --git a/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py b/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py index 3118e07225b8..1e608a4699d8 100644 --- a/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py +++ b/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py @@ -1,19 +1,16 @@ -import regex as re -import requests +from transformers import LlavaConfig, LlavaForConditionalGeneration, AutoTokenizer, MistralConfig, PixtralConfig, PreTrainedTokenizerFast + import torch -from mistral_common.tokens.tokenizers.mistral import MistralTokenizer -from PIL import Image from safetensors.torch import load_file as safe_load_file +import regex as re + +from PIL import Image +import requests +from transformers import AutoProcessor + -from transformers import ( - AutoProcessor, - LlavaConfig, - LlavaForConditionalGeneration, - MistralConfig, - PixtralConfig, - PreTrainedTokenizerFast, -) +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # Load Mistral tokenizer @@ -22,16 +19,11 @@ tokenizer = MistralTokenizer.from_model(model_name) vocab = tokenizer.instruct_tokenizer.tokenizer._tekken_token2id_nospecial -all_special = [ - token.value if hasattr(token, "value") else token - for token in tokenizer.instruct_tokenizer.tokenizer._all_special_tokens -] -specials_tokens = {token: all_special.index(token) for token in all_special} +all_special = [token.value if hasattr(token,"value") else token for token in tokenizer.instruct_tokenizer.tokenizer._all_special_tokens] +specials_tokens = {token : all_special.index(token) for token in all_special} specials_tokens.update(vocab) vocab = specials_tokens from transformers.convert_slow_tokenizer import * - - class MistralConverter: """ A general tiktoken converter. @@ -54,13 +46,13 @@ def __init__( def extract_vocab_merges_from_model(self, vocab: str): try: - pass + from tiktoken.load import load_tiktoken_bpe except Exception: raise ValueError( "`tiktoken` is required to read a `tiktoken` file. Install it with " "`pip install tiktoken`." ) - bpe_ranks = vocab + bpe_ranks = vocab byte_encoder = bytes_to_unicode() def token_bytes_to_string(b): @@ -108,10 +100,7 @@ def converted(self) -> Tokenizer: return tokenizer - -tokenizer = PreTrainedTokenizerFast( - tokenizer_object=MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted() -) +tokenizer = PreTrainedTokenizerFast(tokenizer_object = MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted()) text_config = MistralConfig( @@ -132,7 +121,7 @@ def converted(self) -> Tokenizer: rope_theta=1000000000.0, sliding_window=None, tie_word_embeddings=False, - vocab_size=131072, + vocab_size=131072 ) vision_config = PixtralConfig() @@ -141,59 +130,62 @@ def converted(self) -> Tokenizer: config.text_config.head_dim = 128 config.save_pretrained("../pixtral") -tokenizer.model_input_names = ["input_ids", "attention_mask"] +tokenizer.model_input_names = ['input_ids', 'attention_mask'] original_state_dict = safe_load_file("../pixtral/consolidated.safetensors") OLD_KEY_TO_NEW_KEY_MAPPING = { # Layer Normalization Weights - r"vision_encoder.transformer.layers.(\d+).input_layernorm.weight": r"vision_tower.transformer.layers.\1.attention_norm.weight", - r"vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"vision_tower.transformer.layers.\1.ffn_norm.weight", + r"vision_encoder.transformer.layers.(\d+).input_layernorm.weight": r"vision_tower.transformer.layers.\1.attention_norm.weight", + r"vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"vision_tower.transformer.layers.\1.ffn_norm.weight", + # Self Attention Projections - r"vision_encoder.transformer.layers.(\d+).attention.wq.weight": r"vision_tower.transformer.layers.\1.attention.q_proj.weight", - r"vision_encoder.transformer.layers.(\d+).attention.wk.weight": r"vision_tower.transformer.layers.\1.attention.k_proj.weight", - r"vision_encoder.transformer.layers.(\d+).attention.wv.weight": r"vision_tower.transformer.layers.\1.attention.v_proj.weight", - r"vision_encoder.transformer.layers.(\d+).attention.wo.weight": r"vision_tower.transformer.layers.\1.attention.o_proj.weight", + r"vision_encoder.transformer.layers.(\d+).attention.wq.weight": r"vision_tower.transformer.layers.\1.attention.q_proj.weight", + r"vision_encoder.transformer.layers.(\d+).attention.wk.weight": r"vision_tower.transformer.layers.\1.attention.k_proj.weight", + r"vision_encoder.transformer.layers.(\d+).attention.wv.weight": r"vision_tower.transformer.layers.\1.attention.v_proj.weight", + r"vision_encoder.transformer.layers.(\d+).attention.wo.weight": r"vision_tower.transformer.layers.\1.attention.o_proj.weight", + # MLP Projections - r"vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight", - r"vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"vision_tower.transformer.layers.\1.feed_forward.down_proj.weight", - r"vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"vision_tower.transformer.layers.\1.feed_forward.up_proj.weight", + r"vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight", + r"vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"vision_tower.transformer.layers.\1.feed_forward.down_proj.weight", + r"vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"vision_tower.transformer.layers.\1.feed_forward.up_proj.weight", + # Additional mappings - r"vision_encoder": r"vision_tower", - r"vision_language_adapter.w_in": r"multi_modal_projector.linear_1", - r"vision_language_adapter.w_out": r"multi_modal_projector.linear_2", - r"layers.(\d+).attention.wq.weight": r"language_model.model.layers.\1.self_attn.q_proj.weight", - r"layers.(\d+).attention.wk.weight": r"language_model.model.layers.\1.self_attn.k_proj.weight", - r"layers.(\d+).attention.wv.weight": r"language_model.model.layers.\1.self_attn.v_proj.weight", - r"layers.(\d+).attention.wo.weight": r"language_model.model.layers.\1.self_attn.o_proj.weight", - r"layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.mlp.gate_proj.weight", - r"layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.mlp.down_proj.weight", - r"layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.mlp.up_proj.weight", - r"layers.(\d+).ffn_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight", - r"layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight", - r"tok_embeddings.weight": r"language_model.model.embed_tokens.weight", - r"output.weight": r"language_model.lm_head.weight", - r"norm.weight": r"language_model.model.norm.weight", + r"vision_encoder": r"vision_tower", + r"vision_language_adapter.w_in": r"multi_modal_projector.linear_1", + r"vision_language_adapter.w_out": r"multi_modal_projector.linear_2", + r"layers.(\d+).attention.wq.weight": r"language_model.model.layers.\1.self_attn.q_proj.weight", + r"layers.(\d+).attention.wk.weight": r"language_model.model.layers.\1.self_attn.k_proj.weight", + r"layers.(\d+).attention.wv.weight": r"language_model.model.layers.\1.self_attn.v_proj.weight", + r"layers.(\d+).attention.wo.weight": r"language_model.model.layers.\1.self_attn.o_proj.weight", + r"layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.mlp.gate_proj.weight", + r"layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.mlp.down_proj.weight", + r"layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.mlp.up_proj.weight", + r"layers.(\d+).ffn_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight", + r"layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight", + r"tok_embeddings.weight": r"language_model.model.embed_tokens.weight", + r"output.weight": r"language_model.lm_head.weight", + r"norm.weight": r"language_model.model.norm.weight" + } -new_state_dict = {} -all_keys = "\n" + "\n".join(original_state_dict.keys()) +new_state_dict = {} +all_keys = "\n"+ "\n".join(original_state_dict.keys()) old_keys = all_keys for old, new in OLD_KEY_TO_NEW_KEY_MAPPING.items(): - all_keys = re.sub(r"\n" + old, r"\n" + new, all_keys) + all_keys = re.sub(r"\n"+ old,r"\n"+new,all_keys) OLD_TO_NEW = dict(zip(old_keys.split("\n"), all_keys.split("\n"))) -new_dict = {} - +new_dict={} def permute_for_rope(value, n_heads, config): - dim1 = value.shape[0] - dim2 = config.hidden_size - return value.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) - + dim1 = value.shape[0] + dim2 = config.hidden_size + return value.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) for key, value in original_state_dict.items(): + new_key = OLD_TO_NEW[key] if "vision_encoder" in key: _config = vision_config @@ -206,8 +198,9 @@ def permute_for_rope(value, n_heads, config): num_attention_heads = _config.num_key_value_heads # convert the text model (basically mistral model) + if "q_proj" in new_key or "k_proj" in new_key: - value = permute_for_rope(value, num_attention_heads, _config) + value = permute_for_rope(value,num_attention_heads, _config) new_dict[new_key] = value @@ -221,432 +214,17 @@ def permute_for_rope(value, n_heads, config): config.image_token_index = 10 config.vision_feature_select_strategy = "full" model = LlavaForConditionalGeneration.from_pretrained("../pixtral", config=config).to("cuda") -processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", image_token="[IMG]") +processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", image_token = "[IMG]") processor.tokenizer = tokenizer prompt = "USER: \nWhat's the content of the image? ASSISTANT:" url = "https://www.ilankelman.org/stopsigns/australia.jpg" image = Image.open(requests.get(url, stream=True).raw) -prompt = "[INST][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_END]Describe this image in one sentence.[/INST]" -input_ids_ = torch.tensor( - [ - [ - 1, - 3, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 12, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 12, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 12, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 12, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 12, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 12, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 12, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 12, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 12, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 12, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 12, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 12, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 12, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 12, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 12, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 12, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 12, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 12, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 10, - 13, - 5847, - 13089, - 1593, - 3937, - 1294, - 1925, - 19286, - 1046, - 4, - ] - ] -).long() +prompt = '[INST][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_BREAK][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG][IMG_END]Describe this image in one sentence.[/INST]' +input_ids_ = torch.tensor([[1, 3, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 13, 5847, 13089, 1593, 3937, 1294, 1925, 19286, 1046, 4]]).long() inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda") -input_ids = torch.tensor( - [ - [ - 1, - 5, - 1091, - 19227, - 4994, - 2811, - 1429, - 5165, - 1897, - 1429, - 5165, - 2811, - 16753, - 2391, - 2811, - 1429, - 1689, - 45971, - 1095, - 45629, - 1897, - 1429, - 14653, - 2811, - 1429, - 4147, - 1278, - 3519, - 17253, - 1897, - 1429, - 26204, - 2811, - 16753, - 4994, - 2811, - 1429, - 6371, - 1897, - 1429, - 48649, - 2811, - 16753, - 17611, - 2811, - 16753, - 4994, - 2811, - 1429, - 3607, - 1897, - 1429, - 14653, - 2811, - 1429, - 1784, - 5970, - 1321, - 3468, - 1044, - 1324, - 3596, - 1046, - 5151, - 12717, - 1044, - 13461, - 50666, - 1429, - 8092, - 2811, - 16753, - 4994, - 2811, - 1429, - 3607, - 1897, - 1429, - 31222, - 2811, - 12161, - 1099, - 79092, - 1897, - 1429, - 38600, - 10432, - 31597, - 1429, - 14653, - 2811, - 1429, - 1784, - 6138, - 5476, - 1317, - 2210, - 1046, - 90463, - 1593, - 1562, - 1278, - 8616, - 7285, - 2613, - 47579, - 1429, - 15760, - 2811, - 12161, - 17611, - 1897, - 1429, - 8092, - 4964, - 2821, - 27028, - 6, - 3, - 7493, - 1681, - 1278, - 17253, - 2479, - 9406, - 1294, - 6993, - 4, - ] - ] -) +input_ids = torch.tensor([[1, 5, 1091, 19227, 4994, 2811, 1429, 5165, 1897, 1429, 5165, 2811, 16753, 2391, 2811, 1429, 1689, 45971, 1095, 45629, 1897, 1429, 14653, 2811, 1429, 4147, 1278, 3519, 17253, 1897, 1429, 26204, 2811, 16753, 4994, 2811, 1429, 6371, 1897, 1429, 48649, 2811, 16753, 17611, 2811, 16753, 4994, 2811, 1429, 3607, 1897, 1429, 14653, 2811, 1429, 1784, 5970, 1321, 3468, 1044, 1324, 3596, 1046, 5151, 12717, 1044, 13461, 50666, 1429, 8092, 2811, 16753, 4994, 2811, 1429, 3607, 1897, 1429, 31222, 2811, 12161, 1099, 79092, 1897, 1429, 38600, 10432, 31597, 1429, 14653, 2811, 1429, 1784, 6138, 5476, 1317, 2210, 1046, 90463, 1593, 1562, 1278, 8616, 7285, 2613, 47579, 1429, 15760, 2811, 12161, 17611, 1897, 1429, 8092, 4964, 2821, 27028, 6, 3, 7493, 1681, 1278, 17253, 2479, 9406, 1294, 6993, 4]]) # Generate diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index 91d93d643f39..e5a4d5b34522 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -23,13 +23,16 @@ from ... import PreTrainedModel from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_outputs import ModelOutput, BaseModelOutput from ...utils import ( add_start_docstrings, + add_start_docstrings_to_model_forward, logging, + replace_return_docstrings, ) +from ..auto import AutoModelForCausalLM from .configuration_pixtral import PixtralConfig - +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS logger = logging.get_logger(__name__) @@ -81,25 +84,24 @@ class PixtralCausalLMOutputWithPast(ModelOutput): class PixtralRotaryEmbedding(nn.Module): """ - The key with pixtral embedding is just that you have a frequency for each pixel positions. - If you have height x width pixels (or embedding pixels) + The key with pixtral embedding is just that you have a frequency for each pixel positions. + If you have height x width pixels (or embedding pixels) - then the frequency used for ROPE is given by indexing the pre_computed frequency on the - width and height. + then the frequency used for ROPE is given by indexing the pre_computed frequency on the + width and height. - What you output is of dimension batch, height * width, dim with dim the embed dim. + What you output is of dimension batch, height * width, dim with dim the embed dim. - This simply means that for each image hidden states, you are going to add - a corresponding positional embedding, based on it's index in the grid. + This simply means that for each image hidden states, you are going to add + a corresponding positional embedding, based on it's index in the grid. """ - def __init__(self, config, device): super().__init__() self.rope_type = "default" self.dim = config.head_dim self.base = config.rope_theta max_patches_per_side = config.image_size // config.patch_size - freqs = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) + freqs = 1.0 / (self.base**(torch.arange(0, self.dim, 2).float() / self.dim)) h = torch.arange(max_patches_per_side, device=freqs.device) w = torch.arange(max_patches_per_side, device=freqs.device) @@ -112,7 +114,7 @@ def __init__(self, config, device): freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), ], dim=-1, - ).reshape(-1, self.dim) # we reshape to only index on the position indexes, not tuple of indexes + ).reshape(-1, self.dim) # we reshape to only index on the position indexes, not tuple of indexes # Different from paper, but it uses a different permutation in order to obtain the same calculation # TODO maybe make it torch compatible later on. We can also just slice @@ -135,6 +137,7 @@ def forward(self, x, position_ids): sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: @@ -154,6 +157,7 @@ def _dynamic_frequency_update(self, position_ids, device): self.max_seq_len_cached = self.original_max_seq_len + # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -189,7 +193,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=0): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed - class PixtralAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -248,7 +251,6 @@ def forward( return attn_output, attn_weights - # Copied from gemma2 class PixtralMLP(nn.Module): def __init__(self, config): @@ -290,19 +292,21 @@ def position_ids_in_meshgrid(patch_embeds_list): for patch in patch_embeds_list: height, width = patch.shape[-2:] mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") - h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2,-1) ids = h_grid * height + v_grid - positions.append(ids[:, 0]) + positions.append(ids[:,0]) return torch.cat(positions) + class PixtralAttentionLayer(nn.Module): def __init__(self, config): super().__init__() self.attention_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5) self.feed_forward = PixtralMLP(config) self.attention = PixtralAttention(config) - self.ffn_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5) + self.ffn_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5) + def forward( self, @@ -344,6 +348,7 @@ def forward( return outputs + class PixtralTransformer(nn.Module): def __init__(self, config): super().__init__() @@ -428,6 +433,8 @@ def forward( ) + + PIXTRAL_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -481,7 +488,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - PIXTRAL_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -556,14 +562,12 @@ def _init_weights(self, module): the complete sequence length. """ - @add_start_docstrings( """The PIXTRAL model which consists of a vision backbone and a language model.""", PIXTRAL_START_DOCSTRING, ) class PixtralModel(PixtralPreTrainedModel): base_model_prefix = "vision_encoder" - def __init__(self, config): super().__init__(config) self.config = config @@ -600,10 +604,13 @@ def forward(self, images: List[torch.Tensor], output_hidden_states=False, *kwarg all tokens of all images of shape (N_toks, D) """ # pass images through initial convolution independently - patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images] + patch_embeds_list = [ + self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images + ] # flatten to a single sequence - patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1) + patch_embeds = torch.cat( + [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1) patch_embeds = self.ln_pre(patch_embeds) # positional embeddings From 81c883ad99521973d6afd706eca7138b96ec4742 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Thu, 12 Sep 2024 20:03:15 +0100 Subject: [PATCH 5/6] Set patch_size default --- .../models/pixtral/processing_pixtral.py | 38 ++++++++----------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index 6b13b9b03123..8202ac32c1e6 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -40,7 +40,7 @@ class PixtralProcessor(ProcessorMixin): The image processor is a required input. tokenizer ([`LlamaTokenizerFast`], *optional*): The tokenizer is a required input. - patch_size (`int`, *optional*): + patch_size (`int`, *optional*, defaults to 16): Patch size from the vision tower. chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string. @@ -64,7 +64,7 @@ def __init__( self, image_processor=None, tokenizer=None, - patch_size=None, + patch_size: int = 16, chat_template=None, image_token="[IMG]", # set the default and let users change if they have peculiar special tokens in rare cases image_break_token="[IMG_BREAK]", @@ -145,27 +145,19 @@ def __call__( # try to expand inputs in processing if we have the necessary parts prompt_strings = text if image_inputs.get("pixel_values") is not None: - if self.patch_size is not None: - # Replace the image token with the expanded image token sequence - pixel_values = image_inputs["pixel_values"] - height, width = get_image_size(to_numpy_array(pixel_values[0])) - num_height_tokens = height // self.patch_size - num_width_tokens = width // self.patch_size - - prompt_strings = [] - replace_tokens = [self.image_token] * num_width_tokens + [self.image_break_token] * num_height_tokens - replace_tokens[-1] = self.image_end_token - replace_str = "".join(replace_tokens) - for sample in text: - sample = sample.replace(self.image_token, replace_str) - prompt_strings.append(sample) - else: - logger.warning_once( - "Expanding inputs for image tokens in Pixtral should be done in processing. " - "Please add `patch_size` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}`" - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) + # Replace the image token with the expanded image token sequence + pixel_values = image_inputs["pixel_values"] + height, width = get_image_size(to_numpy_array(pixel_values[0])) + num_height_tokens = height // self.patch_size + num_width_tokens = width // self.patch_size + + prompt_strings = [] + replace_tokens = [self.image_token] * num_width_tokens + [self.image_break_token] * num_height_tokens + replace_tokens[-1] = self.image_end_token + replace_str = "".join(replace_tokens) + for sample in text: + sample = sample.replace(self.image_token, replace_str) + prompt_strings.append(sample) text_inputs = self.tokenizer( prompt_strings, From 8df9433bd82e266e33d890e8a400570edd2ddece Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Thu, 12 Sep 2024 20:10:15 +0100 Subject: [PATCH 6/6] Fix --- src/transformers/models/pixtral/processing_pixtral.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index 8202ac32c1e6..e9d3d581189a 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -53,7 +53,6 @@ class PixtralProcessor(ProcessorMixin): valid_kwargs = [ "chat_template", "patch_size", - "vision_feature_select_strategy", "image_token", "image_break_token", ] @@ -72,7 +71,6 @@ def __init__( **kwargs, ): self.patch_size = patch_size - self.vision_feature_select_strategy = vision_feature_select_strategy self.image_token = image_token self.image_break_token = image_break_token self.image_end_token = image_end_token