Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,14 @@
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...models.modernbert.modeling_modernbert import (
ModernBertEmbeddings,
ModernBertMLP,
ModernBertPredictionHead,
ModernBertPreTrainedModel,
ModernBertRotaryEmbedding,
apply_rotary_pos_emb,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.deprecation import deprecate_kwarg
Expand All @@ -51,6 +45,126 @@
logger = logging.get_logger(__name__)


class ModernBertDecoderEmbeddings(nn.Module):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
"""

def __init__(self, config: ModernBertDecoderConfig):
super().__init__()
self.config = config
self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.drop = nn.Dropout(config.embedding_dropout)

@torch.compile(dynamic=True)
def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
return self.drop(self.norm(self.tok_embeddings(input_ids)))

def forward(
self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = self.drop(self.norm(inputs_embeds))
else:
hidden_states = (
self.compiled_embeddings(input_ids)
if self.config.reference_compile
else self.drop(self.norm(self.tok_embeddings(input_ids)))
)
return hidden_states


class ModernBertDecoderMLP(nn.Module):
"""Applies the GLU at the end of each ModernBertDecoder layer.

Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
"""

def __init__(self, config: ModernBertDecoderConfig):
super().__init__()
self.config = config
self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias)
self.act = ACT2FN[config.hidden_activation]
self.drop = nn.Dropout(config.mlp_dropout)
self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
return self.Wo(self.drop(self.act(input) * gate))


class ModernBertDecoderRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`

def __init__(self, config: ModernBertDecoderConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq

@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def eager_attention_forward(
module: "ModernBertDecoderAttention",
query: torch.Tensor,
Expand Down Expand Up @@ -173,7 +287,7 @@ def __init__(self, config: ModernBertDecoderConfig, layer_idx: Optional[int] = N
)
self.attn = ModernBertDecoderAttention(config=config, layer_idx=layer_idx)
self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.mlp = ModernBertMLP(config)
self.mlp = ModernBertDecoderMLP(config)

@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
Expand Down Expand Up @@ -218,12 +332,28 @@ def forward(
return hidden_states


class ModernBertDecoderPredictionHead(nn.Module):
def __init__(self, config: ModernBertDecoderConfig):
super().__init__()
self.config = config
self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
self.act = ACT2FN[config.classifier_activation]
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(self.act(self.dense(hidden_states)))


@auto_docstring
class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel):
class ModernBertDecoderPreTrainedModel(PreTrainedModel):
config: ModernBertDecoderConfig
_skip_keys_device_placement = ["past_key_values"]
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["ModernBertDecoderLayer"]
_can_compile_fullgraph = False
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_skip_keys_device_placement = ["past_key_values"]
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": ModernBertDecoderLayer,
Expand Down Expand Up @@ -255,56 +385,42 @@ def init_weight(module: nn.Module, std: float):
"final_out": self.config.hidden_size**-0.5,
}

if isinstance(module, ModernBertEmbeddings):
if isinstance(module, ModernBertDecoderEmbeddings):
init_weight(module.tok_embeddings, stds["embedding"])
elif isinstance(module, ModernBertMLP):
elif isinstance(module, ModernBertDecoderMLP):
init_weight(module.Wi, stds["in"])
init_weight(module.Wo, stds["out"])
elif isinstance(module, ModernBertDecoderAttention):
init_weight(module.q_proj, stds["in"])
init_weight(module.k_proj, stds["in"])
init_weight(module.v_proj, stds["in"])
init_weight(module.Wo, stds["out"])
elif isinstance(module, ModernBertPredictionHead):
elif isinstance(module, ModernBertDecoderPredictionHead):
init_weight(module.dense, stds["out"])
elif isinstance(module, ModernBertDecoderForSequenceClassification):
elif module.__class__.__name__ == "ModernBertDecoderForSequenceClassification":
init_weight(module.classifier, stds["final_out"])
elif isinstance(module, ModernBertDecoderForCausalLM):
elif module.__class__.__name__ == "ModernBertDecoderForCausalLM":
init_weight(module.decoder, stds["out"])
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
if module.bias is not None:
module.bias.data.zero_()

def _check_and_adjust_attn_implementation(
self, attn_implementation: Optional[str], is_init_check: bool = False
) -> str:
"""We overwrite this to make sdpa the first selection again if nothing was requested."""

try:
attn_implementation = (
"sdpa" if attn_implementation is None and self._sdpa_can_dispatch() else attn_implementation
)
except (ValueError, ImportError):
pass

return super()._check_and_adjust_attn_implementation(attn_implementation, is_init_check)


@auto_docstring
class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
def __init__(self, config: ModernBertDecoderConfig):
super().__init__(config)
self.config = config
self.embeddings = ModernBertEmbeddings(config)
self.embeddings = ModernBertDecoderEmbeddings(config)
self.layers = nn.ModuleList(
[ModernBertDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.gradient_checkpointing = False

self.global_rotary_emb = ModernBertRotaryEmbedding(config=config)
self.local_rotary_emb = ModernBertRotaryEmbedding(config=config)
self.global_rotary_emb = ModernBertDecoderRotaryEmbedding(config=config)
self.local_rotary_emb = ModernBertDecoderRotaryEmbedding(config=config)

self.post_init()

Expand Down Expand Up @@ -407,7 +523,7 @@ def __init__(self, config: ModernBertDecoderConfig):
super().__init__(config)
self.config = config
self.model = ModernBertDecoderModel(config)
self.lm_head = ModernBertPredictionHead(config)
self.lm_head = ModernBertDecoderPredictionHead(config)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)

# Initialize weights and apply final processing
Expand Down Expand Up @@ -515,7 +631,7 @@ def __init__(self, config: ModernBertDecoderConfig):
self.num_labels = config.num_labels
self.model = ModernBertDecoderModel(config)

self.head = ModernBertPredictionHead(config)
self.head = ModernBertDecoderPredictionHead(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=config.classifier_bias)
self.drop = torch.nn.Dropout(config.classifier_dropout)

Expand Down
Loading