[PixtralLarge] Update Pixtral conversion script to support large format!#34801
[PixtralLarge] Update Pixtral conversion script to support large format!#34801ArthurZucker merged 46 commits intomainfrom
PixtralLarge] Update Pixtral conversion script to support large format!#34801Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
aba20bf to
24d9ee5
Compare
…n case I need to revert
d599d5d to
99ea497
Compare
|
This should be just about ready! Quick summary of the changes:
TODO:
|
# Conflicts: # src/transformers/models/pixtral/modeling_pixtral.py
| def _recursive_to(obj, device, *args, **kwargs): | ||
| # Lists can be nested, so keep digging until we hit tensors | ||
| if isinstance(obj, list): | ||
| return [_recursive_to(o, device, *args, **kwargs) for o in obj] | ||
| # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` | ||
| elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj): | ||
| # cast and send to device | ||
| return obj.to(*args, **kwargs) | ||
| elif isinstance(obj, torch.Tensor) and device is not None: | ||
| # only send to device, don't cast | ||
| return obj.to(device=device) | ||
| else: | ||
| return obj | ||
|
|
There was a problem hiding this comment.
Note to reviewer: The previous BatchFeature.to() actually flattened the structure of nested inputs, which created several bugs! This fix preserves nested structure
| if isinstance(text, str) or isinstance(text, list) and len(text) == 1: | ||
| # If there's a single sample, the image must belong to it | ||
| images = [[images]] | ||
| else: | ||
| raise ValueError( | ||
| "You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample." | ||
| ) | ||
| elif isinstance(images, list) and is_image_or_image_url(images[0]): | ||
| if isinstance(text, str) or isinstance(text, list) and len(text) == 1: | ||
| # If there's a single sample, all images must belong to it | ||
| images = [images] | ||
| else: | ||
| raise ValueError( | ||
| "You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample." | ||
| ) |
There was a problem hiding this comment.
Note to reviewer: Previously there were a lot of edge cases when users passed a single list of images. In some cases, the processor interpreted this as one image per sample rather than a list of images for one sample. This code avoids these error-prone inferences.
| patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values] | ||
| if len(pixel_values) > 1: | ||
| raise ValueError("Batching/padding not supported yet!") | ||
| patch_embeds_list = [self.patch_conv(img.to(self.dtype)) for sample in pixel_values for img in sample] | ||
|
|
||
| # flatten to a single sequence | ||
| patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1) | ||
| patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0) | ||
| patch_embeds = self.ln_pre(patch_embeds) | ||
|
|
||
| # positional embeddings | ||
| position_ids = position_ids_in_meshgrid( | ||
| patch_embeds_list, max_width=self.config.image_size // self.config.patch_size | ||
| ).to(self.device) | ||
|
|
||
| position_embedding = self.patch_positional_embedding(patch_embeds, position_ids) | ||
|
|
There was a problem hiding this comment.
Note to reviewer: These changes are here to handle images being passed in as a list of lists now. Previously, images were passed in as a flat list even though the processor output a list of lists. The only reason this didn't cause an error was because the bug in BatchFeature.to() silently fixed the list structure and made it match the modeling code 😓
|
This should be ready for final review @ArthurZucker! I did ablation testing and reverted some of the dtype changes in |
ArthurZucker
left a comment
There was a problem hiding this comment.
Let's roll! A todo is to add another test for the new model 😉 Good to go otherwise
| def _recursive_to(obj, device, *args, **kwargs): | ||
| # Lists can be nested, so keep digging until we hit tensors | ||
| if isinstance(obj, list): | ||
| return [_recursive_to(o, device, *args, **kwargs) for o in obj] | ||
| # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` | ||
| elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj): | ||
| # cast and send to device | ||
| return obj.to(*args, **kwargs) | ||
| elif isinstance(obj, torch.Tensor) and device is not None: | ||
| # only send to device, don't cast | ||
| return obj.to(device=device) | ||
| else: | ||
| return obj |
There was a problem hiding this comment.
Should probably be fixed on the parent class
…nges (#1116) ## Purpose ## * In transformers==4.48.0, the Pixtral processor was updated to not add an additional layer of wrapping for `pixel_values` (huggingface/transformers#34801). This is more inline with how other processors handle multimodal inputs * Because previously the data_collator was being used to unwrap this unnecessary wrapping, attempting to quantize pixtral with transformers>=4.48.0 fails ## Changes ## * Update pixtral data collator to match latest transformers version * Add comment for those who want to use transformers<4.48.0 ## Testing ## * Ran pixtral example to completion, @shubhra ran pixtral large --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
What does this PR do?
Updates the conversion script