Skip to content

added fast image processor for ZoeDepth and expanded tests accordingly#38515

Merged
yonigozlan merged 4 commits intohuggingface:mainfrom
henrikm11:fast-image-processor-zoedepth
Jun 4, 2025
Merged

added fast image processor for ZoeDepth and expanded tests accordingly#38515
yonigozlan merged 4 commits intohuggingface:mainfrom
henrikm11:fast-image-processor-zoedepth

Conversation

@henrikm11
Copy link
Copy Markdown
Contributor

@henrikm11 henrikm11 commented Jun 1, 2025

potential reviewer: @yonigozlan

Copy link
Copy Markdown
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @henrikm11 ! Thanks for working on this, looks almost ready to go! There's still a few things left to simplify in the fast image processor

Comment on lines +166 to +162
def _resize(
self,
images: "torch.Tensor",
size: SizeDict,
keep_aspect_ratio: bool = False,
ensure_multiple_of: int = 1,
resample: PILImageResampling = PILImageResampling.BILINEAR,
) -> "torch.Tensor":
"""
Resize an image or batchd images to target size `(size["height"], size["width"])`. If `keep_aspect_ratio` is `True`, the image
is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is
set, the image is resized to a size that is a multiple of this value.

Args:
images (`torch.Tensor`):
Images to resize.
size (`Dict[str, int]`):
Target size of the output image.
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
ensure_multiple_of (`int`, *optional*, defaults to 1):
The image is resized to a size that is a multiple of this value.
interpoation (`F.InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size
specified in `size`.
"""
if not size.height or not size.width:
raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size}")
output_size = get_resize_output_image_size(
images,
output_size=(size.height, size.width),
keep_aspect_ratio=keep_aspect_ratio,
multiple=ensure_multiple_of,
input_data_format=ChannelDimension.FIRST,
)
height, width = output_size

resample_to_mode = {PILImageResampling.BILINEAR: "bilinear", PILImageResampling.BICUBIC: "bicubic"}
mode = resample_to_mode[resample]

resized_images = torch.nn.functional.interpolate(
images, (int(height), int(width)), mode=mode, align_corners=True
)

return resized_images
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use the resize function from BaseImageProcessorFast?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't align_corners = True an issue here? Afaik that's simply not available in torchvision.transforms.functional / is set to False and thus one gets significantly different results.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, in this case let's override the resize function from BaseImageProcessorFast. Also you shouldn't need to define resample_to_mode, just use the interpolation arg in _preprocess instead of resample, as it is already the converted PILImageResampling to torch InterpolationMode. And still no need to override the whole preprocess function

Copy link
Copy Markdown
Contributor Author

@henrikm11 henrikm11 Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, thanks for pointing that out, not even sure why I overwrote preprocess in the first place now... All other suggested changes make perfect sense to me, sorry for not being as diligent with cleaning and simplifying the code as I should have been. Will make all changes tomorrow.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries at all!

def _pad_images(
self,
images: "torch.Tensor",
mode: TorchPaddingMode = TorchPaddingMode.REFLECT,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just have

Suggested change
mode: TorchPaddingMode = TorchPaddingMode.REFLECT,
mode="reflect",

So that we can remove TorchPaddingMode altogether

Comment on lines +128 to +164
@auto_docstring
def preprocess(
self,
images: ImageInput,
*args,
**kwargs: Unpack[ZoeDepthFastImageProcessorKwargs],
) -> BatchFeature:
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.

for kwarg_name in self.valid_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))

# Extract parameters that are only used for preparing the input images
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")

# Prepare input images
images = self._prepare_input_images(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)
# return list of torch.Tensor in ChannelDimension.FIRST format
# Update kwargs that need further processing before being validated
kwargs = self._further_process_kwargs(**kwargs)

# Validate kwargs
self._validate_preprocess_kwargs(**kwargs)

# Pop kwargs that are not needed in _preprocess
kwargs.pop("default_to_square")
kwargs.pop("data_format")
kwargs.pop("do_center_crop")
kwargs.pop("crop_size")

return self._preprocess(images, *args, **kwargs)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's a need to override anything here, especially if we use BaseImageProcessorFast resize. You can just have this to account for ZoeDepthFastImageProcessorKwargs in the signature

Suggested change
@auto_docstring
def preprocess(
self,
images: ImageInput,
*args,
**kwargs: Unpack[ZoeDepthFastImageProcessorKwargs],
) -> BatchFeature:
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# Extract parameters that are only used for preparing the input images
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
# Prepare input images
images = self._prepare_input_images(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)
# return list of torch.Tensor in ChannelDimension.FIRST format
# Update kwargs that need further processing before being validated
kwargs = self._further_process_kwargs(**kwargs)
# Validate kwargs
self._validate_preprocess_kwargs(**kwargs)
# Pop kwargs that are not needed in _preprocess
kwargs.pop("default_to_square")
kwargs.pop("data_format")
kwargs.pop("do_center_crop")
kwargs.pop("crop_size")
return self._preprocess(images, *args, **kwargs)
@auto_docstring
def preprocess(self,images: ImageInput, *args, **kwargs: Unpack[ZoeDepthFastImageProcessorKwargs]) -> BatchFeature:
return super().preprocess(images, **kwargs)

if do_rescale:
stacked_images = self.rescale(stacked_images, rescale_factor)
if do_pad:
stacked_images = self._pad_images(images=stacked_images, mode=PaddingMode.REFLECT)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to provide PaddingMode.REFLECT, as there's no way to specify the mode in the slow image processor.

stacked_images = self._pad_images(images=stacked_images, mode=PaddingMode.REFLECT)
if do_resize:
stacked_images = self._resize(stacked_images, size, keep_aspect_ratio, ensure_multiple_of, resample)
print(stacked_images.dtype)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(stacked_images.dtype)

print(stacked_images.dtype)
if do_normalize:
stacked_images = self.normalize(stacked_images, image_mean, image_std)
print(stacked_images.dtype)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(stacked_images.dtype)

Comment on lines +266 to +268
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)


return F.pad(images, padding=(pad_width, pad_height), padding_mode=mode)

@filter_out_non_signature_kwargs()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@filter_out_non_signature_kwargs()

modified_dict = self.image_processor_dict
modified_dict["size"] = 42
image_processor = image_processing_class(**modified_dict)
print(self.image_processor_dict)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(self.image_processor_dict)

self.assertEqual(list(pixel_values.shape), [1, 3, 512, 672])
self.assertTrue(pixel_values.shape[2] % multiple == 0)
self.assertTrue(pixel_values.shape[3] % multiple == 0)
self.image_processor_tester.do_pad = True
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.image_processor_tester.do_pad = True

@henrikm11 henrikm11 force-pushed the fast-image-processor-zoedepth branch from 8fe819e to bde23db Compare June 3, 2025 19:35
Copy link
Copy Markdown
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! two very small things left to change, after that LGTM!

size: SizeDict,
keep_aspect_ratio: bool = False,
ensure_multiple_of: int = 1,
interpolation: "F.InterpolationMode" = InterpolationMode.BILINEAR,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having InterpolationMode.BILINEAR as a default can cause some issues with the CI when torchvision is not available... Since we have resample = PILImageResampling.BILINEAR, as default class attributes, we shouldn't need to set a default here and can just set it to None. That way we can also remove the from torchvision.transforms import InterpolationMode above

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the default is no problem at all, the import is still required, however, because post_process_depth_estimation uses it.

If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
ensure_multiple_of (`int`, *optional*, defaults to 1):
The image is resized to a size that is a multiple of this value.
interpoation (`F.InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
interpoation (`F.InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
interpolation (`F.InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@henrikm11 henrikm11 force-pushed the fast-image-processor-zoedepth branch from 64444e5 to be46058 Compare June 4, 2025 20:38
@henrikm11
Copy link
Copy Markdown
Contributor Author

henrikm11 commented Jun 4, 2025

The two failed test are entirely unrelated to the changes I have made, do you have any suggestions how to deal with this? - Even worse, one fails because of a request timeout, the other one because of a max difference of (I assume random) tensors being 5e-05 and not 1e-05, both of which do not seem to be fully deterministic conditions...maybe you can just rerun the tests @yonigozlan?

@yonigozlan
Copy link
Copy Markdown
Member

The two failed test are entirely unrelated to the changes I have made, do you have any suggestions how to deal with this? - Even worse, one fails because of a request timeout, the other one because of a max difference of (I assume random) tensors being 5e-05 and not 1e-05, both of which do not seem to be fully deterministic conditions...maybe you can just rerun the tests @yonigozlan?

The CI can be a bit flaky sometimes, running the tests again and merging when they pass!

@yonigozlan yonigozlan enabled auto-merge (squash) June 4, 2025 22:46
@yonigozlan yonigozlan merged commit 1fed616 into huggingface:main Jun 4, 2025
20 checks passed
bvantuan pushed a commit to bvantuan/transformers that referenced this pull request Jun 12, 2025
huggingface#38515)

* added fast image processor for ZoeDepth and expanded tests accordingly

* added fast image processor for ZoeDepth and expanded tests accordingly, hopefully fixed repo consistency issue too now

* final edits for zoedept fast image processor

* final minor edit for zoedepth fast imate procesor
@henrikm11 henrikm11 deleted the fast-image-processor-zoedepth branch June 23, 2025 19:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants