Add ColQwen2 to 🤗 transformers#35778
Conversation
|
Feel free to ping us once this is ready for review! |
025ca25 to
48e0aa5
Compare
|
Feel free to ping @Cyrilvallez once this is ready for review! 🤗 |
5ec9758 to
7cfd9dc
Compare
yonigozlan
left a comment
There was a problem hiding this comment.
Hey @tonywu71 ! Thanks for contributing 🤗. Looks almost ready to go to me, I just pointed out a few nits to change
| ) | ||
| self.query_prefix = query_prefix or "Query: " | ||
|
|
||
| self.tokenizer.padding_side = "left" |
There was a problem hiding this comment.
This should be set when saving the tokenizer/processor
There was a problem hiding this comment.
Fixed! The Hf Hub commit with the new processor_config.json can be found here for reference.
There was a problem hiding this comment.
Update: after discussion with @yonigozlan, I have realized it makes much more sense to let tokenizer_config.json handle padding_side. I've just applied the necessary changes!
yonigozlan
left a comment
There was a problem hiding this comment.
Nice thanks for iterating! I see two small things left to change then LGTM for me!
Cyrilvallez
left a comment
There was a problem hiding this comment.
Hey! Sorry for the delay! This is pretty clean, great work! 🤗 I just left a few last comments!
|
|
||
| loss = None | ||
| if not return_dict: | ||
| output = (embeddings,) + outputs[2:] | ||
| output[2] = output[2] if output_hidden_states is not None else None | ||
| output[-1] = (outputs.image_hidden_states if pixel_values is not None else None,) | ||
| return (loss,) + output if loss is not None else output | ||
|
|
||
| return ColPaliForRetrievalOutput( | ||
| loss=loss, | ||
| embeddings=embeddings, |
There was a problem hiding this comment.
Why are we removing the loss here? 👀
There was a problem hiding this comment.
The loss was strictly speaking removed:
- it used to be set to
None. - the default value for
lossinColPaliForRetrievalOutputisNone.
So I have removed the unneeded lines to make the code clearer.
| visual_prompt_prefix: str = "Describe the image.", | ||
| query_prefix: str = "Question: ", | ||
| ): | ||
| super().__init__(image_processor=image_processor, tokenizer=tokenizer, chat_template=chat_template) | ||
| self.visual_prompt_prefix = visual_prompt_prefix | ||
| self.query_prefix = query_prefix |
There was a problem hiding this comment.
These kind of prefix should be part of the chat_template directly, not hardcoded here 🤗
| if is_torch_available(): | ||
| import torch | ||
| from torch import nn |
There was a problem hiding this comment.
No need to protect the torch import here!
| raise AttributeError( | ||
| "The `initializer_range` attribute is not set in the configuration. Please set it before using the model." | ||
| ) |
There was a problem hiding this comment.
Let's make sure it is correctly defined in the Config with some default value instead of raising here
There was a problem hiding this comment.
The ColQwen2Config already has a default value for initializer_range, so I'll just remove the raise 👍🏼
| if inputs_embeds is None: | ||
| inputs_embeds = self.vlm.model.embed_tokens(input_ids) | ||
|
|
||
| if pixel_values is not None: | ||
| pixel_values = pixel_values.type(self.vlm.visual.get_dtype()) | ||
| image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw) | ||
| image_mask = ( | ||
| (input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) | ||
| ) | ||
| image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) | ||
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) | ||
|
|
||
| if attention_mask is not None: | ||
| attention_mask = attention_mask.to(inputs_embeds.device) | ||
|
|
||
| outputs = self.vlm.model( | ||
| input_ids=None, | ||
| position_ids=position_ids, | ||
| attention_mask=attention_mask, | ||
| past_key_values=past_key_values, | ||
| inputs_embeds=inputs_embeds, | ||
| use_cache=use_cache, | ||
| output_attentions=output_attentions, | ||
| output_hidden_states=output_hidden_states, | ||
| return_dict=return_dict, | ||
| cache_position=cache_position, | ||
| ) | ||
| return outputs |
There was a problem hiding this comment.
If you don't mind, I think it would help readability to have this block directly in the main forward instead of separating in 2 functions (due to the large signatures, we need to go back and forth)
There was a problem hiding this comment.
No I don't mind, I think it's actually a good idea! 🤗
| if visual_prompt_prefix is None: | ||
| visual_prompt_prefix = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>" | ||
| self.visual_prompt_prefix = visual_prompt_prefix | ||
|
|
||
| if query_prefix is None: | ||
| query_prefix = "Query: " | ||
| self.query_prefix = query_prefix |
There was a problem hiding this comment.
These should be incorporated to the chat_template 🤗
There was a problem hiding this comment.
Not sure if we should have a chat template here since this is not a chat model really. We had the same issue with Got OCR and ended up not using a chat template. wdyt?
There was a problem hiding this comment.
Humm indeed was a bit fast here - let's keep as is, especially as it aligns with ColPali!
| if text is not None and images is not None: | ||
| raise ValueError("Only one of text or images can be processed at a time") |
There was a problem hiding this comment.
Alright, let's keep it then!
| def process_images( | ||
| self, | ||
| images: ImageInput = None, | ||
| **kwargs: Unpack[ColQwen2ProcessorKwargs], | ||
| ) -> BatchFeature: | ||
| """ | ||
| Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColQwen2Processor's | ||
| [`ColQwen2Processor.__call__`]. | ||
|
|
||
| This method forwards the `images` and `kwargs` arguments to Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`]. | ||
|
|
||
| Args: | ||
| images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): | ||
| The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch | ||
| tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a | ||
| number of channels, H and W are image height and width. | ||
| return_tensors (`str` or [`~utils.TensorType`], *optional*): | ||
| If set, will return tensors of a particular framework. Acceptable values are: | ||
|
|
||
| - `'tf'`: Return TensorFlow `tf.constant` objects. | ||
| - `'pt'`: Return PyTorch `torch.Tensor` objects. | ||
| - `'np'`: Return NumPy `np.ndarray` objects. | ||
| - `'jax'`: Return JAX `jnp.ndarray` objects. | ||
|
|
||
| Returns: | ||
| [`BatchFeature`]: A [`BatchFeature`] with the following fields: | ||
|
|
||
| - **input_ids** -- List of token ids to be fed to a model. | ||
| - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when | ||
| `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not | ||
| `None`). | ||
| - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. | ||
| """ | ||
| return self.__call__(images=images, **kwargs) | ||
|
|
||
| def process_queries( | ||
| self, | ||
| text: Union[TextInput, List[TextInput]], | ||
| **kwargs: Unpack[ColQwen2ProcessorKwargs], | ||
| ) -> BatchFeature: | ||
| """ | ||
| Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColQwen2Processor's | ||
| [`ColQwen2Processor.__call__`]. |
There was a problem hiding this comment.
Why do we redefine them here? They will be inherited directly!
There was a problem hiding this comment.
Oh you're right! However, I think the docstring will inherit from ColPaliProcessor's docstring and thus referencing ColPali. Is there a way to simply override the docstring here? If not, should we keep the code as it is?
There was a problem hiding this comment.
I don't see a clean way to do this, but we can just remove specific references to the tokenizer and image processor in the docstring imo
|
@Cyrilvallez taking over for the final push on this PR as Tony is quite busy. I pushed some necessary updates after the refactoring of Qwen2VL (so nice to have btw), all should be good now and we use modular much more, including for the modeling code 🤗. @tonywu71 you'll still have to run the updated convert_weights script and push to your repo :), but apart from that we should be ready to merge! |
… into add-colqwen2
@yonigozlan Done, the model repo is updated! 🤗 I've also pushed a commit to fix the Hf model path for ColQwen2 integration tests. Lmk if there's anything left to do before merging! |
Cyrilvallez
left a comment
There was a problem hiding this comment.
All right! Amazing work, congrats to you both @tonywu71 @yonigozlan! Super clean 🤗 I left 2 ultra small comments as my job here is to be very picky 🙃, but that's it! Feel free to merge @yonigozlan!
Thanks for the great addition 🤗
| attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, | ||
| ) |
There was a problem hiding this comment.
This is a super nit, feel free to disregard if you're too annoyed by the review process 😆 But passing None is a bit misleading for an example IMO, even if it's equivalent
| attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, | |
| ) | |
| attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa", | |
| ) |
There was a problem hiding this comment.
Agreed! It's been addressed 👌🏼
There was a problem hiding this comment.
Actually it seems sdpa doesn't work out-of-the-box for ColQwen2 as I get this error when loading the model on MPS.
❌ Code:
model_name = "vidore/colqwen2-v1.0-hf"
# Load model
model = ColQwen2ForRetrieval.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto", # "cpu", "cuda", or "mps" for Apple Silicon
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa",
)Note: Leaving attn_implementation=None works.
The error:
ValueError: ColQwen2ForRetrieval does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet. Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`
✅ However, I managed to load Qwen2VL with SDPA:
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
torch_dtype=torch.bfloat16,
device_map="auto", # "cpu", "cuda", or "mps" for Apple Silicon
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa",
)@Cyrilvallez @yonigozlan I read about the instructions for enabling SDPA on ColQwen2 but next steps are a bit unclear as ColQwen2 essentially piggybacks on Qwen2VL thanks to modular. Any ideas about the right fix? 🤗
There was a problem hiding this comment.
I believe it's only because the flags are not set in the PreTrainedModel - adding
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = Trueshould solve it
There was a problem hiding this comment.
Tsm, the fix is working like a charm! And as you expected, ColQwen2 works with attn_implementation="flex_attention" too 👌🏼
| if is_torch_available(): | ||
| import torch |
There was a problem hiding this comment.
Let's not protect, simply import it 🤗
There was a problem hiding this comment.
Problem is we need to protect the import for the processor :(
There was a problem hiding this comment.
Oh I see - not a big issue anyway you can disregard (it's just that torch.nn is imported without protection anyway so a bit weird), but really not a big concern
| if visual_prompt_prefix is None: | ||
| visual_prompt_prefix = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>" | ||
| self.visual_prompt_prefix = visual_prompt_prefix | ||
|
|
||
| if query_prefix is None: | ||
| query_prefix = "Query: " | ||
| self.query_prefix = query_prefix |
There was a problem hiding this comment.
Humm indeed was a bit fast here - let's keep as is, especially as it aligns with ColPali!
What does this PR do?
Add ColQwen2 in 🤗
transformers. ColQwen2 is a model that uses the ColPali architecture with a Qwen2-VL backbone.Who can review?
Additional details
colpali-engine==v0.3.6.vidore/colqwen2-v1.0-hf.Progress checklist