Fast image processor for VitMatte added and bug in slow version fixed#37616
Fast image processor for VitMatte added and bug in slow version fixed#37616yonigozlan merged 4 commits intohuggingface:mainfrom henrikm11:fast-image-processor-ViTMatte
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
|
cc @yonigozlan |
yonigozlan
left a comment
There was a problem hiding this comment.
Thanks a lot @henrikm11 for your contribution! Looks great overall, but I think the processor can be simplified a bit, and a lot more can be done by batch.
Also you'll need to rebase/merge with main.
| np.concatenate([image, np.expand_dims(trimap, axis=-1)], axis=-1) for image, trimap in zip(images, trimaps) | ||
| np.concatenate([image, np.expand_dims(trimap, axis=axis)], axis=axis) |
|
|
||
| if pad_width + pad_height > 0: | ||
| padding = (0, pad_width, 0, pad_height) | ||
| image = torch.nn.functional.pad(image, padding) |
There was a problem hiding this comment.
could you use torchvision functional pad? You should be able to do it by batch with stacked images as well
| # do the padding | ||
| if do_pad: | ||
| processed_images = [self._pad_image(image, size_divisibility) for image in processed_images] |
There was a problem hiding this comment.
You should pad on stacked images grouped by shape
| grouped_images, grouped_images_index = group_images_by_shape(images) | ||
| processed_images_grouped = {} | ||
| for shape, stacked_images in grouped_images.items(): | ||
| # Fused rescale and normalize | ||
| stacked_images = self.rescale_and_normalize( | ||
| stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std | ||
| ) | ||
| processed_images_grouped[shape] = stacked_images | ||
| processed_images = reorder_images(processed_images_grouped, grouped_images_index) | ||
|
|
||
| # same for trimaps | ||
| grouped_trimaps, grouped_trimaps_index = group_images_by_shape(trimaps) | ||
| processed_trimaps_grouped = {} | ||
| for shape, stacked_trimaps in grouped_trimaps.items(): | ||
| # no normalization for trimaps | ||
| stacked_trimaps = self.rescale_and_normalize( | ||
| stacked_trimaps, do_rescale, rescale_factor, False, image_mean, image_std | ||
| ) | ||
| processed_trimaps_grouped[shape] = stacked_trimaps | ||
| processed_trimaps = reorder_images(processed_trimaps_grouped, grouped_trimaps_index) | ||
|
|
||
| # concatenate images and trimaps | ||
| processed_images = [ | ||
| torch.cat([image, trimap], dim=0) for image, trimap in zip(processed_images, processed_trimaps) | ||
| ] | ||
|
|
||
| # do the padding | ||
| if do_pad: | ||
| processed_images = [self._pad_image(image, size_divisibility) for image in processed_images] | ||
|
|
||
| # finish things up | ||
| grouped_processed, grouped_processed_index = group_images_by_shape(processed_images) | ||
| processed_images_grouped = {} | ||
| for shape, stacked_processed_images in grouped_processed.items(): | ||
| processed_images_grouped[shape] = stacked_processed_images |
There was a problem hiding this comment.
you can just have one loop as images and trimaps should have the same "shape" (height and width), and even concat images and trimaps, then pad in the loop,. Something like this should work:
| grouped_images, grouped_images_index = group_images_by_shape(images) | |
| processed_images_grouped = {} | |
| for shape, stacked_images in grouped_images.items(): | |
| # Fused rescale and normalize | |
| stacked_images = self.rescale_and_normalize( | |
| stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std | |
| ) | |
| processed_images_grouped[shape] = stacked_images | |
| processed_images = reorder_images(processed_images_grouped, grouped_images_index) | |
| # same for trimaps | |
| grouped_trimaps, grouped_trimaps_index = group_images_by_shape(trimaps) | |
| processed_trimaps_grouped = {} | |
| for shape, stacked_trimaps in grouped_trimaps.items(): | |
| # no normalization for trimaps | |
| stacked_trimaps = self.rescale_and_normalize( | |
| stacked_trimaps, do_rescale, rescale_factor, False, image_mean, image_std | |
| ) | |
| processed_trimaps_grouped[shape] = stacked_trimaps | |
| processed_trimaps = reorder_images(processed_trimaps_grouped, grouped_trimaps_index) | |
| # concatenate images and trimaps | |
| processed_images = [ | |
| torch.cat([image, trimap], dim=0) for image, trimap in zip(processed_images, processed_trimaps) | |
| ] | |
| # do the padding | |
| if do_pad: | |
| processed_images = [self._pad_image(image, size_divisibility) for image in processed_images] | |
| # finish things up | |
| grouped_processed, grouped_processed_index = group_images_by_shape(processed_images) | |
| processed_images_grouped = {} | |
| for shape, stacked_processed_images in grouped_processed.items(): | |
| processed_images_grouped[shape] = stacked_processed_images | |
| grouped_images, grouped_images_index = group_images_by_shape(images) | |
| grouped_trimaps, grouped_trimaps_index = group_images_by_shape(trimaps) | |
| processed_images_grouped = {} | |
| for shape in grouped_images: | |
| stacked_images = grouped_images[shape] | |
| stacked_trimaps = grouped_trimaps[shape] | |
| # Fused rescale and normalize | |
| stacked_images = self.rescale_and_normalize( | |
| stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std | |
| ) | |
| stacked_trimaps = self.rescale_and_normalize( | |
| stacked_trimaps, do_rescale, rescale_factor, False, image_mean, image_std | |
| ) | |
| stacked_images = torch.cat([stacked_images , stacked_trimaps ], dim=1) | |
| # do padding by batch.... | |
| processed_images_grouped[shape] = stacked_padded_images | |
| processed_images = reorder_images(processed_images_grouped, grouped_images_index) |
| # finish things up | ||
| grouped_processed, grouped_processed_index = group_images_by_shape(processed_images) | ||
| processed_images_grouped = {} | ||
| for shape, stacked_processed_images in grouped_processed.items(): | ||
| processed_images_grouped[shape] = stacked_processed_images | ||
| processed_images = reorder_images(processed_images_grouped, grouped_processed_index) |
There was a problem hiding this comment.
I don't think this does anything in this state
| image_std: Optional[Union[float, list[float]]] = IMAGENET_STANDARD_STD | ||
| do_pad: bool = True | ||
| size_divisibility: int = 32 | ||
| # input_data_format = ChannelDimension.FIRST |
There was a problem hiding this comment.
| # input_data_format = ChannelDimension.FIRST |
| def test_slow_fast_equivalence(self): | ||
| if not self.test_slow_image_processor or not self.test_fast_image_processor: | ||
| self.skipTest(reason="Skipping slow/fast equivalence test") | ||
|
|
||
| if self.image_processing_class is None or self.fast_image_processing_class is None: | ||
| self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") | ||
|
|
||
| dummy_image = Image.open( | ||
| requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw | ||
| ) | ||
| dummy_trimap = np.random.randint(0, 3, size=dummy_image.size[::-1]) | ||
| image_processor_slow = self.image_processing_class(**self.image_processor_dict) | ||
| image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) | ||
|
|
||
| encoding_slow = image_processor_slow(dummy_image, trimaps=dummy_trimap, return_tensors="pt") | ||
| encoding_fast = image_processor_fast(dummy_image, trimaps=dummy_trimap, return_tensors="pt") | ||
| self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1)) | ||
| self.assertLessEqual( | ||
| torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3 | ||
| ) | ||
|
|
||
| def test_slow_fast_equivalence_batched(self): | ||
| # this only checks on equal resolution, since the slow processor doesn't work otherwise | ||
| if not self.test_slow_image_processor or not self.test_fast_image_processor: | ||
| self.skipTest(reason="Skipping slow/fast equivalence test") | ||
|
|
||
| if self.image_processing_class is None or self.fast_image_processing_class is None: | ||
| self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") | ||
|
|
||
| if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop: | ||
| self.skipTest( | ||
| reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors" | ||
| ) | ||
|
|
||
| dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) | ||
| dummy_trimaps = [np.random.randint(0, 3, size=image.shape[1:]) for image in dummy_images] | ||
| image_processor_slow = self.image_processing_class(**self.image_processor_dict) | ||
| image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) | ||
|
|
||
| encoding_slow = image_processor_slow(dummy_images, trimaps=dummy_trimaps, return_tensors="pt") | ||
| encoding_fast = image_processor_fast(dummy_images, trimaps=dummy_trimaps, return_tensors="pt") | ||
|
|
||
| self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1)) | ||
| self.assertLessEqual( | ||
| torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3 | ||
| ) |
There was a problem hiding this comment.
Great thanks for testing thoroughly!
|
@yonigozlan thanks a lot for the detailed review, seems like i got sidetracked by the bug a bit and lost sight of other things..also, I hope I pushed forcefully correctly following the contribution guide. |
yonigozlan
left a comment
There was a problem hiding this comment.
Thanks for iterating @henrikm11 ! Mostly formatting/styling issues left, then LGTM!
|
|
||
| def _pad_image( | ||
| self, | ||
| image: "torch.tensor", |
There was a problem hiding this comment.
Maybe we can make it clearer that we can gave multiple stacked images here and after
| image: "torch.tensor", | |
| images: "torch.tensor", |
|
|
||
| Returns: | ||
| 'torch.tensor': | ||
| padded tensor |
There was a problem hiding this comment.
| Returns: | |
| 'torch.tensor': | |
| padded tensor |
| 'torch.tensor': | ||
| padded tensor | ||
| """ | ||
|
|
| stacked_padded_images = ( | ||
| self._pad_image(stacked_images, self.size_divisibility) if do_pad else stacked_images | ||
| ) |
There was a problem hiding this comment.
let's just have
if do pad:
...otherwise difficult to read
| # Group images by size for further processing | ||
| # Needed in case do_resize is False, or resize returns images with different sizes |
There was a problem hiding this comment.
| # Group images by size for further processing | |
| # Needed in case do_resize is False, or resize returns images with different sizes |
| stacked_trimaps, do_rescale, rescale_factor, False, image_mean, image_std | ||
| ) | ||
| stacked_images = torch.cat([stacked_images, stacked_trimaps], dim=1) | ||
| # do padding by batch.... |
There was a problem hiding this comment.
| # do padding by batch.... |
| @require_vision | ||
| class VitMatteImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): | ||
| image_processing_class = VitMatteImageProcessor if is_vision_available() else None | ||
| fast_image_processing_class = VitMatteImageProcessorFast if is_torch_available else None |
There was a problem hiding this comment.
| fast_image_processing_class = VitMatteImageProcessorFast if is_torch_available else None | |
| fast_image_processing_class = VitMatteImageProcessorFast if is_torchvision_available() else None |
…ts, fixed a bug in the slow image processor that processed images incorrectly for input format ChannelDimension.FIRST in which case the trimaps were not added in the correct dimension, this bug was also reflected in the tests through incorretly shaped trimaps being passed
yonigozlan
left a comment
There was a problem hiding this comment.
Looks good, LGTM! Thanks for contributing :)
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…huggingface#37616) * added fast image processor for VitMatte including updated and new tests, fixed a bug in the slow image processor that processed images incorrectly for input format ChannelDimension.FIRST in which case the trimaps were not added in the correct dimension, this bug was also reflected in the tests through incorretly shaped trimaps being passed * final edits for fast vitmatte image processor and tests * final edits for fast vitmatte image processor and tests --------- Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
-) Added a fast image processor for VitMatte
-) fixed a bug in the slow image processor that processed images incorrectly for input format ChannelDimension.FIRST in which case the trimaps were not added in the correct dimension, this bug was also reflected in the tests through incorrectly shaped trimaps being passed to the preprocess function.
-) adjusted old tests to also cover the fast image processor
-) expanded tests