diff --git a/docs/source/en/model_doc/imagegpt.md b/docs/source/en/model_doc/imagegpt.md index e3c5db15247e..d995a92ec912 100644 --- a/docs/source/en/model_doc/imagegpt.md +++ b/docs/source/en/model_doc/imagegpt.md @@ -104,6 +104,11 @@ If you're interested in submitting a resource to be included here, please feel f [[autodoc]] ImageGPTImageProcessor - preprocess +## ImageGPTImageProcessorFast + +[[autodoc]] ImageGPTImageProcessorFast + - preprocess + ## ImageGPTModel [[autodoc]] ImageGPTModel diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 1240a677d97c..499bfb5b2bdf 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -111,7 +111,7 @@ ("idefics2", ("Idefics2ImageProcessor", "Idefics2ImageProcessorFast")), ("idefics3", ("Idefics3ImageProcessor", "Idefics3ImageProcessorFast")), ("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")), - ("imagegpt", ("ImageGPTImageProcessor", None)), + ("imagegpt", ("ImageGPTImageProcessor", "ImageGPTImageProcessorFast")), ("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")), ("instructblipvideo", ("InstructBlipVideoImageProcessor", None)), ("janus", ("JanusImageProcessor", "JanusImageProcessorFast")), diff --git a/src/transformers/models/imagegpt/__init__.py b/src/transformers/models/imagegpt/__init__.py index cb79cea50d6e..098ffb6296f5 100644 --- a/src/transformers/models/imagegpt/__init__.py +++ b/src/transformers/models/imagegpt/__init__.py @@ -21,6 +21,7 @@ from .configuration_imagegpt import * from .feature_extraction_imagegpt import * from .image_processing_imagegpt import * + from .image_processing_imagegpt_fast import * from .modeling_imagegpt import * else: import sys diff --git a/src/transformers/models/imagegpt/image_processing_imagegpt.py b/src/transformers/models/imagegpt/image_processing_imagegpt.py index 0ec9ef5e4333..1f2026627515 100644 --- a/src/transformers/models/imagegpt/image_processing_imagegpt.py +++ b/src/transformers/models/imagegpt/image_processing_imagegpt.py @@ -26,7 +26,7 @@ PILImageResampling, infer_channel_dimension_format, is_scaled_image, - make_flat_list_of_images, + make_list_of_images, to_numpy_array, valid_images, validate_preprocess_arguments, @@ -238,7 +238,7 @@ def preprocess( clusters = clusters if clusters is not None else self.clusters clusters = np.array(clusters) - images = make_flat_list_of_images(images) + images = make_list_of_images(images) if not valid_images(images): raise ValueError( @@ -247,7 +247,7 @@ def preprocess( ) # Here, normalize() is using a constant factor to divide pixel values. - # hence, the method does not need image_mean and image_std. + # hence, the method does not need iamge_mean and image_std. validate_preprocess_arguments( do_resize=do_resize, size=size, @@ -291,14 +291,24 @@ def preprocess( # We need to convert back to a list of images to keep consistent behaviour across processors. images = list(images) + data = {"input_ids": images} else: - images = [ - to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) - for image in images - ] - - data = {"input_ids": images} + images = [to_channel_dimension_format(image, data_format, input_data_format) for image in images] + data = {"pixel_values": images} return BatchFeature(data=data, tensor_type=return_tensors) + def to_dict(self): + output = super().to_dict() + # Ensure clusters are JSON/equality friendly + if output.get("clusters") is not None and isinstance(output["clusters"], np.ndarray): + output["clusters"] = output["clusters"].tolist() + # Need to set missing keys from slow processor to match the expected behavior in save/load tests compared to fast processor + missing_keys = ["image_mean", "image_std", "rescale_factor", "do_rescale"] + for key in missing_keys: + if key in output: + output[key] = None + + return output + __all__ = ["ImageGPTImageProcessor"] diff --git a/src/transformers/models/imagegpt/image_processing_imagegpt_fast.py b/src/transformers/models/imagegpt/image_processing_imagegpt_fast.py new file mode 100644 index 000000000000..736666fd28a0 --- /dev/null +++ b/src/transformers/models/imagegpt/image_processing_imagegpt_fast.py @@ -0,0 +1,209 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for ImageGPT.""" + +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, +) +from ...image_transforms import group_images_by_shape, reorder_images +from ...image_utils import PILImageResampling +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +def squared_euclidean_distance_torch(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Compute squared Euclidean distances between all pixels and clusters. + + Args: + a: (N, 3) tensor of pixel RGB values + b: (M, 3) tensor of cluster RGB values + + Returns: + (N, M) tensor of squared distances + """ + b = b.t() # (3, M) + a2 = torch.sum(a**2, dim=1) # (N,) + b2 = torch.sum(b**2, dim=0) # (M,) + ab = torch.matmul(a, b) # (N, M) + d = a2[:, None] - 2 * ab + b2[None, :] # Squared Euclidean Distance: a^2 - 2ab + b^2 + return d # (N, M) tensor of squared distances + + +def color_quantize_torch(x: torch.Tensor, clusters: torch.Tensor) -> torch.Tensor: + """ + Assign each pixel to its nearest color cluster. + + Args: + x: (H*W, 3) tensor of flattened pixel RGB values + clusters: (n_clusters, 3) tensor of cluster RGB values + + Returns: + (H*W,) tensor of cluster indices + """ + d = squared_euclidean_distance_torch(x, clusters) + return torch.argmin(d, dim=1) + + +class ImageGPTFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + clusters (`np.ndarray` or `list[list[int]]` or `torch.Tensor`, *optional*): + The color clusters to use, of shape `(n_clusters, 3)` when color quantizing. Can be overridden by `clusters` + in `preprocess`. + do_color_quantize (`bool`, *optional*, defaults to `True`): + Controls whether to apply color quantization to convert continuous pixel values to discrete cluster indices. + When True, each pixel is assigned to its nearest color cluster, enabling ImageGPT's discrete token modeling. + """ + + clusters: Optional[Union[np.ndarray, list[list[int]], torch.Tensor]] + do_color_quantize: Optional[bool] + + +@auto_docstring +class ImageGPTImageProcessorFast(BaseImageProcessorFast): + model_input_names = ["input_ids"] + resample = PILImageResampling.BILINEAR + do_color_quantize = True + clusters = None + image_mean = [0.5, 0.5, 0.5] + image_std = [0.5, 0.5, 0.5] + do_rescale = True + do_normalize = True + valid_kwargs = ImageGPTFastImageProcessorKwargs + + def __init__( + self, + clusters: Optional[Union[list, np.ndarray, torch.Tensor]] = None, # keep as arg for backwards compatibility + **kwargs: Unpack[ImageGPTFastImageProcessorKwargs], + ): + r""" + clusters (`np.ndarray` or `list[list[int]]` or `torch.Tensor`, *optional*): + The color clusters to use, of shape `(n_clusters, 3)` when color quantizing. Can be overridden by `clusters` + in `preprocess`. + """ + clusters = torch.as_tensor(clusters, dtype=torch.float32) if clusters is not None else None + super().__init__(clusters=clusters, **kwargs) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: dict[str, int], + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: dict[str, int], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + do_color_quantize: Optional[bool] = None, + clusters: Optional[Union[list, np.ndarray, torch.Tensor]] = None, + disable_grouping: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ): + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + pixel_values = reorder_images(processed_images_grouped, grouped_images_index) + + # If color quantization is requested, perform it; otherwise return pixel values + if do_color_quantize: + # Prepare clusters + if clusters is None: + raise ValueError("Clusters must be provided for color quantization.") + # Convert to torch tensor if needed (clusters might be passed as list/numpy) + clusters_torch = ( + torch.as_tensor(clusters, dtype=torch.float32) if not isinstance(clusters, torch.Tensor) else clusters + ).to(pixel_values[0].device, dtype=pixel_values[0].dtype) + + # Group images by shape for batch processing + # We need to check if the pixel values are a tensor or a list of tensors + grouped_images, grouped_images_index = group_images_by_shape( + pixel_values, disable_grouping=disable_grouping + ) + # Process each group + input_ids_grouped = {} + + for shape, stacked_images in grouped_images.items(): + input_ids = color_quantize_torch( + stacked_images.permute(0, 2, 3, 1).reshape(-1, 3), clusters_torch + ) # (B*H*W, C) + input_ids_grouped[shape] = input_ids.reshape(stacked_images.shape[0], -1).reshape( + stacked_images.shape[0], -1 + ) # (B, H, W) + + input_ids = reorder_images(input_ids_grouped, grouped_images_index) + + return BatchFeature( + data={"input_ids": torch.stack(input_ids, dim=0) if return_tensors else input_ids}, + tensor_type=return_tensors, + ) + + pixel_values = torch.stack(pixel_values, dim=0) if return_tensors else pixel_values + return BatchFeature(data={"pixel_values": pixel_values}, tensor_type=return_tensors) + + def to_dict(self): + # Convert torch tensors to lists for JSON serialization + output = super().to_dict() + if output.get("clusters") is not None and isinstance(output["clusters"], torch.Tensor): + output["clusters"] = output["clusters"].tolist() + + return output + + +__all__ = ["ImageGPTImageProcessorFast"] diff --git a/tests/models/imagegpt/test_image_processing_imagegpt.py b/tests/models/imagegpt/test_image_processing_imagegpt.py index 35bf08fce3a1..8c04d9585022 100644 --- a/tests/models/imagegpt/test_image_processing_imagegpt.py +++ b/tests/models/imagegpt/test_image_processing_imagegpt.py @@ -19,11 +19,21 @@ import unittest import numpy as np +import pytest +import requests from datasets import load_dataset +from packaging import version from transformers import AutoImageProcessor -from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_vision, slow -from transformers.utils import is_torch_available, is_vision_available +from transformers.testing_utils import ( + check_json_file_has_correct_format, + require_torch, + require_torch_accelerator, + require_vision, + slow, + torch_device, +) +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -36,6 +46,9 @@ from transformers import ImageGPTImageProcessor + if is_torchvision_available(): + from transformers import ImageGPTImageProcessorFast + class ImageGPTImageProcessingTester: def __init__( @@ -94,6 +107,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F @require_vision class ImageGPTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = ImageGPTImageProcessor if is_vision_available() else None + fast_image_processing_class = ImageGPTImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -104,50 +118,54 @@ def image_processor_dict(self): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "clusters")) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "do_normalize")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "clusters")) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_normalize")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"height": 18, "width": 18}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 18, "width": 18}) - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42) - self.assertEqual(image_processor.size, {"height": 42, "width": 42}) + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42) + self.assertEqual(image_processor.size, {"height": 42, "width": 42}) def test_image_processor_to_json_string(self): - image_processor = self.image_processing_class(**self.image_processor_dict) - obj = json.loads(image_processor.to_json_string()) - for key, value in self.image_processor_dict.items(): - if key == "clusters": - self.assertTrue(np.array_equal(value, obj[key])) - else: - self.assertEqual(obj[key], value) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + obj = json.loads(image_processor.to_json_string()) + for key, value in self.image_processor_dict.items(): + if key == "clusters": + self.assertTrue(np.array_equal(value, obj[key])) + else: + self.assertEqual(obj[key], value) def test_image_processor_to_json_file(self): - image_processor_first = self.image_processing_class(**self.image_processor_dict) + for image_processing_class in self.image_processor_list: + image_processor_first = image_processing_class(**self.image_processor_dict) - with tempfile.TemporaryDirectory() as tmpdirname: - json_file_path = os.path.join(tmpdirname, "image_processor.json") - image_processor_first.to_json_file(json_file_path) - image_processor_second = self.image_processing_class.from_json_file(json_file_path).to_dict() + with tempfile.TemporaryDirectory() as tmpdirname: + json_file_path = os.path.join(tmpdirname, "image_processor.json") + image_processor_first.to_json_file(json_file_path) + image_processor_second = image_processing_class.from_json_file(json_file_path).to_dict() - image_processor_first = image_processor_first.to_dict() - for key, value in image_processor_first.items(): - if key == "clusters": - self.assertTrue(np.array_equal(value, image_processor_second[key])) - else: - self.assertEqual(value, value) + image_processor_first = image_processor_first.to_dict() + for key, value in image_processor_first.items(): + if key == "clusters": + self.assertTrue(np.array_equal(value, image_processor_second[key])) + else: + self.assertEqual(image_processor_first[key], value) def test_image_processor_from_and_save_pretrained(self): for image_processing_class in self.image_processor_list: - image_processor_first = self.image_processing_class(**self.image_processor_dict) + image_processor_first = image_processing_class(**self.image_processor_dict) with tempfile.TemporaryDirectory() as tmpdirname: image_processor_first.save_pretrained(tmpdirname) - image_processor_second = self.image_processing_class.from_pretrained(tmpdirname).to_dict() + image_processor_second = image_processing_class.from_pretrained(tmpdirname).to_dict() image_processor_first = image_processor_first.to_dict() for key, value in image_processor_first.items(): @@ -181,43 +199,45 @@ def test_init_without_params(self): # Override the test from ImageProcessingTestMixin as ImageGPT model takes input_ids as input def test_call_pil(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PIL images - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) - for image in image_inputs: - self.assertIsInstance(image, Image.Image) - - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").input_ids - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(encoded_images) - self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) - - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").input_ids - self.assertEqual( - tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) - ) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").input_ids + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(encoded_images) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").input_ids + self.assertEqual( + tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) + ) # Override the test from ImageProcessingTestMixin as ImageGPT model takes input_ids as input def test_call_numpy(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random numpy tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) - for image in image_inputs: - self.assertIsInstance(image, np.ndarray) - - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").input_ids - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(encoded_images) - self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) - - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").input_ids - self.assertEqual( - tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) - ) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").input_ids + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(encoded_images) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").input_ids + self.assertEqual( + tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) + ) @unittest.skip(reason="ImageGPT assumes clusters for 3 channels") def test_call_numpy_4_channels(self): @@ -225,24 +245,93 @@ def test_call_numpy_4_channels(self): # Override the test from ImageProcessingTestMixin as ImageGPT model takes input_ids as input def test_call_pytorch(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PyTorch tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) - - for image in image_inputs: - self.assertIsInstance(image, torch.Tensor) - - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").input_ids - self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) - - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").input_ids - self.assertEqual( - tuple(encoded_images.shape), - (self.image_processor_tester.batch_size, *expected_output_image_shape), + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").input_ids + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").input_ids + self.assertEqual( + tuple(encoded_images.shape), + (self.image_processor_tester.batch_size, *expected_output_image_shape), + ) + + # For quantization-based processors, use absolute tolerance only to avoid infinity issues + @require_vision + @require_torch + def test_slow_fast_equivalence(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + dummy_image = Image.open( + requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw + ) + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") + self._assert_slow_fast_tensors_equivalence( + encoding_slow.input_ids.float(), encoding_fast.input_ids.float(), atol=1.0, rtol=0 + ) + + @require_vision + @require_torch + def test_slow_fast_equivalence_batched(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop: + self.skipTest( + reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors" + ) + + dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_images, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_images, return_tensors="pt") + + self._assert_slow_fast_tensors_equivalence( + encoding_slow.input_ids.float(), encoding_fast.input_ids.float(), atol=1.0, rtol=0 + ) + + @slow + @require_torch_accelerator + @require_vision + @pytest.mark.torch_compile_test + def test_can_compile_fast_image_processor(self): + if self.fast_image_processing_class is None: + self.skipTest("Skipping compilation test as fast image processor is not defined") + if version.parse(torch.__version__) < version.parse("2.3"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + torch.compiler.reset() + input_image = torch.randint(0, 255, (3, 224, 224), dtype=torch.uint8) + image_processor = self.fast_image_processing_class(**self.image_processor_dict) + output_eager = image_processor(input_image, device=torch_device, return_tensors="pt") + + image_processor = torch.compile(image_processor, mode="reduce-overhead") + output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt") + self._assert_slow_fast_tensors_equivalence( + output_eager.input_ids.float(), output_compiled.input_ids.float(), atol=1.0, rtol=0 )