Skip to content

Fast image processor for VitMatte added and bug in slow version fixed#37616

Merged
yonigozlan merged 4 commits intohuggingface:mainfrom
henrikm11:fast-image-processor-ViTMatte
Apr 28, 2025
Merged

Fast image processor for VitMatte added and bug in slow version fixed#37616
yonigozlan merged 4 commits intohuggingface:mainfrom
henrikm11:fast-image-processor-ViTMatte

Conversation

@henrikm11
Copy link
Copy Markdown
Contributor

-) Added a fast image processor for VitMatte
-) fixed a bug in the slow image processor that processed images incorrectly for input format ChannelDimension.FIRST in which case the trimaps were not added in the correct dimension, this bug was also reflected in the tests through incorrectly shaped trimaps being passed to the preprocess function.
-) adjusted old tests to also cover the fast image processor
-) expanded tests

@github-actions github-actions Bot marked this pull request as draft April 18, 2025 14:04
@github-actions
Copy link
Copy Markdown
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@henrikm11 henrikm11 marked this pull request as ready for review April 18, 2025 14:07
@github-actions github-actions Bot requested review from ydshieh and yonigozlan April 18, 2025 14:07
@Rocketknight1
Copy link
Copy Markdown
Member

cc @yonigozlan

Copy link
Copy Markdown
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @henrikm11 for your contribution! Looks great overall, but I think the processor can be simplified a bit, and a lot more can be done by batch.
Also you'll need to rebase/merge with main.

Comment on lines -254 to +255
np.concatenate([image, np.expand_dims(trimap, axis=-1)], axis=-1) for image, trimap in zip(images, trimaps)
np.concatenate([image, np.expand_dims(trimap, axis=axis)], axis=axis)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!


if pad_width + pad_height > 0:
padding = (0, pad_width, 0, pad_height)
image = torch.nn.functional.pad(image, padding)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you use torchvision functional pad? You should be able to do it by batch with stacked images as well

Comment on lines +246 to +248
# do the padding
if do_pad:
processed_images = [self._pad_image(image, size_divisibility) for image in processed_images]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should pad on stacked images grouped by shape

Comment on lines +220 to +254
grouped_images, grouped_images_index = group_images_by_shape(images)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index)

# same for trimaps
grouped_trimaps, grouped_trimaps_index = group_images_by_shape(trimaps)
processed_trimaps_grouped = {}
for shape, stacked_trimaps in grouped_trimaps.items():
# no normalization for trimaps
stacked_trimaps = self.rescale_and_normalize(
stacked_trimaps, do_rescale, rescale_factor, False, image_mean, image_std
)
processed_trimaps_grouped[shape] = stacked_trimaps
processed_trimaps = reorder_images(processed_trimaps_grouped, grouped_trimaps_index)

# concatenate images and trimaps
processed_images = [
torch.cat([image, trimap], dim=0) for image, trimap in zip(processed_images, processed_trimaps)
]

# do the padding
if do_pad:
processed_images = [self._pad_image(image, size_divisibility) for image in processed_images]

# finish things up
grouped_processed, grouped_processed_index = group_images_by_shape(processed_images)
processed_images_grouped = {}
for shape, stacked_processed_images in grouped_processed.items():
processed_images_grouped[shape] = stacked_processed_images
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can just have one loop as images and trimaps should have the same "shape" (height and width), and even concat images and trimaps, then pad in the loop,. Something like this should work:

Suggested change
grouped_images, grouped_images_index = group_images_by_shape(images)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
# same for trimaps
grouped_trimaps, grouped_trimaps_index = group_images_by_shape(trimaps)
processed_trimaps_grouped = {}
for shape, stacked_trimaps in grouped_trimaps.items():
# no normalization for trimaps
stacked_trimaps = self.rescale_and_normalize(
stacked_trimaps, do_rescale, rescale_factor, False, image_mean, image_std
)
processed_trimaps_grouped[shape] = stacked_trimaps
processed_trimaps = reorder_images(processed_trimaps_grouped, grouped_trimaps_index)
# concatenate images and trimaps
processed_images = [
torch.cat([image, trimap], dim=0) for image, trimap in zip(processed_images, processed_trimaps)
]
# do the padding
if do_pad:
processed_images = [self._pad_image(image, size_divisibility) for image in processed_images]
# finish things up
grouped_processed, grouped_processed_index = group_images_by_shape(processed_images)
processed_images_grouped = {}
for shape, stacked_processed_images in grouped_processed.items():
processed_images_grouped[shape] = stacked_processed_images
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_trimaps, grouped_trimaps_index = group_images_by_shape(trimaps)
processed_images_grouped = {}
for shape in grouped_images:
stacked_images = grouped_images[shape]
stacked_trimaps = grouped_trimaps[shape]
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
stacked_trimaps = self.rescale_and_normalize(
stacked_trimaps, do_rescale, rescale_factor, False, image_mean, image_std
)
stacked_images = torch.cat([stacked_images , stacked_trimaps ], dim=1)
# do padding by batch....
processed_images_grouped[shape] = stacked_padded_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index)

Comment on lines +250 to +255
# finish things up
grouped_processed, grouped_processed_index = group_images_by_shape(processed_images)
processed_images_grouped = {}
for shape, stacked_processed_images in grouped_processed.items():
processed_images_grouped[shape] = stacked_processed_images
processed_images = reorder_images(processed_images_grouped, grouped_processed_index)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this does anything in this state

image_std: Optional[Union[float, list[float]]] = IMAGENET_STANDARD_STD
do_pad: bool = True
size_divisibility: int = 32
# input_data_format = ChannelDimension.FIRST
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# input_data_format = ChannelDimension.FIRST

Comment on lines +291 to +336
def test_slow_fast_equivalence(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")

if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")

dummy_image = Image.open(
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
)
dummy_trimap = np.random.randint(0, 3, size=dummy_image.size[::-1])
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)

encoding_slow = image_processor_slow(dummy_image, trimaps=dummy_trimap, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, trimaps=dummy_trimap, return_tensors="pt")
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)

def test_slow_fast_equivalence_batched(self):
# this only checks on equal resolution, since the slow processor doesn't work otherwise
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")

if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")

if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
self.skipTest(
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
)

dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
dummy_trimaps = [np.random.randint(0, 3, size=image.shape[1:]) for image in dummy_images]
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)

encoding_slow = image_processor_slow(dummy_images, trimaps=dummy_trimaps, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, trimaps=dummy_trimaps, return_tensors="pt")

self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great thanks for testing thoroughly!

@henrikm11
Copy link
Copy Markdown
Contributor Author

@yonigozlan thanks a lot for the detailed review, seems like i got sidetracked by the bug a bit and lost sight of other things..also, I hope I pushed forcefully correctly following the contribution guide.

Copy link
Copy Markdown
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating @henrikm11 ! Mostly formatting/styling issues left, then LGTM!


def _pad_image(
self,
image: "torch.tensor",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can make it clearer that we can gave multiple stacked images here and after

Suggested change
image: "torch.tensor",
images: "torch.tensor",

Comment on lines +190 to +193

Returns:
'torch.tensor':
padded tensor
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Returns:
'torch.tensor':
padded tensor

'torch.tensor':
padded tensor
"""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

Comment on lines +238 to +240
stacked_padded_images = (
self._pad_image(stacked_images, self.size_divisibility) if do_pad else stacked_images
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just have

if do pad:
...

otherwise difficult to read

Comment on lines +221 to +222
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes

stacked_trimaps, do_rescale, rescale_factor, False, image_mean, image_std
)
stacked_images = torch.cat([stacked_images, stacked_trimaps], dim=1)
# do padding by batch....
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# do padding by batch....

@require_vision
class VitMatteImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = VitMatteImageProcessor if is_vision_available() else None
fast_image_processing_class = VitMatteImageProcessorFast if is_torch_available else None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fast_image_processing_class = VitMatteImageProcessorFast if is_torch_available else None
fast_image_processing_class = VitMatteImageProcessorFast if is_torchvision_available() else None

…ts, fixed a bug in the slow image processor that processed images incorrectly for input format ChannelDimension.FIRST in which case the trimaps were not added in the correct dimension, this bug was also reflected in the tests through incorretly shaped trimaps being passed
Copy link
Copy Markdown
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, LGTM! Thanks for contributing :)

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yonigozlan yonigozlan merged commit a847d4a into huggingface:main Apr 28, 2025
20 checks passed
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
…huggingface#37616)

* added fast image processor for VitMatte including updated and new tests, fixed a bug in the slow image processor that processed images incorrectly for input format ChannelDimension.FIRST in which case the trimaps were not added in the correct dimension, this bug was also reflected in the tests through incorretly shaped trimaps being passed

* final edits for fast vitmatte image processor and tests

* final edits for fast vitmatte image processor and tests

---------

Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants