From 9fe675865fc8be0cdc79f02141a6dc4c10b82373 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Fri, 13 Sep 2024 11:14:35 +0100 Subject: [PATCH] Process list of list of images --- .../pixtral/image_processing_pixtral.py | 127 +++++++++++++----- .../models/pixtral/processing_pixtral.py | 40 ++++-- 2 files changed, 117 insertions(+), 50 deletions(-) diff --git a/src/transformers/models/pixtral/image_processing_pixtral.py b/src/transformers/models/pixtral/image_processing_pixtral.py index 3481b793ec7e..f3a67f481a0b 100644 --- a/src/transformers/models/pixtral/image_processing_pixtral.py +++ b/src/transformers/models/pixtral/image_processing_pixtral.py @@ -14,7 +14,7 @@ # limitations under the License. """Image processor class for Pixtral.""" -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -31,6 +31,7 @@ get_image_size, infer_channel_dimension_format, is_scaled_image, + is_valid_image, make_list_of_images, to_numpy_array, valid_images, @@ -48,7 +49,40 @@ import PIL -# Adapted from function in image_transforms.py t oensure any transparent pixels are converted to white. +# Copied from transformers.models.idefics2.image_processing_idefics2.make_list_of_images +def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]: + """ + Convert a single image or a list of images to a list of numpy arrays. + + Args: + images (`ImageInput`): + A single image or a list of images. + + Returns: + A list of numpy arrays. + """ + # If it's a single image, convert it to a list of lists + if is_valid_image(images): + images = [[images]] + # If it's a list of images, it's a single batch, so convert it to a list of lists + elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]): + images = [images] + # If it's a list of batches, it's already in the right format + elif ( + isinstance(images, (list, tuple)) + and len(images) > 0 + and isinstance(images[0], (list, tuple)) + and is_valid_image(images[0][0]) + ): + pass + else: + raise ValueError( + "Invalid input type. Must be a single image, a list of images, or a list of batches of images." + ) + return images + + +# Adapted from function in image_transforms.py to ensure any transparent pixels are converted to white. def convert_to_rgb(image: ImageInput) -> ImageInput: """ Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image @@ -134,6 +168,18 @@ def get_resize_output_image_size( return num_height_tokens * patch_height, num_width_tokens * patch_width +# Hack to get tensor conversion used in BatchFeature without batching the images +def _get_is_as_tensor_fns(tensor_type: Union[str, TensorType]) -> Tuple[Callable, Callable]: + return BatchFeature()._get_is_as_tensor_fns(tensor_type) + + +def convert_to_tensor(array, tensor_type: Union[str, TensorType]) -> Any: + is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type) + if is_tensor(array): + return array + return as_tensor(array) + + class PixtralImageProcessor(BaseImageProcessor): r""" Constructs a Pixtral image processor. @@ -333,11 +379,11 @@ def preprocess( - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. """ + patch_size = patch_size if patch_size is not None else self.patch_size patch_size = get_size_dict(patch_size, default_to_square=True) do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size - patch_size = patch_size if patch_size is not None else self.patch_size resample = resample if resample is not None else self.resample do_rescale = do_rescale if do_rescale is not None else self.do_rescale rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor @@ -348,13 +394,14 @@ def preprocess( validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) - images = make_list_of_images(images) + images_list = make_list_of_images(images) - if not valid_images(images): + if not valid_images(images_list[0][0]): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) + validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, @@ -367,12 +414,12 @@ def preprocess( ) if do_convert_rgb: - images = [convert_to_rgb(image) for image in images] + images_list = [[convert_to_rgb(image) for image in images] for images in images_list] # All transformations expect numpy arrays. - images = [to_numpy_array(image) for image in images] + images_list = [[to_numpy_array(image) for image in images] for images in images_list] - if is_scaled_image(images[0]) and do_rescale: + if is_scaled_image(images_list[0][0]) and do_rescale: logger.warning_once( "It looks like you are trying to rescale already rescaled images. If the input" " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." @@ -380,33 +427,41 @@ def preprocess( if input_data_format is None: # We assume that all images have the same channel dimension format. - input_data_format = infer_channel_dimension_format(images[0]) - - all_images = [] - for image in images: - if do_resize: - image = self.resize( - image=image, - size=size, - patch_size=patch_size, - resample=resample, - input_data_format=input_data_format, - ) - - if do_rescale: - image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) - - if do_normalize: - image = self.normalize( - image=image, mean=image_mean, std=image_std, input_data_format=input_data_format - ) - - all_images.append(image) - - images = [ - to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) - for image in all_images + input_data_format = infer_channel_dimension_format(images_list[0][0]) + + batch_images = [] + batch_image_sizes = [] + for sample_images in images_list: + images = [] + image_sizes = [] + for image in sample_images: + if do_resize: + image = self.resize( + image=image, + size=size, + patch_size=patch_size, + resample=resample, + input_data_format=input_data_format, + ) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + images.append(image) + image_sizes.append(get_image_size(image, input_data_format)) + batch_images.append(images) + batch_image_sizes.append(image_sizes) + + images_list = [ + [to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images] + for images in batch_images ] - data = {"pixel_values": images} - return BatchFeature(data=data, tensor_type=return_tensors) + # Convert to tensor type outside of BatchFeature to avoid batching the images of different sizes + images_list = [[convert_to_tensor(image, return_tensors) for image in images] for images in images_list] + return BatchFeature(data={"images": images_list, "image_sizes": batch_image_sizes}, tensor_type=None) diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index c1da48808036..9a9b7a23a413 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -19,7 +19,7 @@ from typing import List, Optional, Union from ...feature_extraction_utils import BatchFeature -from ...image_utils import ImageInput, get_image_size, to_numpy_array +from ...image_utils import ImageInput from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy from ...utils import TensorType, logging @@ -146,21 +146,33 @@ def __call__( # try to expand inputs in processing if we have the necessary parts prompt_strings = text - if image_inputs.get("pixel_values") is not None: + if image_inputs.get("images") is not None: # Replace the image token with the expanded image token sequence - pixel_values = image_inputs["pixel_values"] - height, width = get_image_size(to_numpy_array(pixel_values[0])) - num_height_tokens = height // self.patch_size - num_width_tokens = width // self.patch_size - + images = image_inputs["images"] + image_sizes = image_inputs.pop("image_sizes") prompt_strings = [] - replace_tokens = [[self.image_token] * num_width_tokens + [self.image_break_token]] * num_height_tokens - # Flatten list - replace_tokens = [item for sublist in replace_tokens for item in sublist] - replace_tokens[-1] = self.image_end_token - replace_str = "".join(replace_tokens) - for sample in text: - sample = sample.replace(self.image_token, replace_str) + + for sample_images, sample_image_sizes, sample in zip(images, image_sizes, text): + replace_strings = [] + # First calculate the number of tokens needed for each image and put in a placeholder + for image, image_size in zip(sample_images, sample_image_sizes): + height, width = image_size + num_height_tokens = height // self.patch_size + num_width_tokens = width // self.patch_size + replace_tokens = [ + [self.image_token] * num_width_tokens + [self.image_break_token] + ] * num_height_tokens + # Flatten list + replace_tokens = [item for sublist in replace_tokens for item in sublist] + replace_tokens[-1] = self.image_end_token + replace_str = "".join(replace_tokens) + replace_strings.append(replace_str) + sample = sample.replace(self.image_token, "", 1) + + while "" in sample: + replace_str = replace_strings.pop(0) + sample = sample.replace("", replace_str, 1) + prompt_strings.append(sample) text_inputs = self.tokenizer(