Skip to content

Add Gemma4ForSequenceClassification#45438

Open
Charly21r wants to merge 3 commits intohuggingface:mainfrom
Charly21r:gemma4-sequence-classification
Open

Add Gemma4ForSequenceClassification#45438
Charly21r wants to merge 3 commits intohuggingface:mainfrom
Charly21r:gemma4-sequence-classification

Conversation

@Charly21r
Copy link
Copy Markdown
Contributor

@Charly21r Charly21r commented Apr 14, 2026

What does this PR do?

Fixes #45373
Adds Gemma4TextForSequenceClassification and Gemma4ForSequenceClassification to transformers.models.gemma4, following the same pattern established by Gemma 3.

Prior to this change, AutoModelForSequenceClassification.from_pretrained("google/gemma-4-E4B", num_labels=N) raised a ValueError because neither Gemma4Config nor Gemma4TextConfig were registered in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES. This is inconsistent with every prior Gemma release (gemma, gemma2, gemma3, gemma3n all export sequence classification variants).

Code Agent Policy

  • I confirm that this is not a pure code agent PR.

Before submitting

Who can review?

@ArthurZucker @zucchini-nlp

@Charly21r Charly21r marked this pull request as draft April 14, 2026 17:33
Comment on lines +100 to +101
def test_load_with_mismatched_shapes(self):
pass
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? Let's find the root reason if failing, shouldn't fails ideally

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The root cause is that Gemma4PreTrainedModel doesn't set base_model_prefix = "model" (unlike Gemma3PreTrainedModel and Gemma3nPreTrainedModel which both do).

The test saves a Gemma4TextForSequenceClassification checkpoint (keys prefixed with model.), then loads it via AutoModel -> Gemma4TextModel with vocab_size=10. Since Gemma4TextModel inherits base_model_prefix = "" from Gemma4PreTrainedModel, the loading code can't match model.embed_tokens.weight to embed_tokens.weight so the keys end up as unexpected/missing rather than mismatched, so no RuntimeError is raised.

This is the same reason gemma3 skips it. Fixing Gemma4PreTrainedModel.base_model_prefix would be a separate broader change, should I include that fix or keep it out of scope for this PR?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need a base_model_prefix = "model". Classes that have different base prefix already overwrite it, don't know why we didn't add it. Prob didn't notice

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added base_model_prefix = "model" to Gemma4PreTrainedModel. The text-only test_load_with_mismatched_shapes now passes. The multimodal variant (Vision2Text) is skipped with the same reason as gemma3: loading nested configs with overwritten kwargs isn't supported yet.

@LarsKlawitter
Copy link
Copy Markdown

Thanks for the quick turnaround! The GPU is currently blocked by an unrelated training job, should free up in ~2 days. Will report back here with results (or a minimal repro if anything breaks).

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, gemma4

Comment on lines +2168 to +2173
class Gemma4ForSequenceClassification(Gemma4PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Gemma4Model(config)
self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks fine to me but before merging I want to sync with @vasqu . We have another related PR discussing how to better enable SeqClf on VLMs (that can be maybe LLMs)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine for now, we can refactor this in the other PR(?)

I think we need to think a bit more either way how we properly do this

  • VLMs that can be used for text only
  • Different VLM class to have proper signatures

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, this kind of goes into the direction that we use a separate class for VLMs and text only

Imo, we need to think a bit more about this before merging. We will encounter similar things for each (mixed) modality then so we need to decide on a proper standard. It is kind of a harder issue as we have different signatures across them


class Gemma4PreTrainedModel(PreTrainedModel):
config: Gemma4Config
base_model_prefix = "model"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks potentially breaking because of the audio only model iirc. Any reason this change was really needed; would rather be explicit under the different ones for now

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vision/audio model has overriden its prefix, isn't it? The standard is to have model and I suggested this than manually adding model over two new classes

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh it does? Sorry not too familiar with gemma4, do we have tests for each standalone model to double check we didnt break existing stuff for those?

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for collective efforts, gemma4 has very extensive tests per modality

(super happy about that tbh)

Comment on lines +2178 to +2182
def get_input_embeddings(self):
return self.model.get_input_embeddings()

def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually still need these? EmbeddingAccessMixin should catch these with the base model, no?

Comment on lines +2168 to +2173
class Gemma4ForSequenceClassification(Gemma4PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Gemma4Model(config)
self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine for now, we can refactor this in the other PR(?)

I think we need to think a bit more either way how we properly do this

  • VLMs that can be used for text only
  • Different VLM class to have proper signatures

Comment on lines +394 to +396
@unittest.skip("Loading nested configs with overwritten `kwargs` isn't supported yet.")
def test_load_with_mismatched_shapes(self):
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ig it's the new class, would tbh just disable that one class and still test this

@zucchini-nlp
Copy link
Copy Markdown
Member

We will encounter similar things for each (mixed) modality then so we need to decide on a proper standard

@vasqu yeah, I agree with it 100%, we can't just list all possible args in signature, so if we want a "generic" class I think non-standard ones have to be consumed as kwargs. But again, we face the issue of different modality combination and personally, I don't like the idea of hosting separate class for each combination due to signature. Gemma4 here supports all 3 modalities for ex, and gemma3 supports only "image", etc.

So how important it is to align signature in your opinion? I think we can give up a little bit to reduce duplicate LOC, and instead add more info in class-level docstring This is a classifier model that works with text inputs and optionally can accept images/videos/audio

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 15, 2026

All of these are combined by the fact that they have the same text inputs, could we introduce new typed dicts for kwargs that are specific to each modality and add them to the generic one? I think this would kind of lean into what you want to do while having some insights on possible inputs via typed kwargs. We do something similar for processors already, no? The goal would be then to have more control tho that can filter out kwargs based on a simple flags, e.g. accepts_images=True

For example, the forward signature then can look like:

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        input_ids: torch.LongTensor | None = None,
        attention_mask: torch.Tensor | None = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: Cache | None = None,
        inputs_embeds: torch.FloatTensor | None = None,
        labels: torch.LongTensor | None = None,
        use_cache: bool | None = None,
        **kwargs: Unpack[MultiModalTransformersKwargs],
    ) -> SequenceClassifierOutputWithPast:
        kwargs = process_kwargs(kwargs)  # filter based on allowed modalities 

@zucchini-nlp
Copy link
Copy Markdown
Member

@Charly21r we will work internally on creating generic sequence classification classes for multimodal LLMs. Until then, let's put current PR on hold

I will tag you again when we can resume the PR

@LarsKlawitter
Copy link
Copy Markdown

FYI — the QLoRA classification use case is ready to test whenever your generic approach lands. We've been running Gemma 4 E4B seq-clf locally via a small wrapper following the Gemma 3 pattern; happy to validate whichever direction you land on. Not urgent from our side — we can unblock with the local wrapper in the meantime.

github-merge-queue Bot pushed a commit that referenced this pull request Apr 29, 2026
* chore(typing): add ty type checking for 3 pipeline files

Adds ty type checking coverage for:
- src/transformers/pipelines/feature_extraction.py
- src/transformers/pipelines/image_feature_extraction.py
- src/transformers/pipelines/video_classification.py

For Issues #45438

* fix: restore _typing.py in check_args per review feedback

---------

Co-authored-by: Tarek Ziade <tarek@ziade.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add Gemma4ForSequenceClassification (missing from gemma4 module — Gemma 2/3 have it)

4 participants