From 4f0c0908d4c40b3154914ae802c9752f591f8924 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Mon, 28 Jul 2025 17:46:49 +0000 Subject: [PATCH 1/2] add fast image processor Janus, deepseek_vl, deepseek_vl_hybrid --- docs/source/en/model_doc/deepseek_vl.md | 4 + .../source/en/model_doc/deepseek_vl_hybrid.md | 4 + docs/source/en/model_doc/janus.md | 14 +- .../models/auto/image_processing_auto.py | 6 +- .../deepseek_vl/configuration_deepseek_vl.py | 4 +- .../image_processing_deepseek_vl.py | 4 - .../image_processing_deepseek_vl_fast.py | 199 +++++++++++ .../models/deepseek_vl/modular_deepseek_vl.py | 13 + .../models/deepseek_vl_hybrid/__init__.py | 1 + .../image_processing_deepseek_vl_hybrid.py | 5 - ...mage_processing_deepseek_vl_hybrid_fast.py | 325 ++++++++++++++++++ .../modular_deepseek_vl_hybrid.py | 204 ++++++++++- src/transformers/models/janus/__init__.py | 1 + .../janus/image_processing_janus_fast.py | 245 +++++++++++++ .../test_image_processing_deepseek_vl.py | 39 ++- ...est_image_processing_deepseek_vl_hybrid.py | 61 +++- .../janus/test_image_processing_janus.py | 190 ++++++---- 17 files changed, 1227 insertions(+), 92 deletions(-) create mode 100644 src/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py create mode 100644 src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py create mode 100644 src/transformers/models/janus/image_processing_janus_fast.py diff --git a/docs/source/en/model_doc/deepseek_vl.md b/docs/source/en/model_doc/deepseek_vl.md index 625a2c90b01f..b01ef7064a73 100644 --- a/docs/source/en/model_doc/deepseek_vl.md +++ b/docs/source/en/model_doc/deepseek_vl.md @@ -209,6 +209,10 @@ model = DeepseekVLForConditionalGeneration.from_pretrained( [[autodoc]] DeepseekVLImageProcessor +## DeepseekVLImageProcessorFast + +[[autodoc]] DeepseekVLImageProcessorFast + ## DeepseekVLModel [[autodoc]] DeepseekVLModel diff --git a/docs/source/en/model_doc/deepseek_vl_hybrid.md b/docs/source/en/model_doc/deepseek_vl_hybrid.md index 86e1672bce59..e713782748c9 100644 --- a/docs/source/en/model_doc/deepseek_vl_hybrid.md +++ b/docs/source/en/model_doc/deepseek_vl_hybrid.md @@ -208,6 +208,10 @@ model = DeepseekVLHybridForConditionalGeneration.from_pretrained( [[autodoc]] DeepseekVLHybridImageProcessor +## DeepseekVLHybridImageProcessorFast + +[[autodoc]] DeepseekVLHybridImageProcessorFast + ## DeepseekVLHybridModel [[autodoc]] DeepseekVLHybridModel diff --git a/docs/source/en/model_doc/janus.md b/docs/source/en/model_doc/janus.md index d3973c45c11a..f2825cbc9738 100644 --- a/docs/source/en/model_doc/janus.md +++ b/docs/source/en/model_doc/janus.md @@ -44,11 +44,11 @@ Here is the example of visual understanding with a single image. > Note that the model has been trained with a specific prompt format for chatting. Use `processor.apply_chat_template(my_conversation_dict)` to correctly format your prompts. ```python -import torch -from PIL import Image -import requests +import torch +from PIL import Image +import requests -from transformers import JanusForConditionalGeneration, JanusProcessor +from transformers import JanusForConditionalGeneration, JanusProcessor model_id = "deepseek-community/Janus-Pro-1B" # Prepare Input for generation. @@ -64,7 +64,7 @@ messages = [ # Set generation mode to `text` to perform text generation. processor = JanusProcessor.from_pretrained(model_id) -model = JanusForConditionalGeneration.from_pretrained(model_id, +model = JanusForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto") @@ -209,6 +209,10 @@ for i, image in enumerate(images['pixel_values']): [[autodoc]] JanusImageProcessor +## JanusImageProcessorFast + +[[autodoc]] JanusImageProcessorFast + ## JanusVisionModel [[autodoc]] JanusVisionModel diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 0a0cc6a38ca4..594529cf4726 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -77,8 +77,8 @@ ("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), ("cvt", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), ("data2vec-vision", ("BeitImageProcessor", "BeitImageProcessorFast")), - ("deepseek_vl", ("DeepseekVLImageProcessor")), - ("deepseek_vl_hybrid", ("DeepseekVLHybridImageProcessor")), + ("deepseek_vl", ("DeepseekVLImageProcessor", "DeepseekVLImageProcessorFast")), + ("deepseek_vl_hybrid", ("DeepseekVLHybridImageProcessor", "DeepseekVLHybridImageProcessorFast")), ("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")), ("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")), ("depth_anything", ("DPTImageProcessor", "DPTImageProcessorFast")), @@ -112,7 +112,7 @@ ("imagegpt", ("ImageGPTImageProcessor",)), ("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")), ("instructblipvideo", ("InstructBlipVideoImageProcessor",)), - ("janus", ("JanusImageProcessor")), + ("janus", ("JanusImageProcessor", "JanusImageProcessorFast")), ("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")), ("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")), diff --git a/src/transformers/models/deepseek_vl/configuration_deepseek_vl.py b/src/transformers/models/deepseek_vl/configuration_deepseek_vl.py index af99ac9eeb3f..cfe008635090 100644 --- a/src/transformers/models/deepseek_vl/configuration_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/configuration_deepseek_vl.py @@ -20,7 +20,9 @@ from ...configuration_utils import PretrainedConfig -from ...utils import logging +from ...utils import ( + logging, +) from ..auto import CONFIG_MAPPING, AutoConfig diff --git a/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py b/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py index fad24220ef87..8cf4acbf937f 100644 --- a/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py @@ -406,9 +406,5 @@ def pad_to_square( return result - def postprocess(self): - """Applies post-processing to the decoded image tokens by reversing transformations applied during preprocessing.""" - raise AttributeError("Not needed for DeepseekVL") - __all__ = ["DeepseekVLImageProcessor"] diff --git a/src/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py b/src/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py new file mode 100644 index 000000000000..5bebf43c9b6c --- /dev/null +++ b/src/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py @@ -0,0 +1,199 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/deepseek_vl/modular_deepseek_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_deepseek_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch.nn.functional as F + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling, SizeDict +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + is_torch_available, +) + + +if is_torch_available(): + import torch + + +class DeepseekVLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + r""" + min_size (`int`, *optional*, defaults to 14): + The minimum allowed size for the resized image. Ensures that neither the height nor width + falls below this value after resizing. + """ + + min_size: int + + +@auto_docstring +class DeepseekVLImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + size = {"height": 384, "width": 384} + min_size = 14 + do_resize = True + do_rescale = True + do_normalize = True + valid_kwargs = DeepseekVLFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[DeepseekVLFastImageProcessorKwargs]): + super().__init__(**kwargs) + if kwargs.get("image_mean", None) is None: + background_color = (127, 127, 127) + else: + background_color = tuple([int(x * 255) for x in kwargs.get("image_mean")]) + self.background_color = tuple(background_color) + + def resize( + self, + image: "torch.Tensor", + size: SizeDict, + min_size: int, + interpolation: "F.InterpolationMode" = None, + antialias: bool = True, + **kwargs, + ) -> "torch.Tensor": + if size.height is None or size.width is None or size.height != size.width: + raise ValueError( + f"Output height and width must be the same. Got height={size['height']} and width={size['width']}" + ) + size = size.height + + height, width = image.shape[-2:] + max_size = max(height, width) + + delta = size / max_size + # Largest side becomes `size` and the other side is scaled according to the aspect ratio. + output_size_nonpadded = SizeDict( + height=max(int(height * delta), min_size), + width=max(int(width * delta), min_size), + ) + + return super().resize(image, size=output_size_nonpadded, interpolation=interpolation, antialias=antialias) + + def pad_to_square( + self, + images: "torch.Tensor", + background_color: Union[int, tuple[int, int, int]] = 0, + ) -> "torch.Tensor": + """ + Pads an image to a square based on the longest edge. + + Args: + images (`torch.Tensor`): + The images to pad. + background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0): + The color to use for the padding. Can be an integer for single channel or a + tuple of integers representing for multi-channel images. If passed as integer + in mutli-channel mode, it will default to `0` in subsequent channels. + + Returns: + `torch.Tensor`: The padded images. + """ + height, width = images.shape[-2:] + num_channels = images.shape[1] + batch_size = images.shape[0] + + if height == width: + return images + + max_dim = max(height, width) + + # Ensure background_color is the correct shape + if isinstance(background_color, int): + background_color = [background_color] + elif len(background_color) != num_channels: + raise ValueError( + f"background_color must have no more than {num_channels} elements to match the number of channels" + ) + + padded_images = torch.zeros( + (batch_size, num_channels, max_dim, max_dim), dtype=images.dtype, device=images.device + ) + for i, color in enumerate(background_color): + padded_images[:, i, :, :] = color + if width > height: + start = (max_dim - height) // 2 + padded_images[:, :, start : start + height, :] = images + else: + start = (max_dim - width) // 2 + padded_images[:, :, :, start : start + width] = images + + return padded_images + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + min_size: int, + interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + do_pad: bool = True, + **kwargs, + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize( + image=stacked_images, size=size, min_size=min_size, interpolation=interpolation + ) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_pad: + stacked_images = self.pad_to_square(stacked_images, background_color=self.background_color) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + +__all__ = ["DeepseekVLImageProcessorFast"] diff --git a/src/transformers/models/deepseek_vl/modular_deepseek_vl.py b/src/transformers/models/deepseek_vl/modular_deepseek_vl.py index a5190a280b02..b9f3fc37ba7a 100644 --- a/src/transformers/models/deepseek_vl/modular_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/modular_deepseek_vl.py @@ -33,6 +33,7 @@ from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel from ..idefics.modeling_idefics import IdeficsBaseModelOutputWithPast, IdeficsCausalLMOutputWithPast from ..janus.image_processing_janus import JanusImageProcessor +from ..janus.image_processing_janus_fast import JanusImageProcessorFast from ..janus.modeling_janus import JanusForConditionalGeneration, JanusModel, JanusPreTrainedModel @@ -181,6 +182,9 @@ def generate(self): class DeepseekVLImageProcessor(JanusImageProcessor): + def __init__(self, **super_kwargs): + super().__init__(**super_kwargs) + def postprocess(self): raise AttributeError("Not needed for DeepseekVL") @@ -188,6 +192,14 @@ def unnormalize(self): raise AttributeError("Not needed for DeepseekVL") +class DeepseekVLImageProcessorFast(JanusImageProcessorFast): + def __init__(self, **super_kwargs): + super().__init__(**super_kwargs) + + def postprocess(self): + raise AttributeError("Not needed for DeepseekVL") + + class DeepseekVLProcessorKwargs(ProcessingKwargs, total=False): _defaults = { "text_kwargs": {"padding": False}, @@ -322,5 +334,6 @@ def model_input_names(self): "DeepseekVLModel", "DeepseekVLForConditionalGeneration", "DeepseekVLImageProcessor", + "DeepseekVLImageProcessorFast", "DeepseekVLProcessor", ] diff --git a/src/transformers/models/deepseek_vl_hybrid/__init__.py b/src/transformers/models/deepseek_vl_hybrid/__init__.py index 1836d196ac0b..da85178ccc84 100644 --- a/src/transformers/models/deepseek_vl_hybrid/__init__.py +++ b/src/transformers/models/deepseek_vl_hybrid/__init__.py @@ -21,6 +21,7 @@ from .configuration_deepseek_vl_hybrid import * from .image_processing_deepseek_vl_fast_hybrid import * from .image_processing_deepseek_vl_hybrid import * + from .image_processing_deepseek_vl_hybrid_fast import * from .modeling_deepseek_vl_hybrid import * from .processing_deepseek_vl_hybrid import * else: diff --git a/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py index d42cfbe38bf4..3dc112d3a525 100644 --- a/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py @@ -361,7 +361,6 @@ def preprocess( # high_res_image: resize (high) -> rescale -> normalize (high) # low_res_image: resize (high) -> rescale -> resize (low) -> normalize (low) high_res_image = image - if do_resize: high_res_image = self.resize( image=high_res_image, @@ -475,9 +474,5 @@ def pad_to_square( return result - def postprocess(self): - """Applies post-processing to the decoded image tokens by reversing transformations applied during preprocessing.""" - raise AttributeError("Not needed for DeepseekVLHybrid") - __all__ = ["DeepseekVLHybridImageProcessor"] diff --git a/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py b/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py new file mode 100644 index 000000000000..6bb5696282f9 --- /dev/null +++ b/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py @@ -0,0 +1,325 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_deepseek_vl_hybrid.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch + +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + BatchFeature, + DefaultFastImageProcessorKwargs, + get_size_dict, + group_images_by_shape, + reorder_images, +) +from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ChannelDimension, PILImageResampling, SizeDict +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + + from ...image_utils import pil_torch_interpolation_mapping +elif is_torchvision_available(): + from torchvision.transforms import functional as F + + from ...image_utils import pil_torch_interpolation_mapping + + +class DeepseekVLHybridFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + r""" + min_size (`int`, *optional*, defaults to 14): + The minimum allowed size for the resized image. Ensures that neither the height nor width + falls below this value after resizing. + high_res_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`): + Size of the high resolution output image after resizing. Can be overridden by the `high_res_size` parameter in the `preprocess` + method. + high_res_resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be + overridden by the `high_res_resample` parameter in the `preprocess` method. + high_res_image_mean (`float` or `list[float]`, *optional*, defaults to `OPENAI_CLIP_MEAN`): + Mean to use if normalizing the high resolution image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `high_res_image_mean` parameter in the `preprocess` method. + high_res_image_std (`float` or `list[float]`, *optional*, defaults to `OPENAI_CLIP_STD`): + Standard deviation to use if normalizing the high resolution image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `high_res_image_std` parameter in the `preprocess` method. + """ + + min_size: int + high_res_size: dict + high_res_resample: "PILImageResampling" + high_res_image_mean: list[float] + high_res_image_std: list[float] + + +@auto_docstring +class DeepseekVLHybridImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + size = {"height": 384, "width": 384} + min_size = 14 + do_resize = True + do_rescale = True + do_normalize = True + valid_kwargs = DeepseekVLHybridFastImageProcessorKwargs + high_res_image_mean = OPENAI_CLIP_MEAN + high_res_image_std = OPENAI_CLIP_STD + high_res_size = {"height": 1024, "width": 1024} + high_res_resample = PILImageResampling.BICUBIC + + def __init__(self, **kwargs: Unpack[DeepseekVLHybridFastImageProcessorKwargs]): + if kwargs.get("image_mean", None) is None: + background_color = (127, 127, 127) + else: + background_color = tuple([int(x * 255) for x in kwargs.get("image_mean")]) + if kwargs.get("high_res_image_mean", None) is None: + background_color = (127, 127, 127) + else: + background_color = tuple([int(x * 255) for x in kwargs.get("high_res_image_mean")]) + super().__init__(**kwargs) + self.background_color = tuple(background_color) + + def resize( + self, + image: "torch.Tensor", + size: SizeDict, + min_size: int, + interpolation: "F.InterpolationMode" = None, + antialias: bool = True, + **kwargs, + ) -> "torch.Tensor": + if size.height is None or size.width is None or size.height != size.width: + raise ValueError( + f"Output height and width must be the same. Got height={size['height']} and width={size['width']}" + ) + size = size.height + + height, width = image.shape[-2:] + max_size = max(height, width) + + delta = size / max_size + # Largest side becomes `size` and the other side is scaled according to the aspect ratio. + output_size_nonpadded = SizeDict( + height=max(int(height * delta), min_size), + width=max(int(width * delta), min_size), + ) + + return super().resize(image, size=output_size_nonpadded, interpolation=interpolation, antialias=antialias) + + def pad_to_square( + self, + images: "torch.Tensor", + background_color: Union[int, tuple[int, int, int]] = 0, + ) -> "torch.Tensor": + """ + Pads an image to a square based on the longest edge. + + Args: + images (`torch.Tensor`): + The images to pad. + background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0): + The color to use for the padding. Can be an integer for single channel or a + tuple of integers representing for multi-channel images. If passed as integer + in mutli-channel mode, it will default to `0` in subsequent channels. + + Returns: + `torch.Tensor`: The padded images. + """ + height, width = images.shape[-2:] + num_channels = images.shape[1] + batch_size = images.shape[0] + + if height == width: + return images + + max_dim = max(height, width) + + # Ensure background_color is the correct shape + if isinstance(background_color, int): + background_color = [background_color] + elif len(background_color) != num_channels: + raise ValueError( + f"background_color must have no more than {num_channels} elements to match the number of channels" + ) + + padded_images = torch.zeros( + (batch_size, num_channels, max_dim, max_dim), dtype=images.dtype, device=images.device + ) + for i, color in enumerate(background_color): + padded_images[:, i, :, :] = color + if width > height: + start = (max_dim - height) // 2 + padded_images[:, :, start : start + height, :] = images + else: + start = (max_dim - width) // 2 + padded_images[:, :, :, start : start + width] = images + + return padded_images + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + high_res_size: SizeDict, + min_size: int, + interpolation: Optional["F.InterpolationMode"], + high_res_interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + high_res_image_mean: Optional[Union[float, list[float]]], + high_res_image_std: Optional[Union[float, list[float]]], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + do_pad: bool = True, + **kwargs, + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + high_res_resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_high_res_images = self.resize( + image=stacked_images, size=high_res_size, min_size=min_size, interpolation=high_res_interpolation + ) + high_res_resized_images_grouped[shape] = stacked_high_res_images + high_res_resized_images = reorder_images(high_res_resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_high_res_images, grouped_high_res_images_index = group_images_by_shape( + high_res_resized_images, disable_grouping=disable_grouping + ) + high_res_padded_images = {} + high_res_processed_images_grouped = {} + for shape, stacked_high_res_images in grouped_high_res_images.items(): + if do_pad: + stacked_high_res_images = self.pad_to_square( + stacked_high_res_images, background_color=self.background_color + ) + high_res_padded_images[shape] = stacked_high_res_images + # Fused rescale and normalize + stacked_high_res_images = self.rescale_and_normalize( + stacked_high_res_images, + do_rescale, + rescale_factor, + do_normalize, + high_res_image_mean, + high_res_image_std, + ) + high_res_processed_images_grouped[shape] = stacked_high_res_images + high_res_processed_images = reorder_images(high_res_processed_images_grouped, grouped_high_res_images_index) + high_res_processed_images = ( + torch.stack(high_res_processed_images, dim=0) if return_tensors else high_res_processed_images + ) + + resized_images_grouped = {} + for shape, stacked_high_res_padded_images in high_res_padded_images.items(): + if do_resize: + stacked_images = self.resize( + image=stacked_high_res_padded_images, size=size, min_size=min_size, interpolation=interpolation + ) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_high_res_images_index) + + grouped_resized_images, grouped_resized_images_index = group_images_by_shape( + resized_images, disable_grouping=disable_grouping + ) + processed_images_grouped = {} + for shape, stacked_images in grouped_resized_images.items(): + if do_pad: + stacked_images = self.pad_to_square(stacked_images, background_color=self.background_color) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + processed_images = reorder_images(processed_images_grouped, grouped_resized_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature( + data={"pixel_values": processed_images, "high_res_pixel_values": high_res_processed_images}, + tensor_type=return_tensors, + ) + + def _further_process_kwargs( + self, + size: Optional[SizeDict] = None, + high_res_size: Optional[SizeDict] = None, + default_to_square: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + high_res_image_mean: Optional[Union[float, list[float]]] = None, + high_res_image_std: Optional[Union[float, list[float]]] = None, + data_format: Optional[ChannelDimension] = None, + **kwargs, + ) -> dict: + """ + Update kwargs that need further processing before being validated + Can be overridden by subclasses to customize the processing of kwargs. + """ + if kwargs is None: + kwargs = {} + if size is not None: + size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) + if high_res_size is not None: + high_res_size = SizeDict(**get_size_dict(size=high_res_size, default_to_square=default_to_square)) + if isinstance(image_mean, list): + image_mean = tuple(image_mean) + if isinstance(image_std, list): + image_std = tuple(image_std) + if isinstance(high_res_image_mean, list): + high_res_image_mean = tuple(high_res_image_mean) + if isinstance(high_res_image_std, list): + high_res_image_std = tuple(high_res_image_std) + if data_format is None: + data_format = ChannelDimension.FIRST + + high_res_resample = kwargs.pop("high_res_resample") + kwargs["high_res_interpolation"] = ( + pil_torch_interpolation_mapping[high_res_resample] + if isinstance(high_res_resample, (int, PILImageResampling)) + else high_res_resample + ) + + kwargs["size"] = size + kwargs["high_res_size"] = high_res_size + kwargs["default_to_square"] = default_to_square + kwargs["image_mean"] = image_mean + kwargs["image_std"] = image_std + kwargs["high_res_image_mean"] = high_res_image_mean + kwargs["high_res_image_std"] = high_res_image_std + kwargs["data_format"] = data_format + + return kwargs + + +__all__ = ["DeepseekVLHybridImageProcessorFast"] diff --git a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py index aa0a4f87ba3e..776af4a5cfbc 100644 --- a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py @@ -20,7 +20,10 @@ from ...cache_utils import Cache from ...image_processing_utils_fast import ( BatchFeature, + DefaultFastImageProcessorKwargs, get_size_dict, + group_images_by_shape, + reorder_images, ) from ...image_transforms import convert_to_rgb, to_channel_dimension_format from ...image_utils import ( @@ -29,6 +32,7 @@ ChannelDimension, ImageInput, PILImageResampling, + SizeDict, infer_channel_dimension_format, is_scaled_image, make_flat_list_of_images, @@ -48,11 +52,14 @@ auto_docstring, can_return_tuple, filter_out_non_signature_kwargs, + is_torchvision_available, + is_torchvision_v2_available, logging, ) from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel from ..deepseek_vl.configuration_deepseek_vl import DeepseekVLConfig from ..deepseek_vl.image_processing_deepseek_vl import DeepseekVLImageProcessor +from ..deepseek_vl.image_processing_deepseek_vl_fast import DeepseekVLImageProcessorFast from ..deepseek_vl.modeling_deepseek_vl import ( DeepseekVLForConditionalGeneration, DeepseekVLModel, @@ -63,6 +70,16 @@ from ..sam.modeling_sam import SamLayerNorm, SamVisionNeck +if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + + from ...image_utils import pil_torch_interpolation_mapping +elif is_torchvision_available(): + from torchvision.transforms import functional as F + + from ...image_utils import pil_torch_interpolation_mapping + + logger = logging.get_logger(__name__) @@ -654,7 +671,6 @@ def preprocess( # high_res_image: resize (high) -> rescale -> normalize (high) # low_res_image: resize (high) -> rescale -> resize (low) -> normalize (low) high_res_image = image - if do_resize: high_res_image = self.resize( image=high_res_image, @@ -695,6 +711,191 @@ def preprocess( return BatchFeature(data=data, tensor_type=return_tensors) +class DeepseekVLHybridFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + r""" + min_size (`int`, *optional*, defaults to 14): + The minimum allowed size for the resized image. Ensures that neither the height nor width + falls below this value after resizing. + high_res_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`): + Size of the high resolution output image after resizing. Can be overridden by the `high_res_size` parameter in the `preprocess` + method. + high_res_resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be + overridden by the `high_res_resample` parameter in the `preprocess` method. + high_res_image_mean (`float` or `list[float]`, *optional*, defaults to `OPENAI_CLIP_MEAN`): + Mean to use if normalizing the high resolution image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `high_res_image_mean` parameter in the `preprocess` method. + high_res_image_std (`float` or `list[float]`, *optional*, defaults to `OPENAI_CLIP_STD`): + Standard deviation to use if normalizing the high resolution image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `high_res_image_std` parameter in the `preprocess` method. + """ + + min_size: int + high_res_size: dict + high_res_resample: "PILImageResampling" + high_res_image_mean: list[float] + high_res_image_std: list[float] + + +class DeepseekVLHybridImageProcessorFast(DeepseekVLImageProcessorFast): + high_res_image_mean = OPENAI_CLIP_MEAN + high_res_image_std = OPENAI_CLIP_STD + high_res_size = {"height": 1024, "width": 1024} + high_res_resample = PILImageResampling.BICUBIC + + def __init__(self, **kwargs: Unpack[DeepseekVLHybridFastImageProcessorKwargs]): + if kwargs.get("image_mean", None) is None: + background_color = (127, 127, 127) + else: + background_color = tuple([int(x * 255) for x in kwargs.get("image_mean")]) + if kwargs.get("high_res_image_mean", None) is None: + background_color = (127, 127, 127) + else: + background_color = tuple([int(x * 255) for x in kwargs.get("high_res_image_mean")]) + DeepseekVLImageProcessorFast().__init__(**kwargs) + self.background_color = tuple(background_color) + + def _further_process_kwargs( + self, + size: Optional[SizeDict] = None, + high_res_size: Optional[SizeDict] = None, + default_to_square: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + high_res_image_mean: Optional[Union[float, list[float]]] = None, + high_res_image_std: Optional[Union[float, list[float]]] = None, + data_format: Optional[ChannelDimension] = None, + **kwargs, + ) -> dict: + """ + Update kwargs that need further processing before being validated + Can be overridden by subclasses to customize the processing of kwargs. + """ + if kwargs is None: + kwargs = {} + if size is not None: + size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) + if high_res_size is not None: + high_res_size = SizeDict(**get_size_dict(size=high_res_size, default_to_square=default_to_square)) + if isinstance(image_mean, list): + image_mean = tuple(image_mean) + if isinstance(image_std, list): + image_std = tuple(image_std) + if isinstance(high_res_image_mean, list): + high_res_image_mean = tuple(high_res_image_mean) + if isinstance(high_res_image_std, list): + high_res_image_std = tuple(high_res_image_std) + if data_format is None: + data_format = ChannelDimension.FIRST + + high_res_resample = kwargs.pop("high_res_resample") + kwargs["high_res_interpolation"] = ( + pil_torch_interpolation_mapping[high_res_resample] + if isinstance(high_res_resample, (int, PILImageResampling)) + else high_res_resample + ) + + kwargs["size"] = size + kwargs["high_res_size"] = high_res_size + kwargs["default_to_square"] = default_to_square + kwargs["image_mean"] = image_mean + kwargs["image_std"] = image_std + kwargs["high_res_image_mean"] = high_res_image_mean + kwargs["high_res_image_std"] = high_res_image_std + kwargs["data_format"] = data_format + + return kwargs + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + high_res_size: SizeDict, + min_size: int, + interpolation: Optional["F.InterpolationMode"], + high_res_interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + high_res_image_mean: Optional[Union[float, list[float]]], + high_res_image_std: Optional[Union[float, list[float]]], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + do_pad: bool = True, + **kwargs, + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + high_res_resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_high_res_images = self.resize( + image=stacked_images, size=high_res_size, min_size=min_size, interpolation=high_res_interpolation + ) + high_res_resized_images_grouped[shape] = stacked_high_res_images + high_res_resized_images = reorder_images(high_res_resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_high_res_images, grouped_high_res_images_index = group_images_by_shape( + high_res_resized_images, disable_grouping=disable_grouping + ) + high_res_padded_images = {} + high_res_processed_images_grouped = {} + for shape, stacked_high_res_images in grouped_high_res_images.items(): + if do_pad: + stacked_high_res_images = self.pad_to_square( + stacked_high_res_images, background_color=self.background_color + ) + high_res_padded_images[shape] = stacked_high_res_images + # Fused rescale and normalize + stacked_high_res_images = self.rescale_and_normalize( + stacked_high_res_images, + do_rescale, + rescale_factor, + do_normalize, + high_res_image_mean, + high_res_image_std, + ) + high_res_processed_images_grouped[shape] = stacked_high_res_images + high_res_processed_images = reorder_images(high_res_processed_images_grouped, grouped_high_res_images_index) + high_res_processed_images = ( + torch.stack(high_res_processed_images, dim=0) if return_tensors else high_res_processed_images + ) + + resized_images_grouped = {} + for shape, stacked_high_res_padded_images in high_res_padded_images.items(): + if do_resize: + stacked_images = self.resize( + image=stacked_high_res_padded_images, size=size, min_size=min_size, interpolation=interpolation + ) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_high_res_images_index) + + grouped_resized_images, grouped_resized_images_index = group_images_by_shape( + resized_images, disable_grouping=disable_grouping + ) + processed_images_grouped = {} + for shape, stacked_images in grouped_resized_images.items(): + if do_pad: + stacked_images = self.pad_to_square(stacked_images, background_color=self.background_color) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + processed_images = reorder_images(processed_images_grouped, grouped_resized_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature( + data={"pixel_values": processed_images, "high_res_pixel_values": high_res_processed_images}, + tensor_type=return_tensors, + ) + + class DeepseekVLHybridProcessorKwargs(DeepseekVLProcessorKwargs): pass @@ -773,5 +974,6 @@ def __call__( "DeepseekVLHybridModel", "DeepseekVLHybridForConditionalGeneration", "DeepseekVLHybridImageProcessor", + "DeepseekVLHybridImageProcessorFast", "DeepseekVLHybridProcessor", ] diff --git a/src/transformers/models/janus/__init__.py b/src/transformers/models/janus/__init__.py index 06bc90cd938a..8aacc2ed6fdb 100644 --- a/src/transformers/models/janus/__init__.py +++ b/src/transformers/models/janus/__init__.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from .configuration_janus import * from .image_processing_janus import * + from .image_processing_janus_fast import * from .modeling_janus import * from .processing_janus import * else: diff --git a/src/transformers/models/janus/image_processing_janus_fast.py b/src/transformers/models/janus/image_processing_janus_fast.py new file mode 100644 index 000000000000..81f9bafed767 --- /dev/null +++ b/src/transformers/models/janus/image_processing_janus_fast.py @@ -0,0 +1,245 @@ +# coding=utf-8 +# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional, Union + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ImageInput, + PILImageResampling, + SizeDict, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torch_available(): + import torch +if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F +elif is_torchvision_available(): + from torchvision.transforms import functional as F + + +class JanusFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + r""" + min_size (`int`, *optional*, defaults to 14): + The minimum allowed size for the resized image. Ensures that neither the height nor width + falls below this value after resizing. + """ + + min_size: int + + +@auto_docstring +class JanusImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + size = {"height": 384, "width": 384} + min_size = 14 + do_resize = True + do_rescale = True + do_normalize = True + valid_kwargs = JanusFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[JanusFastImageProcessorKwargs]): + if kwargs.get("image_mean", None) is None: + background_color = (127, 127, 127) + else: + background_color = tuple([int(x * 255) for x in kwargs.get("image_mean")]) + super().__init__(**kwargs) + self.background_color = tuple(background_color) + + def resize( + self, + image: "torch.Tensor", + size: SizeDict, + min_size: int, + interpolation: "F.InterpolationMode" = None, + antialias: bool = True, + **kwargs, + ) -> "torch.Tensor": + if size.height is None or size.width is None or size.height != size.width: + raise ValueError( + f"Output height and width must be the same. Got height={size['height']} and width={size['width']}" + ) + size = size.height + + height, width = image.shape[-2:] + max_size = max(height, width) + + delta = size / max_size + # Largest side becomes `size` and the other side is scaled according to the aspect ratio. + output_size_nonpadded = SizeDict( + height=max(int(height * delta), min_size), + width=max(int(width * delta), min_size), + ) + + return super().resize(image, size=output_size_nonpadded, interpolation=interpolation, antialias=antialias) + + def pad_to_square( + self, + images: "torch.Tensor", + background_color: Union[int, tuple[int, int, int]] = 0, + ) -> "torch.Tensor": + """ + Pads an image to a square based on the longest edge. + + Args: + images (`torch.Tensor`): + The images to pad. + background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0): + The color to use for the padding. Can be an integer for single channel or a + tuple of integers representing for multi-channel images. If passed as integer + in mutli-channel mode, it will default to `0` in subsequent channels. + + Returns: + `torch.Tensor`: The padded images. + """ + height, width = images.shape[-2:] + num_channels = images.shape[1] + batch_size = images.shape[0] + + if height == width: + return images + + max_dim = max(height, width) + + # Ensure background_color is the correct shape + if isinstance(background_color, int): + background_color = [background_color] + elif len(background_color) != num_channels: + raise ValueError( + f"background_color must have no more than {num_channels} elements to match the number of channels" + ) + + padded_images = torch.zeros( + (batch_size, num_channels, max_dim, max_dim), dtype=images.dtype, device=images.device + ) + for i, color in enumerate(background_color): + padded_images[:, i, :, :] = color + if width > height: + start = (max_dim - height) // 2 + padded_images[:, :, start : start + height, :] = images + else: + start = (max_dim - width) // 2 + padded_images[:, :, :, start : start + width] = images + + return padded_images + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + min_size: int, + interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + do_pad: bool = True, + **kwargs, + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize( + image=stacked_images, size=size, min_size=min_size, interpolation=interpolation + ) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_pad: + stacked_images = self.pad_to_square(stacked_images, background_color=self.background_color) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + def postprocess( + self, + images: ImageInput, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[list[float]] = None, + image_std: Optional[list[float]] = None, + return_tensors: Optional[str] = None, + ) -> "torch.Tensor": + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = 1.0 / self.rescale_factor if rescale_factor is None else rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + image_mean = tuple(-rescale_factor * mean / std for mean, std in zip(image_mean, image_std)) + image_std = tuple(1 / std for std in image_std) + + images = self.preprocess( + images, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=False, + do_pad=False, + return_tensors=return_tensors, + ).pixel_values + if do_rescale: + images = [image.clip(0, 255).to(torch.uint8) for image in images] + + if do_normalize and do_rescale and return_tensors == "PIL.Image.Image": + images = [F.to_pil_image(image) for image in images] + + data = {"pixel_values": images} + return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None + + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["JanusImageProcessorFast"] diff --git a/tests/models/deepseek_vl/test_image_processing_deepseek_vl.py b/tests/models/deepseek_vl/test_image_processing_deepseek_vl.py index c1092f05d336..3156bd043456 100644 --- a/tests/models/deepseek_vl/test_image_processing_deepseek_vl.py +++ b/tests/models/deepseek_vl/test_image_processing_deepseek_vl.py @@ -17,14 +17,21 @@ import unittest from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs +if is_torch_available(): + import torch + + if is_vision_available(): from transformers import DeepseekVLImageProcessor + if is_torchvision_available(): + from transformers import DeepseekVLImageProcessorFast + # Copied from tests.models.vit.test_image_processing_vit.ViTImageProcessingTester with ViT->DeepseekVL class DeepseekVLImageProcessingTester: @@ -83,10 +90,9 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F @require_torch @require_vision -# Copied from tests.models.vit.test_image_processing_vit.ViTImageProcessingTest with ViT->DeepseekVL class DeepseekVLImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): - # Ignore copy image_processing_class = DeepseekVLImageProcessor if is_vision_available() else None + fast_image_processing_class = DeepseekVLImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -113,6 +119,33 @@ def test_image_processor_from_dict_with_kwargs(self): image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42) self.assertEqual(image_processor.size, {"height": 42, "width": 42}) + @require_vision + @require_torch + def test_slow_fast_equivalence_batched(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop: + self.skipTest( + reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors" + ) + + dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_images, return_tensors=None) + encoding_fast = image_processor_fast(dummy_images, return_tensors=None) + + # Overwrite as the outputs are not always all of the same shape (kept for BC) + for i in range(len(encoding_slow.pixel_values)): + self._assert_slow_fast_tensors_equivalence( + torch.from_numpy(encoding_slow.pixel_values[i]), encoding_fast.pixel_values[i] + ) + # Ignore copy @unittest.skip(reason="Not supported") def test_call_numpy_4_channels(self): diff --git a/tests/models/deepseek_vl_hybrid/test_image_processing_deepseek_vl_hybrid.py b/tests/models/deepseek_vl_hybrid/test_image_processing_deepseek_vl_hybrid.py index b7eaefd71a81..554219f1cc4c 100644 --- a/tests/models/deepseek_vl_hybrid/test_image_processing_deepseek_vl_hybrid.py +++ b/tests/models/deepseek_vl_hybrid/test_image_processing_deepseek_vl_hybrid.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. - import unittest import numpy as np +import requests from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -32,6 +32,9 @@ from transformers import DeepseekVLHybridImageProcessor + if is_torchvision_available(): + from transformers import DeepseekVLHybridImageProcessorFast + class DeepseekVLHybridImageProcessingTester: def __init__( @@ -104,6 +107,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F @require_vision class DeepseekVLHybridImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = DeepseekVLHybridImageProcessor if is_vision_available() else None + fast_image_processing_class = DeepseekVLHybridImageProcessorFast if is_torchvision_available() else None # Copied from tests.models.vit.test_image_processing_vit.ViTImageProcessingTester.setUp with ViT->DeepseekVLHybrid def setUp(self): @@ -213,6 +217,59 @@ def test_call_pytorch_high_res(self): (self.image_processor_tester.batch_size, *expected_output_image_shape), ) + @require_vision + @require_torch + def test_slow_fast_equivalence(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + dummy_image = Image.open( + requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw + ) + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") + self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values) + self._assert_slow_fast_tensors_equivalence( + encoding_slow.high_res_pixel_values, encoding_fast.high_res_pixel_values + ) + + @require_vision + @require_torch + def test_slow_fast_equivalence_batched(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop: + self.skipTest( + reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors" + ) + + dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_images, return_tensors=None) + encoding_fast = image_processor_fast(dummy_images, return_tensors=None) + + # Overwrite as the outputs are not always all of the same shape (kept for BC) + for i in range(len(encoding_slow.pixel_values)): + self._assert_slow_fast_tensors_equivalence( + torch.from_numpy(encoding_slow.pixel_values[i]), encoding_fast.pixel_values[i] + ) + for i in range(len(encoding_slow.high_res_pixel_values)): + self._assert_slow_fast_tensors_equivalence( + torch.from_numpy(encoding_slow.high_res_pixel_values[i]), encoding_fast.high_res_pixel_values[i] + ) + @unittest.skip(reason="Not supported") def test_call_numpy_4_channels(self): pass diff --git a/tests/models/janus/test_image_processing_janus.py b/tests/models/janus/test_image_processing_janus.py index 184f669e6a58..843ef834ac91 100644 --- a/tests/models/janus/test_image_processing_janus.py +++ b/tests/models/janus/test_image_processing_janus.py @@ -18,7 +18,7 @@ import numpy as np from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -31,6 +31,9 @@ from transformers import JanusImageProcessor + if is_torchvision_available(): + from transformers import JanusImageProcessorFast + class JanusImageProcessingTester: def __init__( @@ -44,8 +47,8 @@ def __init__( do_resize=True, size=None, do_normalize=True, - image_mean=[1.0, 1.0, 1.0], - image_std=[1.0, 1.0, 1.0], + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, ): size = size if size is not None else {"height": 384, "width": 384} @@ -89,6 +92,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F @require_vision class JanusImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = JanusImageProcessor if is_vision_available() else None + fast_image_processing_class = JanusImageProcessorFast if is_torchvision_available() else None # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->Janus def setUp(self): @@ -101,87 +105,137 @@ def image_processor_dict(self): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) - self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"height": 384, "width": 384}) - self.assertEqual(image_processor.image_mean, [1.0, 1.0, 1.0]) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 384, "width": 384}) + self.assertEqual(image_processor.image_mean, [0.48145466, 0.4578275, 0.40821073]) - image_processor = self.image_processing_class.from_dict( - self.image_processor_dict, size=42, image_mean=[1.0, 2.0, 1.0] - ) - self.assertEqual(image_processor.size, {"height": 42, "width": 42}) - self.assertEqual(image_processor.image_mean, [1.0, 2.0, 1.0]) + image_processor = image_processing_class.from_dict( + self.image_processor_dict, size=42, image_mean=[1.0, 2.0, 1.0] + ) + self.assertEqual(image_processor.size, {"height": 42, "width": 42}) + self.assertEqual(image_processor.image_mean, [1.0, 2.0, 1.0]) def test_call_pil(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) - for image in image_inputs: - self.assertIsInstance(image, Image.Image) - - # Test Non batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values - expected_output_image_shape = (1, 3, 384, 384) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 3, 384, 384) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test Non batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 3, 384, 384) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 3, 384, 384) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) def test_call_numpy(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True) - for image in image_inputs: - self.assertIsInstance(image, np.ndarray) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) - encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values - expected_output_image_shape = (1, 3, 384, 384) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 3, 384, 384) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 3, 384, 384) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 3, 384, 384) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) def test_call_pytorch(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) - for image in image_inputs: - self.assertIsInstance(image, torch.Tensor) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) - encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values - expected_output_image_shape = (1, 3, 384, 384) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 3, 384, 384) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 3, 384, 384) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 3, 384, 384) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) def test_nested_input(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) - - # Test batched as a list of images. - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 3, 384, 384) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - - # Test batched as a nested list of images, where each sublist is one batch. - image_inputs_nested = [image_inputs[:3], image_inputs[3:]] - encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 3, 384, 384) - self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape) - - # Image processor should return same pixel values, independently of input format. - self.assertTrue((encoded_images_nested == encoded_images).all()) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + + # Test batched as a list of images. + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 3, 384, 384) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched as a nested list of images, where each sublist is one batch. + image_inputs_nested = [image_inputs[:3], image_inputs[3:]] + encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 3, 384, 384) + self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape) + + # Image processor should return same pixel values, independently of input format. + self.assertTrue((encoded_images_nested == encoded_images).all()) + + @require_vision + @require_torch + def test_slow_fast_equivalence_batched(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop: + self.skipTest( + reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors" + ) + + dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_images, return_tensors=None) + encoding_fast = image_processor_fast(dummy_images, return_tensors=None) + + # Overwrite as the outputs are not always all of the same shape (kept for BC) + for i in range(len(encoding_slow.pixel_values)): + self._assert_slow_fast_tensors_equivalence( + torch.from_numpy(encoding_slow.pixel_values[i]), encoding_fast.pixel_values[i] + ) + + @require_vision + @require_torch + def test_slow_fast_equivalence_postprocess(self): + dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + dummy_images = [image / 255.0 for image in dummy_images] + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow.postprocess(dummy_images, return_tensors=None) + encoding_fast = image_processor_fast.postprocess(dummy_images, return_tensors=None) + + # Overwrite as the outputs are not always all of the same shape (kept for BC) + for i in range(len(encoding_slow.pixel_values)): + self._assert_slow_fast_tensors_equivalence( + torch.from_numpy(encoding_slow.pixel_values[i]).float(), encoding_fast.pixel_values[i].float() + ) @unittest.skip(reason="Not supported") def test_call_numpy_4_channels(self): From 5e3c5faa5432186178776d7740cc73aef581bd12 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Thu, 31 Jul 2025 17:13:38 +0000 Subject: [PATCH 2/2] fix after review --- .../image_processing_deepseek_vl.py | 8 +++++++- .../image_processing_deepseek_vl_hybrid.py | 19 +++++++++++++++---- ...mage_processing_deepseek_vl_hybrid_fast.py | 7 ++++--- .../modular_deepseek_vl_hybrid.py | 18 ++++++++++++------ .../models/janus/image_processing_janus.py | 8 +++++++- .../models/janus/modular_janus.py | 8 +++++++- 6 files changed, 52 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py b/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py index 8cf4acbf937f..8df016a80eeb 100644 --- a/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py @@ -131,6 +131,7 @@ def resize( self, image: np.ndarray, size: Union[dict[str, int], int], + background_color: Optional[tuple[int, int, int]] = None, resample: PILImageResampling = PILImageResampling.BICUBIC, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -142,6 +143,10 @@ def resize( Args: image (`np.ndarray`): Image to resize. + size (`dict[str, int]` or `int`): + The size to resize the image to. If a dictionary, it should have the keys `"height"` and `"width"`. + background_color (`tuple[int, int, int]`): + The background color to use for the padding. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. data_format (`ChannelDimension` or `str`, *optional*): @@ -160,6 +165,7 @@ def resize( Returns: `np.ndarray`: The resized image. """ + background_color = background_color if background_color is not None else self.background_color if input_data_format is None: input_data_format = infer_channel_dimension_format(image) @@ -191,7 +197,7 @@ def resize( # Expand and pad the images to obtain a square image of dimensions `size x size` image = self.pad_to_square( image=image, - background_color=self.background_color, + background_color=background_color, input_data_format=input_data_format, ) return image diff --git a/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py index 3dc112d3a525..a2772894b2ae 100644 --- a/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py @@ -154,14 +154,15 @@ def __init__( self.background_color = tuple([int(x * 255) for x in image_mean]) if high_res_image_mean is None: - self.background_color = (127, 127, 127) + self.high_res_background_color = (127, 127, 127) else: - self.background_color = tuple([int(x * 255) for x in high_res_image_mean]) + self.high_res_background_color = tuple([int(x * 255) for x in high_res_image_mean]) def resize( self, image: np.ndarray, size: Union[dict[str, int], int], + background_color: Optional[tuple[int, int, int]] = None, resample: PILImageResampling = PILImageResampling.BICUBIC, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -173,6 +174,10 @@ def resize( Args: image (`np.ndarray`): Image to resize. + size (`dict[str, int]` or `int`): + The size to resize the image to. If a dictionary, it should have the keys `"height"` and `"width"`. + background_color (`tuple[int, int, int]`): + The background color to use for the padding. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. data_format (`ChannelDimension` or `str`, *optional*): @@ -191,6 +196,7 @@ def resize( Returns: `np.ndarray`: The resized image. """ + background_color = background_color if background_color is not None else self.background_color if input_data_format is None: input_data_format = infer_channel_dimension_format(image) @@ -222,7 +228,7 @@ def resize( # Expand and pad the images to obtain a square image of dimensions `size x size` image = self.pad_to_square( image=image, - background_color=self.background_color, + background_color=background_color, input_data_format=input_data_format, ) return image @@ -365,11 +371,16 @@ def preprocess( high_res_image = self.resize( image=high_res_image, size=high_res_size_dict, + background_color=self.high_res_background_color, resample=high_res_resample, input_data_format=input_data_format, ) image = self.resize( - image=high_res_image, size=size_dict, resample=resample, input_data_format=input_data_format + image=high_res_image, + size=size_dict, + background_color=self.background_color, + resample=resample, + input_data_format=input_data_format, ) if do_rescale: diff --git a/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py b/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py index 6bb5696282f9..2120a65dd429 100644 --- a/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +++ b/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py @@ -98,11 +98,12 @@ def __init__(self, **kwargs: Unpack[DeepseekVLHybridFastImageProcessorKwargs]): else: background_color = tuple([int(x * 255) for x in kwargs.get("image_mean")]) if kwargs.get("high_res_image_mean", None) is None: - background_color = (127, 127, 127) + high_res_background_color = (127, 127, 127) else: - background_color = tuple([int(x * 255) for x in kwargs.get("high_res_image_mean")]) + high_res_background_color = tuple([int(x * 255) for x in kwargs.get("high_res_image_mean")]) super().__init__(**kwargs) self.background_color = tuple(background_color) + self.high_res_background_color = tuple(high_res_background_color) def resize( self, @@ -223,7 +224,7 @@ def _preprocess( for shape, stacked_high_res_images in grouped_high_res_images.items(): if do_pad: stacked_high_res_images = self.pad_to_square( - stacked_high_res_images, background_color=self.background_color + stacked_high_res_images, background_color=self.high_res_background_color ) high_res_padded_images[shape] = stacked_high_res_images # Fused rescale and normalize diff --git a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py index 776af4a5cfbc..e3149b420286 100644 --- a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py @@ -533,9 +533,9 @@ def __init__( ) if high_res_image_mean is None: - self.background_color = (127, 127, 127) + self.high_res_background_color = (127, 127, 127) else: - self.background_color = tuple([int(x * 255) for x in high_res_image_mean]) + self.high_res_background_color = tuple([int(x * 255) for x in high_res_image_mean]) @filter_out_non_signature_kwargs() def preprocess( @@ -675,11 +675,16 @@ def preprocess( high_res_image = self.resize( image=high_res_image, size=high_res_size_dict, + background_color=self.high_res_background_color, resample=high_res_resample, input_data_format=input_data_format, ) image = self.resize( - image=high_res_image, size=size_dict, resample=resample, input_data_format=input_data_format + image=high_res_image, + size=size_dict, + background_color=self.background_color, + resample=resample, + input_data_format=input_data_format, ) if do_rescale: @@ -749,11 +754,12 @@ def __init__(self, **kwargs: Unpack[DeepseekVLHybridFastImageProcessorKwargs]): else: background_color = tuple([int(x * 255) for x in kwargs.get("image_mean")]) if kwargs.get("high_res_image_mean", None) is None: - background_color = (127, 127, 127) + high_res_background_color = (127, 127, 127) else: - background_color = tuple([int(x * 255) for x in kwargs.get("high_res_image_mean")]) + high_res_background_color = tuple([int(x * 255) for x in kwargs.get("high_res_image_mean")]) DeepseekVLImageProcessorFast().__init__(**kwargs) self.background_color = tuple(background_color) + self.high_res_background_color = tuple(high_res_background_color) def _further_process_kwargs( self, @@ -848,7 +854,7 @@ def _preprocess( for shape, stacked_high_res_images in grouped_high_res_images.items(): if do_pad: stacked_high_res_images = self.pad_to_square( - stacked_high_res_images, background_color=self.background_color + stacked_high_res_images, background_color=self.high_res_background_color ) high_res_padded_images[shape] = stacked_high_res_images # Fused rescale and normalize diff --git a/src/transformers/models/janus/image_processing_janus.py b/src/transformers/models/janus/image_processing_janus.py index 3b1236061300..f99041748fcb 100644 --- a/src/transformers/models/janus/image_processing_janus.py +++ b/src/transformers/models/janus/image_processing_janus.py @@ -134,6 +134,7 @@ def resize( self, image: np.ndarray, size: Union[dict[str, int], int], + background_color: Optional[tuple[int, int, int]] = None, resample: PILImageResampling = PILImageResampling.BICUBIC, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -145,6 +146,10 @@ def resize( Args: image (`np.ndarray`): Image to resize. + size (`dict[str, int]` or `int`): + The size to resize the image to. If a dictionary, it should have the keys `"height"` and `"width"`. + background_color (`tuple[int, int, int]`): + The background color to use for the padding. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. data_format (`ChannelDimension` or `str`, *optional*): @@ -163,6 +168,7 @@ def resize( Returns: `np.ndarray`: The resized image. """ + background_color = background_color if background_color is not None else self.background_color if input_data_format is None: input_data_format = infer_channel_dimension_format(image) @@ -194,7 +200,7 @@ def resize( # Expand and pad the images to obtain a square image of dimensions `size x size` image = self.pad_to_square( image=image, - background_color=self.background_color, + background_color=background_color, input_data_format=input_data_format, ) return image diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 11b0848620b7..5d9bde2e9840 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -1419,6 +1419,7 @@ def resize( self, image: np.ndarray, size: Union[dict[str, int], int], + background_color: Optional[tuple[int, int, int]] = None, resample: PILImageResampling = PILImageResampling.BICUBIC, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -1430,6 +1431,10 @@ def resize( Args: image (`np.ndarray`): Image to resize. + size (`dict[str, int]` or `int`): + The size to resize the image to. If a dictionary, it should have the keys `"height"` and `"width"`. + background_color (`tuple[int, int, int]`): + The background color to use for the padding. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. data_format (`ChannelDimension` or `str`, *optional*): @@ -1448,6 +1453,7 @@ def resize( Returns: `np.ndarray`: The resized image. """ + background_color = background_color if background_color is not None else self.background_color if input_data_format is None: input_data_format = infer_channel_dimension_format(image) @@ -1479,7 +1485,7 @@ def resize( # Expand and pad the images to obtain a square image of dimensions `size x size` image = self.pad_to_square( image=image, - background_color=self.background_color, + background_color=background_color, input_data_format=input_data_format, ) return image