From 7c40ac95aa7407feed9a22196f6955db50e9563f Mon Sep 17 00:00:00 2001 From: minsangkim Date: Mon, 9 Mar 2026 17:55:08 +0900 Subject: [PATCH 1/2] Add A.X K1 --- src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + src/transformers/models/auto/modeling_auto.py | 4 + src/transformers/models/ax_k1/__init__.py | 27 + .../models/ax_k1/configuration_ax_k1.py | 177 +++++ .../models/ax_k1/modeling_ax_k1.py | 731 ++++++++++++++++++ .../models/ax_k1/modular_ax_k1.py | 389 ++++++++++ src/transformers/utils/auto_docstring.py | 1 + 8 files changed, 1333 insertions(+) create mode 100644 src/transformers/models/ax_k1/__init__.py create mode 100644 src/transformers/models/ax_k1/configuration_ax_k1.py create mode 100644 src/transformers/models/ax_k1/modeling_ax_k1.py create mode 100644 src/transformers/models/ax_k1/modular_ax_k1.py diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 292b66f3546a..a9e36cf220c5 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -31,6 +31,7 @@ from .auto import * from .autoformer import * from .aya_vision import * + from .ax_k1 import * from .bamba import * from .bark import * from .bart import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index aa037408c937..45f4e5cbc665 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -49,6 +49,7 @@ ("audioflamingo3_encoder", "AudioFlamingo3EncoderConfig"), ("autoformer", "AutoformerConfig"), ("aya_vision", "AyaVisionConfig"), + ("axk1", "AXK1Config"), ("bamba", "BambaConfig"), ("bark", "BarkConfig"), ("bart", "BartConfig"), @@ -531,6 +532,7 @@ ("audioflamingo3_encoder", "AudioFlamingo3Encoder"), ("autoformer", "Autoformer"), ("aya_vision", "AyaVision"), + ("axk1", "A.X-K1"), ("bamba", "Bamba"), ("bark", "Bark"), ("bart", "BART"), @@ -1034,6 +1036,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str]( [ ("audioflamingo3_encoder", "audioflamingo3"), + ("axk1", "ax_k1"), ("openai-gpt", "openai"), ("blip-2", "blip_2"), ("data2vec-audio", "data2vec"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 757d772bb968..ca84395f4bab 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -56,6 +56,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("audioflamingo3_encoder", "AudioFlamingo3Encoder"), ("autoformer", "AutoformerModel"), ("aya_vision", "AyaVisionModel"), + ("axk1", "AXK1Model"), ("bamba", "BambaModel"), ("bark", "BarkModel"), ("bart", "BartModel"), @@ -589,6 +590,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("apertus", "ApertusForCausalLM"), ("arcee", "ArceeForCausalLM"), ("aria_text", "AriaTextForCausalLM"), + ("axk1", "AXK1ForCausalLM"), ("bamba", "BambaForCausalLM"), ("bart", "BartForCausalLM"), ("bert", "BertLMHeadModel"), @@ -1188,6 +1190,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("data2vec-text", "Data2VecTextForSequenceClassification"), ("deberta", "DebertaForSequenceClassification"), ("deberta-v2", "DebertaV2ForSequenceClassification"), + ("axk1", "AXK1ForSequenceClassification"), ("deepseek_v2", "DeepseekV2ForSequenceClassification"), ("deepseek_v3", "DeepseekV3ForSequenceClassification"), ("diffllama", "DiffLlamaForSequenceClassification"), @@ -1410,6 +1413,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("convbert", "ConvBertForTokenClassification"), ("data2vec-text", "Data2VecTextForTokenClassification"), ("deberta", "DebertaForTokenClassification"), + ("axk1", "AXK1ForTokenClassification"), ("deberta-v2", "DebertaV2ForTokenClassification"), ("deepseek_v3", "DeepseekV3ForTokenClassification"), ("diffllama", "DiffLlamaForTokenClassification"), diff --git a/src/transformers/models/ax_k1/__init__.py b/src/transformers/models/ax_k1/__init__.py new file mode 100644 index 000000000000..f4412e3853c6 --- /dev/null +++ b/src/transformers/models/ax_k1/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_ax_k1 import * + from .modeling_ax_k1 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/ax_k1/configuration_ax_k1.py b/src/transformers/models/ax_k1/configuration_ax_k1.py new file mode 100644 index 000000000000..903a4db4553c --- /dev/null +++ b/src/transformers/models/ax_k1/configuration_ax_k1.py @@ -0,0 +1,177 @@ +# Copyright 2025 SK Telecom and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A.X-K1 model configuration""" + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters +from ...utils import auto_docstring + + +AXK1_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +@auto_docstring(checkpoint="skt/A.X-K1") +class AXK1Config(PreTrainedConfig): + r""" + n_group (`int`, *optional*, defaults to 8): + Number of groups for routed experts. + first_k_dense_replace (`int`, *optional*, defaults to 1): + Number of dense layers in shallow layers (embed->dense->moe->moe...->lm_head). + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + rope_interleave (`bool`, *optional*, defaults to `True`): + Whether to interleave the rotary position embeddings. + Example: + + ```python + >>> from transformers import AXK1Model, AXK1Config + + >>> # Initializing a A.X-K1 style configuration + >>> configuration = AXK1Config() + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "axk1" + keys_to_ignore_at_inference = ["past_key_values"] + + base_model_tp_plan = { + "layers.*.mlp.experts.gate_up_proj": "packed_colwise", + "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + attribute_map = { + "num_local_experts": "n_routed_experts", + } + + def __init__( + self, + vocab_size: int | None = 163840, + hidden_size: int | None = 7168, + intermediate_size: int | None = 18432, + moe_intermediate_size: int | None = 2048, + num_hidden_layers: int | None = 61, + num_attention_heads: int | None = 64, + num_key_value_heads: int | None = 64, + n_shared_experts: int | None = 1, + n_routed_experts: int | None = 256, + routed_scaling_factor: float | None = 2.5, + kv_lora_rank: int | None = 512, + q_lora_rank: int | None = 1536, + qk_rope_head_dim: int | None = 64, + v_head_dim: int | None = 128, + qk_nope_head_dim: int | None = 128, + n_group: int | None = 8, + topk_group: int | None = 4, + num_experts_per_tok: int | None = 8, + first_k_dense_replace: int | None = 1, + moe_layer_freq: int | None = 1, + norm_topk_prob: bool | None = True, + hidden_act: str | None = "silu", + max_position_embeddings: int | None = 4096, + initializer_range: float | None = 0.02, + rms_norm_eps: int | None = 1e-6, + use_cache: bool | None = True, + pad_token_id: int | None = None, + bos_token_id: int | None = 163691, + eos_token_id: int | None = 163691, + pretraining_tp: int | None = 1, + tie_word_embeddings: bool | None = False, + rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None, + rope_interleave: bool | None = True, + attention_bias: bool | None = False, + attention_dropout: float | None = 0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + + # Multi-head Latent Attention (MLA) + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.head_dim = qk_rope_head_dim + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + # Mixture of Experts (MoE) + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.routed_scaling_factor = routed_scaling_factor + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.first_k_dense_replace = first_k_dense_replace + self.moe_layer_freq = moe_layer_freq + self.norm_topk_prob = norm_topk_prob + + # RoPE + self.rope_parameters = rope_parameters + self.rope_interleave = rope_interleave + + # General + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation: set | None = None, **kwargs): + rope_scaling = kwargs.pop("rope_scaling", None) + self.rope_parameters = rope_scaling or self.rope_parameters + self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else {} + + self.rope_parameters.setdefault("rope_theta", kwargs.pop("rope_theta", self.default_theta)) + self.standardize_rope_params() + self.validate_rope(ignore_keys=ignore_keys_at_rope_validation) + + for key in ["beta_fast", "beta_slow", "factor"]: + if key in self.rope_parameters: + self.rope_parameters[key] = float(self.rope_parameters[key]) + return kwargs + + +__all__ = ["AXK1Config"] diff --git a/src/transformers/models/ax_k1/modeling_ax_k1.py b/src/transformers/models/ax_k1/modeling_ax_k1.py new file mode 100644 index 000000000000..aa5c1b4c6df6 --- /dev/null +++ b/src/transformers/models/ax_k1/modeling_ax_k1.py @@ -0,0 +1,731 @@ +# Copyright 2025 SK Telecom and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch A.X-K1 model.""" + +import math +from collections.abc import Callable +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernel_func_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import ( + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from .configuration_ax_k1 import AXK1Config + + +@use_kernel_forward_from_hub("RMSNorm") +class AXK1RMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class AXK1RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: AXK1Config, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: AXK1Config | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class AXK1MLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class AXK1TopkRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.n_routed_experts = config.n_routed_experts + + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts)) + + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) + return router_logits + + +@use_experts_implementation +class AXK1NaiveMoe(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__(self, config): + super().__init__() + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + +class AXK1MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.experts = AXK1NaiveMoe(config) + self.gate = AXK1TopkRouter(config) + self.shared_experts = AXK1MLP( + config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts + ) + self.n_routed_experts = config.n_routed_experts + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + self.top_k = config.num_experts_per_tok + + def route_tokens_to_experts(self, router_logits): + router_logits = router_logits.sigmoid() + router_logits_for_choice = router_logits + self.gate.e_score_correction_bias + group_scores = ( + router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) + scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), float("-inf")) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = router_logits.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights + + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + router_logits = self.gate(hidden_states) + topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies interleaved Rotary Position Embedding to the query and key tensors. + + Unlike the standard RoPE which splits the hidden dim into two halves, this variant interleaves + even/odd dimensions before applying the rotation, matching the weight layout used during pretraining. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Unused. Kept for API compatibility. + unsqueeze_dim (`int`, *optional*, defaults to 1): + Dimension along which to unsqueeze cos/sin for broadcasting. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class AXK1Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: AXK1Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_attention_heads + + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_head_dim = config.qk_head_dim + + self.is_causal = True + if self.q_lora_rank is None: + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = AXK1RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) + + self.kv_a_proj_with_mqa = nn.Linear( + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = AXK1RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + self.scaling = self.qk_head_dim ** (-0.5) + if self.config.rope_parameters.get("rope_type", "default") != "default": + mscale_all_dim = self.config.rope_parameters.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_parameters["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scaling = self.scaling * mscale * mscale + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + + if self.q_lora_rank is None: + q_states = self.q_proj(hidden_states) + else: + q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q_states = q_states.view(query_shape).transpose(1, 2) + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + + cos, sin = position_embeddings + if self.config.rope_interleave: # support using interleaved weights for efficiency + q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) + else: + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) + + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class AXK1DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: AXK1Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + + self.self_attn = AXK1Attention(config=config, layer_idx=layer_idx) + + self.is_moe_layer = ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + if self.is_moe_layer: + self.mlp = AXK1MoE(config) + else: + self.mlp = AXK1MLP(config) + + self.input_layernorm = AXK1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = AXK1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_mlp_layernorm = AXK1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if self.is_moe_layer: + hidden_states = self.post_mlp_layernorm(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class AXK1PreTrainedModel(PreTrainedModel): + config: AXK1Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["AXK1DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": AXK1DecoderLayer, + "attentions": AXK1Attention, + } + _keep_in_fp32_modules_strict = ["e_score_correction_bias"] + _keys_to_ignore_on_load_unexpected = [r"model\.layers\.\d+\.self_attn\.rotary_emb\.inv_freq"] + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, AXK1TopkRouter): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + init.zeros_(module.e_score_correction_bias) + elif isinstance(module, AXK1NaiveMoe): + init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) + init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) + + +@auto_docstring +class AXK1Model(AXK1PreTrainedModel): + def __init__(self, config: AXK1Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [AXK1DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = AXK1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = AXK1RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_embeddings=position_embeddings, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class AXK1ForCausalLM(AXK1PreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = AXK1Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, AXK1ForCausalLM + + >>> model = AXK1ForCausalLM.from_pretrained("skt/A.X-K1") + >>> tokenizer = AutoTokenizer.from_pretrained("skt/A.X-K1") + + >>> prompt = "SKT의 A.X 모델에 대해 알려줘." + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "SKT의 A.X 모델에 대해 알려줘.\nA.X는 SK텔레콤이 개발한 대규모 언어 모델입니다." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class AXK1ForSequenceClassification(GenericForSequenceClassification, AXK1PreTrainedModel): + pass + + +class AXK1ForTokenClassification(GenericForTokenClassification, AXK1PreTrainedModel): + pass + + +__all__ = [ + "AXK1PreTrainedModel", + "AXK1Model", + "AXK1ForCausalLM", + "AXK1ForSequenceClassification", + "AXK1ForTokenClassification", +] diff --git a/src/transformers/models/ax_k1/modular_ax_k1.py b/src/transformers/models/ax_k1/modular_ax_k1.py new file mode 100644 index 000000000000..32069aff9c9b --- /dev/null +++ b/src/transformers/models/ax_k1/modular_ax_k1.py @@ -0,0 +1,389 @@ +# Copyright 2025 SK Telecom and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch A.X-K1 model (modular).""" + +import math +from collections.abc import Callable + +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...cache_utils import Cache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GenericForSequenceClassification, GenericForTokenClassification +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, logging +from ...utils.generic import is_flash_attention_requested +from ..llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, + rotate_half, +) +from ..mixtral.modeling_mixtral import MixtralExperts +from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeMLP +from .configuration_ax_k1 import AXK1Config + + +logger = logging.get_logger(__name__) + + +class AXK1RMSNorm(LlamaRMSNorm): + pass + + +class AXK1RotaryEmbedding(LlamaRotaryEmbedding): + pass + + +class AXK1MLP(Qwen2MoeMLP): + pass + + +def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies interleaved Rotary Position Embedding to the query and key tensors. + + Unlike the standard RoPE which splits the hidden dim into two halves, this variant interleaves + even/odd dimensions before applying the rotation, matching the weight layout used during pretraining. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Unused. Kept for API compatibility. + unsqueeze_dim (`int`, *optional*, defaults to 1): + Dimension along which to unsqueeze cos/sin for broadcasting. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class AXK1TopkRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.n_routed_experts = config.n_routed_experts + + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts)) + + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) + return router_logits + + +class AXK1NaiveMoe(MixtralExperts): + def __init__(self, config): + super().__init__(config) + self.num_experts = config.num_local_experts + self.intermediate_dim = config.moe_intermediate_size + + +class AXK1MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.experts = AXK1NaiveMoe(config) + self.gate = AXK1TopkRouter(config) + self.shared_experts = AXK1MLP( + config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts + ) + self.n_routed_experts = config.n_routed_experts + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + self.top_k = config.num_experts_per_tok + + def route_tokens_to_experts(self, router_logits): + router_logits = router_logits.sigmoid() + router_logits_for_choice = router_logits + self.gate.e_score_correction_bias + group_scores = ( + router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) + scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = router_logits.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights + + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + router_logits = self.gate(hidden_states) + topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + +class AXK1Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: AXK1Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_attention_heads + + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_head_dim = config.qk_head_dim + + self.is_causal = True + if self.q_lora_rank is None: + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = AXK1RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) + + self.kv_a_proj_with_mqa = nn.Linear( + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = AXK1RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + self.scaling = self.qk_head_dim ** (-0.5) + if self.config.rope_parameters.get("rope_type", "default") != "default": + mscale_all_dim = self.config.rope_parameters.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_parameters["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scaling = self.scaling * mscale * mscale + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + + if self.q_lora_rank is None: + q_states = self.q_proj(hidden_states) + else: + q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q_states = q_states.view(query_shape).transpose(1, 2) + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + + cos, sin = position_embeddings + if self.config.rope_interleave: + q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) + else: + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) + + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class AXK1DecoderLayer(LlamaDecoderLayer): + def __init__(self, config: AXK1Config, layer_idx: int): + nn.Module.__init__(self) + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + + self.self_attn = AXK1Attention(config=config, layer_idx=layer_idx) + + self.is_moe_layer = ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + if self.is_moe_layer: + self.mlp = AXK1MoE(config) + else: + self.mlp = AXK1MLP(config) + + self.input_layernorm = AXK1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = AXK1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_mlp_layernorm = AXK1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if self.is_moe_layer: + hidden_states = self.post_mlp_layernorm(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class AXK1PreTrainedModel(LlamaPreTrainedModel): + _keep_in_fp32_modules_strict = ["e_score_correction_bias"] + _keys_to_ignore_on_load_unexpected = [r"model\.layers\.\d+\.self_attn\.rotary_emb\.inv_freq"] + + @torch.no_grad() + def _init_weights(self, module): + PreTrainedModel._init_weights(self, module) + if isinstance(module, AXK1TopkRouter): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + init.zeros_(module.e_score_correction_bias) + elif isinstance(module, AXK1NaiveMoe): + init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) + init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) + + +class AXK1Model(LlamaModel): + pass + + +class AXK1ForCausalLM(LlamaForCausalLM): + pass + + +class AXK1ForSequenceClassification(GenericForSequenceClassification, AXK1PreTrainedModel): + pass + + +class AXK1ForTokenClassification(GenericForTokenClassification, AXK1PreTrainedModel): + pass + + +__all__ = [ + "AXK1PreTrainedModel", + "AXK1Model", + "AXK1ForCausalLM", + "AXK1ForSequenceClassification", + "AXK1ForTokenClassification", +] diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py index 73f83e3a7f5c..4db043453b06 100644 --- a/src/transformers/utils/auto_docstring.py +++ b/src/transformers/utils/auto_docstring.py @@ -79,6 +79,7 @@ "parakeet": "ParakeetCTCConfig", "lasr": "LasrCTCConfig", "wav2vec2-with-lm": "Wav2Vec2Config", + "ax-k1": "AXK1Config", } _re_checkpoint = re.compile(r"\[(.+?)\]\((https://huggingface\.co/.+?)\)") From 76df201e93100943cc8ea22cba2993721edfae2c Mon Sep 17 00:00:00 2001 From: 1113778 Date: Tue, 10 Mar 2026 17:14:53 +0900 Subject: [PATCH 2/2] match axk1 type name --- src/transformers/models/auto/configuration_auto.py | 6 +++--- src/transformers/models/auto/modeling_auto.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 45f4e5cbc665..fc73f71cd81a 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -49,7 +49,7 @@ ("audioflamingo3_encoder", "AudioFlamingo3EncoderConfig"), ("autoformer", "AutoformerConfig"), ("aya_vision", "AyaVisionConfig"), - ("axk1", "AXK1Config"), + ("AXK1", "AXK1Config"), ("bamba", "BambaConfig"), ("bark", "BarkConfig"), ("bart", "BartConfig"), @@ -532,7 +532,7 @@ ("audioflamingo3_encoder", "AudioFlamingo3Encoder"), ("autoformer", "Autoformer"), ("aya_vision", "AyaVision"), - ("axk1", "A.X-K1"), + ("AXK1", "A.X-K1"), ("bamba", "Bamba"), ("bark", "Bark"), ("bart", "BART"), @@ -1036,7 +1036,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str]( [ ("audioflamingo3_encoder", "audioflamingo3"), - ("axk1", "ax_k1"), + ("AXK1", "ax_k1"), ("openai-gpt", "openai"), ("blip-2", "blip_2"), ("data2vec-audio", "data2vec"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index ca84395f4bab..fb43142d6f97 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -56,7 +56,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("audioflamingo3_encoder", "AudioFlamingo3Encoder"), ("autoformer", "AutoformerModel"), ("aya_vision", "AyaVisionModel"), - ("axk1", "AXK1Model"), + ("AXK1", "AXK1Model"), ("bamba", "BambaModel"), ("bark", "BarkModel"), ("bart", "BartModel"), @@ -590,7 +590,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("apertus", "ApertusForCausalLM"), ("arcee", "ArceeForCausalLM"), ("aria_text", "AriaTextForCausalLM"), - ("axk1", "AXK1ForCausalLM"), + ("AXK1", "AXK1ForCausalLM"), ("bamba", "BambaForCausalLM"), ("bart", "BartForCausalLM"), ("bert", "BertLMHeadModel"), @@ -1190,7 +1190,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("data2vec-text", "Data2VecTextForSequenceClassification"), ("deberta", "DebertaForSequenceClassification"), ("deberta-v2", "DebertaV2ForSequenceClassification"), - ("axk1", "AXK1ForSequenceClassification"), + ("AXK1", "AXK1ForSequenceClassification"), ("deepseek_v2", "DeepseekV2ForSequenceClassification"), ("deepseek_v3", "DeepseekV3ForSequenceClassification"), ("diffllama", "DiffLlamaForSequenceClassification"), @@ -1413,7 +1413,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("convbert", "ConvBertForTokenClassification"), ("data2vec-text", "Data2VecTextForTokenClassification"), ("deberta", "DebertaForTokenClassification"), - ("axk1", "AXK1ForTokenClassification"), + ("AXK1", "AXK1ForTokenClassification"), ("deberta-v2", "DebertaV2ForTokenClassification"), ("deepseek_v3", "DeepseekV3ForTokenClassification"), ("diffllama", "DiffLlamaForTokenClassification"),