diff --git a/docs/source/en/model_doc/gemma4.md b/docs/source/en/model_doc/gemma4.md index 0aed86af4199..18dc014ff2de 100644 --- a/docs/source/en/model_doc/gemma4.md +++ b/docs/source/en/model_doc/gemma4.md @@ -313,3 +313,8 @@ print(processor.decode(outputs[0][input_len:], skip_special_tokens=False)) [[autodoc]] Gemma4ForConditionalGeneration - forward + +## Gemma4ForSequenceClassification + +[[autodoc]] Gemma4ForSequenceClassification + - forward \ No newline at end of file diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 078a910c368e..7e3f08a0b994 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -1260,6 +1260,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("gemma2", "Gemma2ForSequenceClassification"), ("gemma3", "Gemma3ForSequenceClassification"), ("gemma3_text", "Gemma3TextForSequenceClassification"), + ("gemma4", "Gemma4ForSequenceClassification"), ("glm", "GlmForSequenceClassification"), ("glm4", "Glm4ForSequenceClassification"), ("gpt-sw3", "GPT2ForSequenceClassification"), diff --git a/src/transformers/models/gemma4/modeling_gemma4.py b/src/transformers/models/gemma4/modeling_gemma4.py index 88c340a9414b..f8d747f38d1d 100644 --- a/src/transformers/models/gemma4/modeling_gemma4.py +++ b/src/transformers/models/gemma4/modeling_gemma4.py @@ -42,7 +42,12 @@ ) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -52,6 +57,7 @@ auto_docstring, can_return_tuple, is_accelerate_available, + logging, torch_compilable_check, ) from ...utils.generic import maybe_autocast, merge_with_config_defaults @@ -64,6 +70,9 @@ from accelerate.hooks import add_hook_to_module +logger = logging.get_logger(__name__) + + @dataclass @auto_docstring( custom_intro=""" @@ -2646,6 +2655,95 @@ def create_masks_for_generate( ) +class Gemma4ForSequenceClassification(Gemma4PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Gemma4Model(config) + self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + token_type_ids: torch.LongTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + token_type_ids=token_type_ids, + use_cache=use_cache, + return_dict=True, + **kwargs, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.text_config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.text_config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.text_config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + __all__ = [ "Gemma4AudioModel", "Gemma4ForCausalLM", @@ -2654,4 +2752,5 @@ def create_masks_for_generate( "Gemma4PreTrainedModel", "Gemma4TextModel", "Gemma4VisionModel", + "Gemma4ForSequenceClassification", ] diff --git a/src/transformers/models/gemma4/modular_gemma4.py b/src/transformers/models/gemma4/modular_gemma4.py index 0cddf103f3bf..c5036e0ca222 100644 --- a/src/transformers/models/gemma4/modular_gemma4.py +++ b/src/transformers/models/gemma4/modular_gemma4.py @@ -52,8 +52,10 @@ Gemma3Attention, Gemma3DecoderLayer, Gemma3ForCausalLM, + Gemma3ForSequenceClassification, Gemma3MLP, Gemma3RotaryEmbedding, + Gemma3TextForSequenceClassification, Gemma3TextModel, Gemma3TextScaledWordEmbedding, ) @@ -2203,6 +2205,11 @@ def prepare_inputs_for_generation( return model_inputs +class Gemma4ForSequenceClassification(Gemma3ForSequenceClassification): + pass + + + __all__ = [ "Gemma4AudioModel", "Gemma4ForCausalLM", @@ -2211,4 +2218,5 @@ def prepare_inputs_for_generation( "Gemma4PreTrainedModel", "Gemma4TextModel", "Gemma4VisionModel", + "Gemma4ForSequenceClassification", ] diff --git a/tests/models/gemma4/test_modeling_gemma4.py b/tests/models/gemma4/test_modeling_gemma4.py index 91694b5c1d45..d05de52bb9ed 100644 --- a/tests/models/gemma4/test_modeling_gemma4.py +++ b/tests/models/gemma4/test_modeling_gemma4.py @@ -48,6 +48,7 @@ AutoModelForCausalLM, Gemma4ForCausalLM, Gemma4ForConditionalGeneration, + Gemma4ForSequenceClassification, Gemma4Model, Gemma4Processor, Gemma4TextModel, @@ -226,7 +227,7 @@ def prepare_config_and_inputs_for_common(self): @require_torch class Gemma4Audio2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): - all_model_classes = (Gemma4Model, Gemma4ForConditionalGeneration) if is_torch_available() else () + all_model_classes = (Gemma4Model, Gemma4ForConditionalGeneration, Gemma4ForSequenceClassification) if is_torch_available() else () all_generative_model_classes = (Gemma4ForConditionalGeneration,) if is_torch_available() else () def setUp(self): @@ -378,7 +379,7 @@ def prepare_config_and_inputs_for_common(self): @require_torch class Gemma4Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): - all_model_classes = (Gemma4Model, Gemma4ForConditionalGeneration) if is_torch_available() else () + all_model_classes = (Gemma4Model, Gemma4ForConditionalGeneration, Gemma4ForSequenceClassification) if is_torch_available() else () all_generative_model_classes = (Gemma4ForConditionalGeneration,) if is_torch_available() else () additional_model_inputs = ["mm_token_type_ids"]