Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .auto import *
from .autoformer import *
from .aya_vision import *
from .ax_k1 import *
from .bamba import *
from .bark import *
from .bart import *
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
("audioflamingo3_encoder", "AudioFlamingo3EncoderConfig"),
("autoformer", "AutoformerConfig"),
("aya_vision", "AyaVisionConfig"),
("AXK1", "AXK1Config"),
("bamba", "BambaConfig"),
("bark", "BarkConfig"),
("bart", "BartConfig"),
Expand Down Expand Up @@ -531,6 +532,7 @@
("audioflamingo3_encoder", "AudioFlamingo3Encoder"),
("autoformer", "Autoformer"),
("aya_vision", "AyaVision"),
("AXK1", "A.X-K1"),
("bamba", "Bamba"),
("bark", "Bark"),
("bart", "BART"),
Expand Down Expand Up @@ -1034,6 +1036,7 @@
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
[
("audioflamingo3_encoder", "audioflamingo3"),
("AXK1", "ax_k1"),
("openai-gpt", "openai"),
("blip-2", "blip_2"),
("data2vec-audio", "data2vec"),
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("audioflamingo3_encoder", "AudioFlamingo3Encoder"),
("autoformer", "AutoformerModel"),
("aya_vision", "AyaVisionModel"),
("AXK1", "AXK1Model"),
("bamba", "BambaModel"),
("bark", "BarkModel"),
("bart", "BartModel"),
Expand Down Expand Up @@ -589,6 +590,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("apertus", "ApertusForCausalLM"),
("arcee", "ArceeForCausalLM"),
("aria_text", "AriaTextForCausalLM"),
("AXK1", "AXK1ForCausalLM"),
("bamba", "BambaForCausalLM"),
("bart", "BartForCausalLM"),
("bert", "BertLMHeadModel"),
Expand Down Expand Up @@ -1188,6 +1190,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("data2vec-text", "Data2VecTextForSequenceClassification"),
("deberta", "DebertaForSequenceClassification"),
("deberta-v2", "DebertaV2ForSequenceClassification"),
("AXK1", "AXK1ForSequenceClassification"),
("deepseek_v2", "DeepseekV2ForSequenceClassification"),
("deepseek_v3", "DeepseekV3ForSequenceClassification"),
("diffllama", "DiffLlamaForSequenceClassification"),
Expand Down Expand Up @@ -1410,6 +1413,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("convbert", "ConvBertForTokenClassification"),
("data2vec-text", "Data2VecTextForTokenClassification"),
("deberta", "DebertaForTokenClassification"),
("AXK1", "AXK1ForTokenClassification"),
("deberta-v2", "DebertaV2ForTokenClassification"),
("deepseek_v3", "DeepseekV3ForTokenClassification"),
("diffllama", "DiffLlamaForTokenClassification"),
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/models/ax_k1/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_ax_k1 import *
from .modeling_ax_k1 import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
177 changes: 177 additions & 0 deletions src/transformers/models/ax_k1/configuration_ax_k1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright 2025 SK Telecom and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A.X-K1 model configuration"""

from ...configuration_utils import PreTrainedConfig
from ...modeling_rope_utils import RopeParameters
from ...utils import auto_docstring


AXK1_PRETRAINED_CONFIG_ARCHIVE_MAP = {}


@auto_docstring(checkpoint="skt/A.X-K1")
class AXK1Config(PreTrainedConfig):
r"""
n_group (`int`, *optional*, defaults to 8):
Number of groups for routed experts.
first_k_dense_replace (`int`, *optional*, defaults to 1):
Number of dense layers in shallow layers (embed->dense->moe->moe...->lm_head).
moe_layer_freq (`int`, *optional*, defaults to 1):
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
rope_interleave (`bool`, *optional*, defaults to `True`):
Whether to interleave the rotary position embeddings.
Example:

```python
>>> from transformers import AXK1Model, AXK1Config

>>> # Initializing a A.X-K1 style configuration
>>> configuration = AXK1Config()

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "axk1"
keys_to_ignore_at_inference = ["past_key_values"]

base_model_tp_plan = {
"layers.*.mlp.experts.gate_up_proj": "packed_colwise",
"layers.*.mlp.experts.down_proj": "rowwise",
"layers.*.mlp.experts": "moe_tp_experts",
"layers.*.mlp.shared_experts.gate_proj": "colwise",
"layers.*.mlp.shared_experts.up_proj": "colwise",
"layers.*.mlp.shared_experts.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
attribute_map = {
"num_local_experts": "n_routed_experts",
}

def __init__(
self,
vocab_size: int | None = 163840,
hidden_size: int | None = 7168,
intermediate_size: int | None = 18432,
moe_intermediate_size: int | None = 2048,
num_hidden_layers: int | None = 61,
num_attention_heads: int | None = 64,
num_key_value_heads: int | None = 64,
n_shared_experts: int | None = 1,
n_routed_experts: int | None = 256,
routed_scaling_factor: float | None = 2.5,
kv_lora_rank: int | None = 512,
q_lora_rank: int | None = 1536,
qk_rope_head_dim: int | None = 64,
v_head_dim: int | None = 128,
qk_nope_head_dim: int | None = 128,
n_group: int | None = 8,
topk_group: int | None = 4,
num_experts_per_tok: int | None = 8,
first_k_dense_replace: int | None = 1,
moe_layer_freq: int | None = 1,
norm_topk_prob: bool | None = True,
hidden_act: str | None = "silu",
max_position_embeddings: int | None = 4096,
initializer_range: float | None = 0.02,
rms_norm_eps: int | None = 1e-6,
use_cache: bool | None = True,
pad_token_id: int | None = None,
bos_token_id: int | None = 163691,
eos_token_id: int | None = 163691,
pretraining_tp: int | None = 1,
tie_word_embeddings: bool | None = False,
rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
rope_interleave: bool | None = True,
attention_bias: bool | None = False,
attention_dropout: float | None = 0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_position_embeddings = max_position_embeddings

# Multi-head Latent Attention (MLA)
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.head_dim = qk_rope_head_dim

if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads

# Mixture of Experts (MoE)
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.routed_scaling_factor = routed_scaling_factor
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.first_k_dense_replace = first_k_dense_replace
self.moe_layer_freq = moe_layer_freq
self.norm_topk_prob = norm_topk_prob

# RoPE
self.rope_parameters = rope_parameters
self.rope_interleave = rope_interleave

# General
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation: set | None = None, **kwargs):
rope_scaling = kwargs.pop("rope_scaling", None)
self.rope_parameters = rope_scaling or self.rope_parameters
self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else {}

self.rope_parameters.setdefault("rope_theta", kwargs.pop("rope_theta", self.default_theta))
self.standardize_rope_params()
self.validate_rope(ignore_keys=ignore_keys_at_rope_validation)

for key in ["beta_fast", "beta_slow", "factor"]:
if key in self.rope_parameters:
self.rope_parameters[key] = float(self.rope_parameters[key])
return kwargs


__all__ = ["AXK1Config"]
Loading
Loading