Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/gemma4.md
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,8 @@ print(processor.decode(outputs[0][input_len:], skip_special_tokens=False))

[[autodoc]] Gemma4ForConditionalGeneration
- forward

## Gemma4ForSequenceClassification

[[autodoc]] Gemma4ForSequenceClassification
- forward
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("gemma2", "Gemma2ForSequenceClassification"),
("gemma3", "Gemma3ForSequenceClassification"),
("gemma3_text", "Gemma3TextForSequenceClassification"),
("gemma4", "Gemma4ForSequenceClassification"),
("glm", "GlmForSequenceClassification"),
("glm4", "Glm4ForSequenceClassification"),
("gpt-sw3", "GPT2ForSequenceClassification"),
Expand Down
101 changes: 100 additions & 1 deletion src/transformers/models/gemma4/modeling_gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast
from ...modeling_outputs import (
BaseModelOutputWithPast,
BaseModelOutputWithPooling,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
Expand All @@ -52,6 +57,7 @@
auto_docstring,
can_return_tuple,
is_accelerate_available,
logging,
torch_compilable_check,
)
from ...utils.generic import maybe_autocast, merge_with_config_defaults
Expand All @@ -64,6 +70,9 @@
from accelerate.hooks import add_hook_to_module


logger = logging.get_logger(__name__)


@dataclass
@auto_docstring(
custom_intro="""
Expand Down Expand Up @@ -2646,6 +2655,95 @@ def create_masks_for_generate(
)


class Gemma4ForSequenceClassification(Gemma4PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Gemma4Model(config)
self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.get_input_embeddings()

def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)

@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
token_type_ids: torch.LongTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> SequenceClassifierOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""

transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
token_type_ids=token_type_ids,
use_cache=use_cache,
return_dict=True,
**kwargs,
)
hidden_states = transformer_outputs.last_hidden_state
logits = self.score(hidden_states)

if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]

if self.config.text_config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.text_config.pad_token_id is None:
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.text_config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]

loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)

return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)


__all__ = [
"Gemma4AudioModel",
"Gemma4ForCausalLM",
Expand All @@ -2654,4 +2752,5 @@ def create_masks_for_generate(
"Gemma4PreTrainedModel",
"Gemma4TextModel",
"Gemma4VisionModel",
"Gemma4ForSequenceClassification",
]
8 changes: 8 additions & 0 deletions src/transformers/models/gemma4/modular_gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@
Gemma3Attention,
Gemma3DecoderLayer,
Gemma3ForCausalLM,
Gemma3ForSequenceClassification,
Gemma3MLP,
Gemma3RotaryEmbedding,
Gemma3TextForSequenceClassification,
Gemma3TextModel,
Gemma3TextScaledWordEmbedding,
)
Expand Down Expand Up @@ -2203,6 +2205,11 @@ def prepare_inputs_for_generation(
return model_inputs


class Gemma4ForSequenceClassification(Gemma3ForSequenceClassification):
pass



__all__ = [
"Gemma4AudioModel",
"Gemma4ForCausalLM",
Expand All @@ -2211,4 +2218,5 @@ def prepare_inputs_for_generation(
"Gemma4PreTrainedModel",
"Gemma4TextModel",
"Gemma4VisionModel",
"Gemma4ForSequenceClassification",
]
5 changes: 3 additions & 2 deletions tests/models/gemma4/test_modeling_gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
AutoModelForCausalLM,
Gemma4ForCausalLM,
Gemma4ForConditionalGeneration,
Gemma4ForSequenceClassification,
Gemma4Model,
Gemma4Processor,
Gemma4TextModel,
Expand Down Expand Up @@ -226,7 +227,7 @@ def prepare_config_and_inputs_for_common(self):

@require_torch
class Gemma4Audio2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Gemma4Model, Gemma4ForConditionalGeneration) if is_torch_available() else ()
all_model_classes = (Gemma4Model, Gemma4ForConditionalGeneration, Gemma4ForSequenceClassification) if is_torch_available() else ()
all_generative_model_classes = (Gemma4ForConditionalGeneration,) if is_torch_available() else ()

def setUp(self):
Expand Down Expand Up @@ -378,7 +379,7 @@ def prepare_config_and_inputs_for_common(self):

@require_torch
class Gemma4Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Gemma4Model, Gemma4ForConditionalGeneration) if is_torch_available() else ()
all_model_classes = (Gemma4Model, Gemma4ForConditionalGeneration, Gemma4ForSequenceClassification) if is_torch_available() else ()
all_generative_model_classes = (Gemma4ForConditionalGeneration,) if is_torch_available() else ()
additional_model_inputs = ["mm_token_type_ids"]

Expand Down
Loading