Add Gemma4ForSequenceClassification#45438
Add Gemma4ForSequenceClassification#45438Charly21r wants to merge 3 commits intohuggingface:mainfrom
Conversation
| def test_load_with_mismatched_shapes(self): | ||
| pass |
There was a problem hiding this comment.
why? Let's find the root reason if failing, shouldn't fails ideally
There was a problem hiding this comment.
The root cause is that Gemma4PreTrainedModel doesn't set base_model_prefix = "model" (unlike Gemma3PreTrainedModel and Gemma3nPreTrainedModel which both do).
The test saves a Gemma4TextForSequenceClassification checkpoint (keys prefixed with model.), then loads it via AutoModel -> Gemma4TextModel with vocab_size=10. Since Gemma4TextModel inherits base_model_prefix = "" from Gemma4PreTrainedModel, the loading code can't match model.embed_tokens.weight to embed_tokens.weight so the keys end up as unexpected/missing rather than mismatched, so no RuntimeError is raised.
This is the same reason gemma3 skips it. Fixing Gemma4PreTrainedModel.base_model_prefix would be a separate broader change, should I include that fix or keep it out of scope for this PR?
There was a problem hiding this comment.
I think we need a base_model_prefix = "model". Classes that have different base prefix already overwrite it, don't know why we didn't add it. Prob didn't notice
There was a problem hiding this comment.
Added base_model_prefix = "model" to Gemma4PreTrainedModel. The text-only test_load_with_mismatched_shapes now passes. The multimodal variant (Vision2Text) is skipped with the same reason as gemma3: loading nested configs with overwritten kwargs isn't supported yet.
|
Thanks for the quick turnaround! The GPU is currently blocked by an unrelated training job, should free up in ~2 days. Will report back here with results (or a minimal repro if anything breaks). |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, gemma4 |
| class Gemma4ForSequenceClassification(Gemma4PreTrainedModel): | ||
| def __init__(self, config): | ||
| super().__init__(config) | ||
| self.num_labels = config.num_labels | ||
| self.model = Gemma4Model(config) | ||
| self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False) |
There was a problem hiding this comment.
looks fine to me but before merging I want to sync with @vasqu . We have another related PR discussing how to better enable SeqClf on VLMs (that can be maybe LLMs)
There was a problem hiding this comment.
I think this is fine for now, we can refactor this in the other PR(?)
I think we need to think a bit more either way how we properly do this
- VLMs that can be used for text only
- Different VLM class to have proper signatures
vasqu
left a comment
There was a problem hiding this comment.
Yea, this kind of goes into the direction that we use a separate class for VLMs and text only
Imo, we need to think a bit more about this before merging. We will encounter similar things for each (mixed) modality then so we need to decide on a proper standard. It is kind of a harder issue as we have different signatures across them
|
|
||
| class Gemma4PreTrainedModel(PreTrainedModel): | ||
| config: Gemma4Config | ||
| base_model_prefix = "model" |
There was a problem hiding this comment.
This looks potentially breaking because of the audio only model iirc. Any reason this change was really needed; would rather be explicit under the different ones for now
There was a problem hiding this comment.
vision/audio model has overriden its prefix, isn't it? The standard is to have model and I suggested this than manually adding model over two new classes
There was a problem hiding this comment.
Oh it does? Sorry not too familiar with gemma4, do we have tests for each standalone model to double check we didnt break existing stuff for those?
There was a problem hiding this comment.
thanks for collective efforts, gemma4 has very extensive tests per modality
(super happy about that tbh)
| def get_input_embeddings(self): | ||
| return self.model.get_input_embeddings() | ||
|
|
||
| def set_input_embeddings(self, value): | ||
| self.model.set_input_embeddings(value) |
There was a problem hiding this comment.
Do we actually still need these? EmbeddingAccessMixin should catch these with the base model, no?
| class Gemma4ForSequenceClassification(Gemma4PreTrainedModel): | ||
| def __init__(self, config): | ||
| super().__init__(config) | ||
| self.num_labels = config.num_labels | ||
| self.model = Gemma4Model(config) | ||
| self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False) |
There was a problem hiding this comment.
I think this is fine for now, we can refactor this in the other PR(?)
I think we need to think a bit more either way how we properly do this
- VLMs that can be used for text only
- Different VLM class to have proper signatures
| @unittest.skip("Loading nested configs with overwritten `kwargs` isn't supported yet.") | ||
| def test_load_with_mismatched_shapes(self): | ||
| pass |
There was a problem hiding this comment.
Ig it's the new class, would tbh just disable that one class and still test this
@vasqu yeah, I agree with it 100%, we can't just list all possible So how important it is to align signature in your opinion? I think we can give up a little bit to reduce duplicate LOC, and instead add more info in class-level docstring |
|
All of these are combined by the fact that they have the same text inputs, could we introduce new typed dicts for kwargs that are specific to each modality and add them to the generic one? I think this would kind of lean into what you want to do while having some insights on possible inputs via typed kwargs. We do something similar for processors already, no? The goal would be then to have more control tho that can filter out kwargs based on a simple flags, e.g. For example, the forward signature then can look like: |
|
@Charly21r we will work internally on creating generic sequence classification classes for multimodal LLMs. Until then, let's put current PR on hold I will tag you again when we can resume the PR |
|
FYI — the QLoRA classification use case is ready to test whenever your generic approach lands. We've been running Gemma 4 E4B seq-clf locally via a small wrapper following the Gemma 3 pattern; happy to validate whichever direction you land on. Not urgent from our side — we can unblock with the local wrapper in the meantime. |
* chore(typing): add ty type checking for 3 pipeline files Adds ty type checking coverage for: - src/transformers/pipelines/feature_extraction.py - src/transformers/pipelines/image_feature_extraction.py - src/transformers/pipelines/video_classification.py For Issues #45438 * fix: restore _typing.py in check_args per review feedback --------- Co-authored-by: Tarek Ziade <tarek@ziade.org>
What does this PR do?
Fixes #45373
Adds
Gemma4TextForSequenceClassificationandGemma4ForSequenceClassificationtotransformers.models.gemma4, following the same pattern established by Gemma 3.Prior to this change,
AutoModelForSequenceClassification.from_pretrained("google/gemma-4-E4B", num_labels=N)raised aValueErrorbecause neitherGemma4ConfignorGemma4TextConfigwere registered inMODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES. This is inconsistent with every prior Gemma release (gemma, gemma2, gemma3, gemma3n all export sequence classification variants).Code Agent Policy
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker @zucchini-nlp