Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
aea6547
initial commit
agamjots05 Jul 22, 2025
d49fa73
initial setup
agamjots05 Jul 22, 2025
aead221
Overiding imageGPT specific functions
agamjots05 Jul 25, 2025
dc1b191
imported is_torch_available and utilized it for importing torch in im…
ayaayethan Jul 28, 2025
daedee9
Created init and ImageGPTFastImageProcessorKwargs
ayaayethan Jul 28, 2025
8608e19
added return_tensors, data_format, and input_data_format to ImageGPTF…
ayaayethan Jul 28, 2025
b772356
set up arguments and process and _preprocess definitions
ayaayethan Jul 28, 2025
9e80e0a
Added arguments to _preprocess
chrisandthentran Aug 1, 2025
f9f3ad8
Added additional optional arguments
chrisandthentran Aug 1, 2025
870cd9a
Copied logic over from base imageGPT processor
ayaayethan Aug 1, 2025
3604c7a
Implemented 2nd draft of fast imageGPT preprocess using batch processing
ayaayethan Aug 5, 2025
fd5c136
Implemented 3rd draft of imageGPT fast _preprocessor. Pulled logic fr…
ayaayethan Aug 5, 2025
ec1681b
modified imageGPT test file to properly run fast processor tests
ayaayethan Aug 7, 2025
432e8f3
converts images to torch.float32 from torch.unit8
ayaayethan Aug 7, 2025
040678c
fixed a typo with self.image_processor_list in the imagegpt test file
ayaayethan Aug 7, 2025
6e0c670
updated more instances of image_processing = self.image_processing_cl…
ayaayethan Aug 7, 2025
a020d5f
standardized normalization to not use image mean or std
ayaayethan Aug 11, 2025
56b3546
Merged changes from solution2 branch
ayaayethan Aug 11, 2025
0b21bff
Merged changes from solution2 test file
ayaayethan Aug 11, 2025
e98d5fa
fixed testing through baseImageGPT processor file
agamjots05 Aug 11, 2025
f3b0a8c
Fixed check_code_quality test. Removed unncessary list comprehension.
ayaayethan Aug 12, 2025
43c6171
reorganized imports in image_processing_imagegpt_fast
ayaayethan Aug 12, 2025
178b2c0
Merge branch 'main' into fastimage-imagegpt
ayaayethan Aug 12, 2025
5bb6d5a
formatted image_processing_imagegpt_fast.py
ayaayethan Aug 12, 2025
e575ba7
Added arg documentation
chrisandthentran Aug 12, 2025
25bd8ac
Added FastImageProcessorKwargs class + Docs for new kwargs
chrisandthentran Aug 12, 2025
ff89353
Reformatted previous
chrisandthentran Aug 12, 2025
a1b2e7f
Added F to normalization
chrisandthentran Aug 12, 2025
cd4d063
fixed ruff linting and cleaned up fast processor file
agamjots05 Aug 13, 2025
6b8458c
Merge branch 'huggingface:main' into fastimage-imagegpt
agamjots05 Aug 13, 2025
0b4af99
Merge branch 'main' into fastimage-imagegpt
agamjots05 Aug 13, 2025
1315a98
implemented requested changes
agamjots05 Aug 22, 2025
0f34fd1
fixed ruff checks
agamjots05 Aug 22, 2025
4f34393
fixed formatting issues
agamjots05 Aug 22, 2025
b382d09
Merge branch 'main' into fastimage-imagegpt
agamjots05 Aug 26, 2025
898a807
fix(ruff after merging main)
agamjots05 Aug 26, 2025
24f21c4
Merge remote-tracking branch 'upstream/main' into fastimage-imagegpt
yonigozlan Sep 4, 2025
f3815ce
simplify logic and reuse standard equivalenec tests
yonigozlan Sep 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/imagegpt.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ If you're interested in submitting a resource to be included here, please feel f
[[autodoc]] ImageGPTImageProcessor
- preprocess

## ImageGPTImageProcessorFast

[[autodoc]] ImageGPTImageProcessorFast
- preprocess

## ImageGPTModel

[[autodoc]] ImageGPTModel
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
("idefics2", ("Idefics2ImageProcessor", "Idefics2ImageProcessorFast")),
("idefics3", ("Idefics3ImageProcessor", "Idefics3ImageProcessorFast")),
("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")),
("imagegpt", ("ImageGPTImageProcessor", None)),
("imagegpt", ("ImageGPTImageProcessor", "ImageGPTImageProcessorFast")),
("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
("instructblipvideo", ("InstructBlipVideoImageProcessor", None)),
("janus", ("JanusImageProcessor", "JanusImageProcessorFast")),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/imagegpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .configuration_imagegpt import *
from .feature_extraction_imagegpt import *
from .image_processing_imagegpt import *
from .image_processing_imagegpt_fast import *
from .modeling_imagegpt import *
else:
import sys
Expand Down
28 changes: 19 additions & 9 deletions src/transformers/models/imagegpt/image_processing_imagegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_flat_list_of_images,
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
Expand Down Expand Up @@ -238,7 +238,7 @@ def preprocess(
clusters = clusters if clusters is not None else self.clusters
clusters = np.array(clusters)

images = make_flat_list_of_images(images)
images = make_list_of_images(images)

if not valid_images(images):
raise ValueError(
Expand All @@ -247,7 +247,7 @@ def preprocess(
)

# Here, normalize() is using a constant factor to divide pixel values.
# hence, the method does not need image_mean and image_std.
# hence, the method does not need iamge_mean and image_std.
validate_preprocess_arguments(
do_resize=do_resize,
size=size,
Expand Down Expand Up @@ -291,14 +291,24 @@ def preprocess(

# We need to convert back to a list of images to keep consistent behaviour across processors.
images = list(images)
data = {"input_ids": images}
else:
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
for image in images
]

data = {"input_ids": images}
images = [to_channel_dimension_format(image, data_format, input_data_format) for image in images]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)

def to_dict(self):
output = super().to_dict()
# Ensure clusters are JSON/equality friendly
if output.get("clusters") is not None and isinstance(output["clusters"], np.ndarray):
output["clusters"] = output["clusters"].tolist()
# Need to set missing keys from slow processor to match the expected behavior in save/load tests compared to fast processor
missing_keys = ["image_mean", "image_std", "rescale_factor", "do_rescale"]
for key in missing_keys:
if key in output:
output[key] = None

return output


__all__ = ["ImageGPTImageProcessor"]
209 changes: 209 additions & 0 deletions src/transformers/models/imagegpt/image_processing_imagegpt_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for ImageGPT."""

from typing import Optional, Union

import numpy as np

from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
)
from ...image_transforms import group_images_by_shape, reorder_images
from ...image_utils import PILImageResampling
from ...processing_utils import Unpack
from ...utils import (
TensorType,
auto_docstring,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
)


if is_torch_available():
import torch

if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F


def squared_euclidean_distance_torch(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
Compute squared Euclidean distances between all pixels and clusters.

Args:
a: (N, 3) tensor of pixel RGB values
b: (M, 3) tensor of cluster RGB values

Returns:
(N, M) tensor of squared distances
"""
b = b.t() # (3, M)
a2 = torch.sum(a**2, dim=1) # (N,)
b2 = torch.sum(b**2, dim=0) # (M,)
ab = torch.matmul(a, b) # (N, M)
d = a2[:, None] - 2 * ab + b2[None, :] # Squared Euclidean Distance: a^2 - 2ab + b^2
return d # (N, M) tensor of squared distances


def color_quantize_torch(x: torch.Tensor, clusters: torch.Tensor) -> torch.Tensor:
"""
Assign each pixel to its nearest color cluster.

Args:
x: (H*W, 3) tensor of flattened pixel RGB values
clusters: (n_clusters, 3) tensor of cluster RGB values

Returns:
(H*W,) tensor of cluster indices
"""
d = squared_euclidean_distance_torch(x, clusters)
return torch.argmin(d, dim=1)


class ImageGPTFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
"""
clusters (`np.ndarray` or `list[list[int]]` or `torch.Tensor`, *optional*):
The color clusters to use, of shape `(n_clusters, 3)` when color quantizing. Can be overridden by `clusters`
in `preprocess`.
do_color_quantize (`bool`, *optional*, defaults to `True`):
Controls whether to apply color quantization to convert continuous pixel values to discrete cluster indices.
When True, each pixel is assigned to its nearest color cluster, enabling ImageGPT's discrete token modeling.
"""

clusters: Optional[Union[np.ndarray, list[list[int]], torch.Tensor]]
do_color_quantize: Optional[bool]


@auto_docstring
class ImageGPTImageProcessorFast(BaseImageProcessorFast):
model_input_names = ["input_ids"]
resample = PILImageResampling.BILINEAR
do_color_quantize = True
clusters = None
image_mean = [0.5, 0.5, 0.5]
image_std = [0.5, 0.5, 0.5]
do_rescale = True
do_normalize = True
valid_kwargs = ImageGPTFastImageProcessorKwargs

def __init__(
self,
clusters: Optional[Union[list, np.ndarray, torch.Tensor]] = None, # keep as arg for backwards compatibility
**kwargs: Unpack[ImageGPTFastImageProcessorKwargs],
):
r"""
clusters (`np.ndarray` or `list[list[int]]` or `torch.Tensor`, *optional*):
The color clusters to use, of shape `(n_clusters, 3)` when color quantizing. Can be overridden by `clusters`
in `preprocess`.
"""
clusters = torch.as_tensor(clusters, dtype=torch.float32) if clusters is not None else None
super().__init__(clusters=clusters, **kwargs)

def _preprocess(
self,
images: list["torch.Tensor"],
do_resize: bool,
size: dict[str, int],
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: dict[str, int],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
do_color_quantize: Optional[bool] = None,
clusters: Optional[Union[list, np.ndarray, torch.Tensor]] = None,
disable_grouping: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
):
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)

# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:
stacked_images = self.center_crop(stacked_images, crop_size)
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images

pixel_values = reorder_images(processed_images_grouped, grouped_images_index)

# If color quantization is requested, perform it; otherwise return pixel values
if do_color_quantize:
# Prepare clusters
if clusters is None:
raise ValueError("Clusters must be provided for color quantization.")
# Convert to torch tensor if needed (clusters might be passed as list/numpy)
clusters_torch = (
torch.as_tensor(clusters, dtype=torch.float32) if not isinstance(clusters, torch.Tensor) else clusters
).to(pixel_values[0].device, dtype=pixel_values[0].dtype)

# Group images by shape for batch processing
# We need to check if the pixel values are a tensor or a list of tensors
grouped_images, grouped_images_index = group_images_by_shape(
pixel_values, disable_grouping=disable_grouping
)
# Process each group
input_ids_grouped = {}

for shape, stacked_images in grouped_images.items():
input_ids = color_quantize_torch(
stacked_images.permute(0, 2, 3, 1).reshape(-1, 3), clusters_torch
) # (B*H*W, C)
input_ids_grouped[shape] = input_ids.reshape(stacked_images.shape[0], -1).reshape(
stacked_images.shape[0], -1
) # (B, H, W)

input_ids = reorder_images(input_ids_grouped, grouped_images_index)

return BatchFeature(
data={"input_ids": torch.stack(input_ids, dim=0) if return_tensors else input_ids},
tensor_type=return_tensors,
)

pixel_values = torch.stack(pixel_values, dim=0) if return_tensors else pixel_values
return BatchFeature(data={"pixel_values": pixel_values}, tensor_type=return_tensors)

def to_dict(self):
# Convert torch tensors to lists for JSON serialization
output = super().to_dict()
if output.get("clusters") is not None and isinstance(output["clusters"], torch.Tensor):
output["clusters"] = output["clusters"].tolist()

return output


__all__ = ["ImageGPTImageProcessorFast"]
Loading