diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index fd3dd9ea9820..0869467f0bb4 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -27,20 +27,14 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS -from ...models.modernbert.modeling_modernbert import ( - ModernBertEmbeddings, - ModernBertMLP, - ModernBertPredictionHead, - ModernBertPreTrainedModel, - ModernBertRotaryEmbedding, - apply_rotary_pos_emb, -) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.deprecation import deprecate_kwarg @@ -51,6 +45,126 @@ logger = logging.get_logger(__name__) +class ModernBertDecoderEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config: ModernBertDecoderConfig): + super().__init__() + self.config = config + self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.drop = nn.Dropout(config.embedding_dropout) + + @torch.compile(dynamic=True) + def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: + return self.drop(self.norm(self.tok_embeddings(input_ids))) + + def forward( + self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = self.drop(self.norm(inputs_embeds)) + else: + hidden_states = ( + self.compiled_embeddings(input_ids) + if self.config.reference_compile + else self.drop(self.norm(self.tok_embeddings(input_ids))) + ) + return hidden_states + + +class ModernBertDecoderMLP(nn.Module): + """Applies the GLU at the end of each ModernBertDecoder layer. + + Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` + and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. + """ + + def __init__(self, config: ModernBertDecoderConfig): + super().__init__() + self.config = config + self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias) + self.act = ACT2FN[config.hidden_activation] + self.drop = nn.Dropout(config.mlp_dropout) + self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input, gate = self.Wi(hidden_states).chunk(2, dim=-1) + return self.Wo(self.drop(self.act(input) * gate)) + + +class ModernBertDecoderRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: ModernBertDecoderConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def eager_attention_forward( module: "ModernBertDecoderAttention", query: torch.Tensor, @@ -173,7 +287,7 @@ def __init__(self, config: ModernBertDecoderConfig, layer_idx: Optional[int] = N ) self.attn = ModernBertDecoderAttention(config=config, layer_idx=layer_idx) self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) - self.mlp = ModernBertMLP(config) + self.mlp = ModernBertDecoderMLP(config) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( @@ -218,12 +332,28 @@ def forward( return hidden_states +class ModernBertDecoderPredictionHead(nn.Module): + def __init__(self, config: ModernBertDecoderConfig): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) + self.act = ACT2FN[config.classifier_activation] + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(self.act(self.dense(hidden_states))) + + @auto_docstring -class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel): +class ModernBertDecoderPreTrainedModel(PreTrainedModel): config: ModernBertDecoderConfig - _skip_keys_device_placement = ["past_key_values"] + base_model_prefix = "model" + supports_gradient_checkpointing = True _no_split_modules = ["ModernBertDecoderLayer"] - _can_compile_fullgraph = False + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _skip_keys_device_placement = ["past_key_values"] _supports_attention_backend = True _can_record_outputs = { "hidden_states": ModernBertDecoderLayer, @@ -255,9 +385,9 @@ def init_weight(module: nn.Module, std: float): "final_out": self.config.hidden_size**-0.5, } - if isinstance(module, ModernBertEmbeddings): + if isinstance(module, ModernBertDecoderEmbeddings): init_weight(module.tok_embeddings, stds["embedding"]) - elif isinstance(module, ModernBertMLP): + elif isinstance(module, ModernBertDecoderMLP): init_weight(module.Wi, stds["in"]) init_weight(module.Wo, stds["out"]) elif isinstance(module, ModernBertDecoderAttention): @@ -265,46 +395,32 @@ def init_weight(module: nn.Module, std: float): init_weight(module.k_proj, stds["in"]) init_weight(module.v_proj, stds["in"]) init_weight(module.Wo, stds["out"]) - elif isinstance(module, ModernBertPredictionHead): + elif isinstance(module, ModernBertDecoderPredictionHead): init_weight(module.dense, stds["out"]) - elif isinstance(module, ModernBertDecoderForSequenceClassification): + elif module.__class__.__name__ == "ModernBertDecoderForSequenceClassification": init_weight(module.classifier, stds["final_out"]) - elif isinstance(module, ModernBertDecoderForCausalLM): + elif module.__class__.__name__ == "ModernBertDecoderForCausalLM": init_weight(module.decoder, stds["out"]) elif isinstance(module, nn.LayerNorm): module.weight.data.fill_(1.0) if module.bias is not None: module.bias.data.zero_() - def _check_and_adjust_attn_implementation( - self, attn_implementation: Optional[str], is_init_check: bool = False - ) -> str: - """We overwrite this to make sdpa the first selection again if nothing was requested.""" - - try: - attn_implementation = ( - "sdpa" if attn_implementation is None and self._sdpa_can_dispatch() else attn_implementation - ) - except (ValueError, ImportError): - pass - - return super()._check_and_adjust_attn_implementation(attn_implementation, is_init_check) - @auto_docstring class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel): def __init__(self, config: ModernBertDecoderConfig): super().__init__(config) self.config = config - self.embeddings = ModernBertEmbeddings(config) + self.embeddings = ModernBertDecoderEmbeddings(config) self.layers = nn.ModuleList( [ModernBertDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.gradient_checkpointing = False - self.global_rotary_emb = ModernBertRotaryEmbedding(config=config) - self.local_rotary_emb = ModernBertRotaryEmbedding(config=config) + self.global_rotary_emb = ModernBertDecoderRotaryEmbedding(config=config) + self.local_rotary_emb = ModernBertDecoderRotaryEmbedding(config=config) self.post_init() @@ -407,7 +523,7 @@ def __init__(self, config: ModernBertDecoderConfig): super().__init__(config) self.config = config self.model = ModernBertDecoderModel(config) - self.lm_head = ModernBertPredictionHead(config) + self.lm_head = ModernBertDecoderPredictionHead(config) self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias) # Initialize weights and apply final processing @@ -515,7 +631,7 @@ def __init__(self, config: ModernBertDecoderConfig): self.num_labels = config.num_labels self.model = ModernBertDecoderModel(config) - self.head = ModernBertPredictionHead(config) + self.head = ModernBertDecoderPredictionHead(config) self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=config.classifier_bias) self.drop = torch.nn.Dropout(config.classifier_dropout) diff --git a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py index 823248f6b40f..0f0c752a4d63 100644 --- a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py @@ -28,7 +28,11 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS -from ...models.modernbert.modeling_modernbert import ( +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs +from ..modernbert.modeling_modernbert import ( ModernBertEmbeddings, ModernBertMLP, ModernBertPredictionHead, @@ -36,10 +40,6 @@ ModernBertRotaryEmbedding, apply_rotary_pos_emb, ) -from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.deprecation import deprecate_kwarg -from ...utils.generic import check_model_inputs logger = logging.get_logger(__name__) @@ -228,6 +228,18 @@ def __init__( self.sliding_window = local_attention // 2 if local_attention else -1 +class ModernBertDecoderEmbeddings(ModernBertEmbeddings): + pass + + +class ModernBertDecoderMLP(ModernBertMLP): + pass + + +class ModernBertDecoderRotaryEmbedding(ModernBertRotaryEmbedding): + pass + + def eager_attention_forward( module: "ModernBertDecoderAttention", query: torch.Tensor, @@ -350,7 +362,7 @@ def __init__(self, config: ModernBertDecoderConfig, layer_idx: Optional[int] = N ) self.attn = ModernBertDecoderAttention(config=config, layer_idx=layer_idx) self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) - self.mlp = ModernBertMLP(config) + self.mlp = ModernBertDecoderMLP(config) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( @@ -395,12 +407,15 @@ def forward( return hidden_states +class ModernBertDecoderPredictionHead(ModernBertPredictionHead): + pass + + @auto_docstring class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel): - config: ModernBertDecoderConfig _skip_keys_device_placement = ["past_key_values"] _no_split_modules = ["ModernBertDecoderLayer"] - _can_compile_fullgraph = False + _supports_flex_attn = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": ModernBertDecoderLayer, @@ -432,9 +447,9 @@ def init_weight(module: nn.Module, std: float): "final_out": self.config.hidden_size**-0.5, } - if isinstance(module, ModernBertEmbeddings): + if isinstance(module, ModernBertDecoderEmbeddings): init_weight(module.tok_embeddings, stds["embedding"]) - elif isinstance(module, ModernBertMLP): + elif isinstance(module, ModernBertDecoderMLP): init_weight(module.Wi, stds["in"]) init_weight(module.Wo, stds["out"]) elif isinstance(module, ModernBertDecoderAttention): @@ -442,30 +457,25 @@ def init_weight(module: nn.Module, std: float): init_weight(module.k_proj, stds["in"]) init_weight(module.v_proj, stds["in"]) init_weight(module.Wo, stds["out"]) - elif isinstance(module, ModernBertPredictionHead): + elif isinstance(module, ModernBertDecoderPredictionHead): init_weight(module.dense, stds["out"]) - elif isinstance(module, ModernBertDecoderForSequenceClassification): + elif module.__class__.__name__ == "ModernBertDecoderForSequenceClassification": init_weight(module.classifier, stds["final_out"]) - elif isinstance(module, ModernBertDecoderForCausalLM): + elif module.__class__.__name__ == "ModernBertDecoderForCausalLM": init_weight(module.decoder, stds["out"]) elif isinstance(module, nn.LayerNorm): module.weight.data.fill_(1.0) if module.bias is not None: module.bias.data.zero_() - def _check_and_adjust_attn_implementation( - self, attn_implementation: Optional[str], is_init_check: bool = False - ) -> str: - """We overwrite this to make sdpa the first selection again if nothing was requested.""" + def _check_and_adjust_attn_implementation(self, attn_implementation, is_init_check): + raise AttributeError("No need to inherit!") - try: - attn_implementation = ( - "sdpa" if attn_implementation is None and self._sdpa_can_dispatch() else attn_implementation - ) - except (ValueError, ImportError): - pass + def _maybe_set_compile(self): + raise AttributeError("No need to inherit!") - return super()._check_and_adjust_attn_implementation(attn_implementation, is_init_check) + def resize_token_embeddings(self, *args, **kwargs): + raise AttributeError("No need to inherit!") @auto_docstring @@ -473,15 +483,15 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel): def __init__(self, config: ModernBertDecoderConfig): super().__init__(config) self.config = config - self.embeddings = ModernBertEmbeddings(config) + self.embeddings = ModernBertDecoderEmbeddings(config) self.layers = nn.ModuleList( [ModernBertDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.gradient_checkpointing = False - self.global_rotary_emb = ModernBertRotaryEmbedding(config=config) - self.local_rotary_emb = ModernBertRotaryEmbedding(config=config) + self.global_rotary_emb = ModernBertDecoderRotaryEmbedding(config=config) + self.local_rotary_emb = ModernBertDecoderRotaryEmbedding(config=config) self.post_init() @@ -584,7 +594,7 @@ def __init__(self, config: ModernBertDecoderConfig): super().__init__(config) self.config = config self.model = ModernBertDecoderModel(config) - self.lm_head = ModernBertPredictionHead(config) + self.lm_head = ModernBertDecoderPredictionHead(config) self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias) # Initialize weights and apply final processing @@ -692,7 +702,7 @@ def __init__(self, config: ModernBertDecoderConfig): self.num_labels = config.num_labels self.model = ModernBertDecoderModel(config) - self.head = ModernBertPredictionHead(config) + self.head = ModernBertDecoderPredictionHead(config) self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=config.classifier_bias) self.drop = torch.nn.Dropout(config.classifier_dropout)