diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index fb82ea374c92..00870f152136 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -713,6 +713,8 @@ title: MegatronBERT - local: model_doc/megatron_gpt2 title: MegatronGPT2 + - local: model_doc/minicpm3 + title: MiniCPM3 - local: model_doc/minimax title: MiniMax - local: model_doc/minimax_m2 diff --git a/docs/source/en/model_doc/minicpm3.md b/docs/source/en/model_doc/minicpm3.md new file mode 100644 index 000000000000..e812e594ac4c --- /dev/null +++ b/docs/source/en/model_doc/minicpm3.md @@ -0,0 +1,45 @@ + + +# MiniCPM3 + +## Overview + +The MiniCPM3 model was proposed in [MiniCPM: Unveiling the Potential of Small Language Models with Scalable Training Strategies](https://huggingface.co/papers/2404.06395) by OpenBMB. + +MiniCPM3-4B is a dense language model that uses Multi-head Latent Attention (MLA) for efficient KV cache compression, combined with embedding scaling, depth-dependent residual scaling, and logit scaling for stable training. Despite its compact 4B parameter size, it achieves performance comparable to larger 7B-9B models. + +This model was contributed by [aliyevaladddin](https://github.com/aliyevaladddin). +The original code can be found [here](https://huggingface.co/openbmb/MiniCPM3-4B). + +## MiniCPM3Config + +[[autodoc]] MiniCPM3Config + +## MiniCPM3Model + +[[autodoc]] MiniCPM3Model + - forward + +## MiniCPM3ForCausalLM + +[[autodoc]] MiniCPM3ForCausalLM + - forward + +## MiniCPM3ForSequenceClassification + +[[autodoc]] MiniCPM3ForSequenceClassification + - forward diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 03af4c0819bd..71394c6ae319 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -699,6 +699,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("marian", "MarianForCausalLM"), ("mbart", "MBartForCausalLM"), ("megatron-bert", "MegatronBertForCausalLM"), + ("minicpm3", "MiniCPM3ForCausalLM"), ("minimax", "MiniMaxForCausalLM"), ("minimax_m2", "MiniMaxM2ForCausalLM"), ("ministral", "MinistralForCausalLM"), @@ -1299,6 +1300,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("markuplm", "MarkupLMForSequenceClassification"), ("mbart", "MBartForSequenceClassification"), ("megatron-bert", "MegatronBertForSequenceClassification"), + ("minicpm3", "MiniCPM3ForSequenceClassification"), ("minimax", "MiniMaxForSequenceClassification"), ("ministral", "MinistralForSequenceClassification"), ("ministral3", "Ministral3ForSequenceClassification"), diff --git a/src/transformers/models/minicpm3/__init__.py b/src/transformers/models/minicpm3/__init__.py new file mode 100644 index 000000000000..405741de6116 --- /dev/null +++ b/src/transformers/models/minicpm3/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_minicpm3 import * + from .modeling_minicpm3 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/minicpm3/configuration_minicpm3.py b/src/transformers/models/minicpm3/configuration_minicpm3.py new file mode 100644 index 000000000000..ad4645318054 --- /dev/null +++ b/src/transformers/models/minicpm3/configuration_minicpm3.py @@ -0,0 +1,126 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/minicpm3/modular_minicpm3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_minicpm3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters +from ...utils import auto_docstring + + +@auto_docstring(checkpoint="openbmb/MiniCPM3-4B") +@strict +class MiniCPM3Config(PreTrainedConfig): + r""" + kv_lora_rank (`int`, *optional*, defaults to 256): + Rank of the low-rank KV projection in multi-head latent attention. + q_lora_rank (`int`, *optional*, defaults to 768): + Rank of the low-rank query projection in multi-head latent attention. + qk_nope_head_dim (`int`, *optional*, defaults to 64): + Dimension of the non-RoPE part of each query/key head. + qk_rope_head_dim (`int`, *optional*, defaults to 32): + Dimension of the RoPE part of each query/key head. + v_head_dim (`int`, *optional*, defaults to 128): + Dimension of each value head. + scale_emb (`int`, *optional*, defaults to 1): + Scaling factor applied to input embeddings. + scale_depth (`float`, *optional*, defaults to 1.0): + Scaling factor for residual connections, applied as `scale_depth / sqrt(num_hidden_layers)`. + dim_model_base (`int`, *optional*, defaults to 1): + Base model dimension used to scale logits before the language model head. + + Example: + + ```python + >>> from transformers import MiniCPM3Model, MiniCPM3Config + >>> configuration = MiniCPM3Config() + >>> model = MiniCPM3Model(configuration) + >>> print(model.config) + ``` + """ + + model_type = "minicpm3" + keys_to_ignore_at_inference = ["past_key_values"] + + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.q_b_proj": "colwise", + "layers.*.self_attn.kv_a_proj_with_mqa": "mla_kv_a_proj", + "layers.*.self_attn.kv_b_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + vocab_size: int = 73448 + hidden_size: int = 2560 + intermediate_size: int = 6400 + num_hidden_layers: int = 62 + num_attention_heads: int = 40 + num_key_value_heads: int | None = 40 + hidden_act: str = "silu" + max_position_embeddings: int = 32768 + initializer_range: float = 0.1 + rms_norm_eps: float = 1e-5 + use_cache: bool = True + pad_token_id: int | None = None + bos_token_id: int | None = 1 + eos_token_id: int | list[int] | None = 2 + pretraining_tp: int | None = 1 + tie_word_embeddings: bool = False + rope_parameters: RopeParameters | dict | None = None + attention_bias: bool = False + attention_dropout: float | None = 0.0 + mlp_bias: bool = False + head_dim: int | None = None + kv_lora_rank: int = 256 + q_lora_rank: int | None = 768 + qk_nope_head_dim: int = 64 + qk_rope_head_dim: int = 32 + v_head_dim: int = 128 + scale_emb: int = 1 + scale_depth: float = 1.0 + dim_model_base: int = 1 + + def __post_init__(self, **kwargs): + self.head_dim = self.qk_rope_head_dim + if self.head_dim is None: + self.head_dim = self.hidden_size // self.num_attention_heads + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + super().__post_init__(**kwargs) + + def validate_architecture(self): + """Part of `@strict`-powered validation. Validates the architecture of the config.""" + if self.hidden_size % self.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention " + f"heads ({self.num_attention_heads})." + ) + + +__all__ = ["MiniCPM3Config"] diff --git a/src/transformers/models/minicpm3/modeling_minicpm3.py b/src/transformers/models/minicpm3/modeling_minicpm3.py new file mode 100644 index 000000000000..850140c782ee --- /dev/null +++ b/src/transformers/models/minicpm3/modeling_minicpm3.py @@ -0,0 +1,522 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/minicpm3/modular_minicpm3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_minicpm3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Callable +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from .configuration_minicpm3 import MiniCPM3Config + + +@use_kernel_forward_from_hub("RMSNorm") +class MiniCPM3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + MiniCPM3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class MiniCPM3RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: MiniCPM3Config, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: MiniCPM3Config | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation + freqs_cis = freqs_cis * self.attention_scaling + + return freqs_cis + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + + # Broadcast to [1, 1, seq_len, dim // 2] + freqs_cis = freqs_cis.unsqueeze(1).to(xq_.device) + + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) + return xq_out, xk_out + + +class MiniCPM3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MiniCPM3Config, layer_idx: int | None = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.max_position_embeddings = config.max_position_embeddings + + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + + self.is_causal = True + + if self.q_lora_rank is None: + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.qk_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = MiniCPM3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = MiniCPM3RMSNorm(config.kv_lora_rank) + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_heads * (self.qk_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + + self.scaling = self.qk_head_dim ** (-0.5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(query_shape).transpose(1, 2) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_nope, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_nope = self.kv_b_proj(self.kv_a_layernorm(k_nope)).view(key_shape).transpose(1, 2) + k_nope, value_states = torch.split(k_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_pe = k_pe.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + + q_pe, k_pe = apply_rotary_emb(q_pe, k_pe, position_embeddings.to(q_pe.device)) + + k_pe = k_pe.expand(*k_nope.shape[:-1], -1) + query_states = torch.cat((q_nope, q_pe), dim=-1) + key_states = torch.cat((k_nope, k_pe), dim=-1) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class MiniCPM3MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MiniCPM3DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: MiniCPM3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = MiniCPM3Attention(config=config, layer_idx=layer_idx) + self.mlp = MiniCPM3MLP(config) + self.input_layernorm = MiniCPM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MiniCPM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.scale_depth = config.scale_depth + self.num_hidden_layers = config.num_hidden_layers + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers)) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers)) + return hidden_states + + +@auto_docstring +class MiniCPM3PreTrainedModel(PreTrainedModel): + config: MiniCPM3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MiniCPM3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": MiniCPM3DecoderLayer, + "attentions": MiniCPM3Attention, + } + + +@auto_docstring +class MiniCPM3Model(MiniCPM3PreTrainedModel): + def __init__(self, config: MiniCPM3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MiniCPM3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = MiniCPM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = MiniCPM3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) * self.config.scale_emb + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_embeddings=position_embeddings, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class MiniCPM3ForCausalLM(MiniCPM3PreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = MiniCPM3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, MiniCPM3ForCausalLM + + >>> model = MiniCPM3ForCausalLM.from_pretrained("openbmb/MiniCPM3-4B") + >>> tokenizer = AutoTokenizer.from_pretrained("openbmb/MiniCPM3-4B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head( + hidden_states[:, slice_indices, :] / (self.config.hidden_size / self.config.dim_model_base) + ) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class MiniCPM3ForSequenceClassification(GenericForSequenceClassification, MiniCPM3PreTrainedModel): + pass + + +__all__ = ["MiniCPM3PreTrainedModel", "MiniCPM3Model", "MiniCPM3ForCausalLM", "MiniCPM3ForSequenceClassification"] diff --git a/src/transformers/models/minicpm3/modular_minicpm3.py b/src/transformers/models/minicpm3/modular_minicpm3.py new file mode 100644 index 000000000000..8b551853b132 --- /dev/null +++ b/src/transformers/models/minicpm3/modular_minicpm3.py @@ -0,0 +1,342 @@ +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +from huggingface_hub.dataclasses import strict +from torch import nn + +from ...cache_utils import Cache, DynamicCache +from ...masking_utils import create_causal_mask +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import RopeParameters +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from ..deepseek_v2.modeling_deepseek_v2 import ( + DeepseekV2Attention, + DeepseekV2RotaryEmbedding, +) +from ..llama.configuration_llama import LlamaConfig +from ..llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaMLP, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, +) + + +logger = logging.get_logger(__name__) + + +@auto_docstring(checkpoint="openbmb/MiniCPM3-4B") +@strict +class MiniCPM3Config(LlamaConfig): + r""" + kv_lora_rank (`int`, *optional*, defaults to 256): + Rank of the low-rank KV projection in multi-head latent attention. + q_lora_rank (`int`, *optional*, defaults to 768): + Rank of the low-rank query projection in multi-head latent attention. + qk_nope_head_dim (`int`, *optional*, defaults to 64): + Dimension of the non-RoPE part of each query/key head. + qk_rope_head_dim (`int`, *optional*, defaults to 32): + Dimension of the RoPE part of each query/key head. + v_head_dim (`int`, *optional*, defaults to 128): + Dimension of each value head. + scale_emb (`int`, *optional*, defaults to 1): + Scaling factor applied to input embeddings. + scale_depth (`float`, *optional*, defaults to 1.0): + Scaling factor for residual connections, applied as `scale_depth / sqrt(num_hidden_layers)`. + dim_model_base (`int`, *optional*, defaults to 1): + Base model dimension used to scale logits before the language model head. + + Example: + + ```python + >>> from transformers import MiniCPM3Model, MiniCPM3Config + >>> configuration = MiniCPM3Config() + >>> model = MiniCPM3Model(configuration) + >>> print(model.config) + ``` + """ + + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.q_b_proj": "colwise", + "layers.*.self_attn.kv_a_proj_with_mqa": "mla_kv_a_proj", + "layers.*.self_attn.kv_b_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + + model_type = "minicpm3" + keys_to_ignore_at_inference = ["past_key_values"] + + vocab_size: int = 73448 + hidden_size: int = 2560 + intermediate_size: int = 6400 + num_hidden_layers: int = 62 + num_attention_heads: int = 40 + num_key_value_heads: int | None = 40 + hidden_act: str = "silu" + max_position_embeddings: int = 32768 + initializer_range: float = 0.1 + rms_norm_eps: float = 1e-5 + use_cache: bool = True + pad_token_id: int | None = None + bos_token_id: int | None = 1 + eos_token_id: int | list[int] | None = 2 + tie_word_embeddings: bool = False + rope_parameters: RopeParameters | dict | None = None + attention_bias: bool = False + attention_dropout: float | None = 0.0 + mlp_bias: bool = False + kv_lora_rank: int = 256 + q_lora_rank: int | None = 768 + qk_nope_head_dim: int = 64 + qk_rope_head_dim: int = 32 + v_head_dim: int = 128 + scale_emb: int = 1 + scale_depth: float = 1.0 + dim_model_base: int = 1 + + def __post_init__(self, **kwargs): + self.head_dim = self.qk_rope_head_dim + super().__post_init__(**kwargs) + + +class MiniCPM3RMSNorm(LlamaRMSNorm): + pass + + +class MiniCPM3RotaryEmbedding(DeepseekV2RotaryEmbedding): + pass + + +class MiniCPM3Attention(DeepseekV2Attention): + pass + + +class MiniCPM3MLP(LlamaMLP): + pass + + +class MiniCPM3DecoderLayer(LlamaDecoderLayer): + def __init__(self, config: MiniCPM3Config, layer_idx: int): + super().__init__(config, layer_idx) + self.self_attn = MiniCPM3Attention(config=config, layer_idx=layer_idx) + self.mlp = MiniCPM3MLP(config) + self.input_layernorm = MiniCPM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MiniCPM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.scale_depth = config.scale_depth + self.num_hidden_layers = config.num_hidden_layers + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers)) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers)) + return hidden_states + + +class MiniCPM3PreTrainedModel(LlamaPreTrainedModel): + pass + + +@auto_docstring +class MiniCPM3Model(LlamaModel): + def __init__(self, config: MiniCPM3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MiniCPM3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = MiniCPM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = MiniCPM3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) * self.config.scale_emb + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_embeddings=position_embeddings, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class MiniCPM3ForCausalLM(LlamaForCausalLM): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = MiniCPM3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, MiniCPM3ForCausalLM + + >>> model = MiniCPM3ForCausalLM.from_pretrained("openbmb/MiniCPM3-4B") + >>> tokenizer = AutoTokenizer.from_pretrained("openbmb/MiniCPM3-4B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head( + hidden_states[:, slice_indices, :] / (self.config.hidden_size / self.config.dim_model_base) + ) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class MiniCPM3ForSequenceClassification(LlamaForSequenceClassification): + pass + + +__all__ = [ + "MiniCPM3PreTrainedModel", + "MiniCPM3Model", + "MiniCPM3ForCausalLM", + "MiniCPM3ForSequenceClassification", + "MiniCPM3Config", +] diff --git a/tests/models/minicpm3/__init__.py b/tests/models/minicpm3/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/minicpm3/test_modeling_minicpm3.py b/tests/models/minicpm3/test_modeling_minicpm3.py new file mode 100644 index 000000000000..a273d2ba93b3 --- /dev/null +++ b/tests/models/minicpm3/test_modeling_minicpm3.py @@ -0,0 +1,136 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch MiniCPM3 model.""" + +import unittest + +from transformers import Cache, is_torch_available +from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester + + +if is_torch_available(): + import torch + + from transformers import MiniCPM3ForCausalLM, MiniCPM3Model + from transformers.models.minicpm3.modeling_minicpm3 import MiniCPM3RotaryEmbedding + + +class MiniCPM3ModelTester(CausalLMModelTester): + if is_torch_available(): + base_model_class = MiniCPM3Model + + def __init__( + self, + parent, + kv_lora_rank=32, + q_lora_rank=16, + qk_nope_head_dim=64, + qk_rope_head_dim=64, + v_head_dim=128, + scale_emb=1, + scale_depth=1.4, + dim_model_base=256, + ): + super().__init__(parent=parent) + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.scale_emb = scale_emb + self.scale_depth = scale_depth + self.dim_model_base = dim_model_base + + +@require_torch +class MiniCPM3ModelTest(CausalLMModelTest, unittest.TestCase): + test_all_params_have_gradient = False + model_tester_class = MiniCPM3ModelTester + model_split_percents = [0.5, 0.7, 0.8] + + _torch_compile_train_cls = MiniCPM3ForCausalLM if is_torch_available() else None + + @unittest.skip("MiniCPM3 uses MLA attention which is incompatible with this test") + def test_sdpa_padding_matches_padding_free_with_position_ids(self): + pass + + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): + self.assertIsInstance(past_key_values, Cache) + + expected_common_shape = ( + batch_size, + getattr(config, "num_key_value_heads", config.num_attention_heads), + seq_length, + ) + expected_key_shape = expected_common_shape + (config.qk_nope_head_dim + config.qk_rope_head_dim,) + expected_value_shape = expected_common_shape + (config.v_head_dim,) + + for layer in past_key_values.layers: + self.assertEqual(layer.keys.shape, expected_key_shape) + self.assertEqual(layer.values.shape, expected_value_shape) + + def test_model_rope_scaling_frequencies(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + x = torch.randn(1, dtype=torch.float32, device=torch_device) + position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device).unsqueeze(0) + position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device).unsqueeze(0) + + original_rope = MiniCPM3RotaryEmbedding(config=config).to(torch_device) + original_freqs_cis_short = original_rope(x, position_ids_short) + original_freqs_cis_long = original_rope(x, position_ids_long) + torch.testing.assert_close(original_freqs_cis_short, original_freqs_cis_long[:, :short_input_length, :]) + + config.rope_parameters = {"rope_type": "linear", "rope_theta": 10000.0, "factor": scaling_factor} + linear_scaling_rope = MiniCPM3RotaryEmbedding(config=config).to(torch_device) + linear_freqs_cis_short = linear_scaling_rope(x, position_ids_short) + linear_freqs_cis_long = linear_scaling_rope(x, position_ids_long) + torch.testing.assert_close(linear_freqs_cis_short, linear_freqs_cis_long[:, :short_input_length, :]) + + config.rope_parameters = {"rope_type": "dynamic", "rope_theta": 10000.0, "factor": scaling_factor} + ntk_scaling_rope = MiniCPM3RotaryEmbedding(config=config).to(torch_device) + ntk_freqs_cis_short = ntk_scaling_rope(x, position_ids_short) + ntk_freqs_cis_long = ntk_scaling_rope(x, position_ids_long) + torch.testing.assert_close(ntk_freqs_cis_short, original_freqs_cis_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_freqs_cis_long, original_freqs_cis_long) + self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + + config.rope_parameters = {"rope_type": "yarn", "rope_theta": 10000.0, "factor": scaling_factor} + yarn_scaling_rope = MiniCPM3RotaryEmbedding(config=config).to(torch_device) + yarn_freqs_cis_short = yarn_scaling_rope(x, position_ids_short) + yarn_freqs_cis_long = yarn_scaling_rope(x, position_ids_long) + torch.testing.assert_close(yarn_freqs_cis_short, yarn_freqs_cis_long[:, :short_input_length, :]) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_freqs_cis_short, original_freqs_cis_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_freqs_cis_long, original_freqs_cis_long) + + def test_tp_plan_matches_params(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + if config.q_lora_rank is not None: + config.base_model_tp_plan.pop("layers.*.self_attn.q_proj") + super().test_tp_plan_matches_params() + config.base_model_tp_plan.update({"layers.*.self_attn.q_proj": "colwise"}) + + +@slow +@require_torch_accelerator +class MiniCPM3IntegrationTest(unittest.TestCase): + pass