Skip to content
Closed
15 changes: 15 additions & 0 deletions src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,21 @@
"expert_count": "num_experts",
"expert_used_count": "num_experts_per_tok",
},
"gpt_oss": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.dimension_count": None,
"rope.freq_base": "rope_theta",
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
"expert_count": "num_local_experts",
"expert_used_count": "num_experts_per_tok",
"sliding_window": "sliding_window",
},
"lfm2": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
Expand Down
145 changes: 145 additions & 0 deletions src/transformers/modeling_gguf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,107 @@ def _set_moe_expert_tensor(self, weights: np.ndarray, parsed_parameters: dict[st
out.copy_(torch_weights)


class GptOssTensorProcessor(TensorProcessor):
"""
Tensor processor for GPT-OSS models (MoE with 128 experts).
Handles:
- Splitting stacked expert tensors (down_proj, gate_proj, up_proj) into individual experts.
- Interleaving gate and up projections if stored in a combined tensor (gate_up_projs).
- Bias tensors (1D) are passed through without transpose.
"""

# Regex for separate expert tensors: e.g., blk.0.ffn_down_projs.weight
GGUF_MOE_WEIGHTS_PATTERN = re.compile(r"blk\.(?P<bid>\d+)\.ffn_(?P<proj>down|gate|up)_projs\.weight$")
# Regex for combined gate+up tensor: e.g., blk.0.ffn_gate_up_projs.weight
GGUF_MOE_COMBINED_PATTERN = re.compile(r"blk\.(?P<bid>\d+)\.ffn_gate_up_projs\.weight$")

def __init__(self, config=None):
super().__init__(config=config)

def process(self, weights, name: str, **kwargs):
# 1. Handle separate MoE expert tensors (down, gate, up)
if m := self.GGUF_MOE_WEIGHTS_PATTERN.match(name):
tensor_key_mapping = kwargs.get("tensor_key_mapping")
parsed_parameters = kwargs.get("parsed_parameters")
if tensor_key_mapping and parsed_parameters:
self._split_moe_expert_tensor(weights, parsed_parameters, m["bid"], m["proj"], tensor_key_mapping)
return GGUFTensor(weights, None, {}) # signal handled

# 2. Handle combined gate+up tensor
if m := self.GGUF_MOE_COMBINED_PATTERN.match(name):
tensor_key_mapping = kwargs.get("tensor_key_mapping")
parsed_parameters = kwargs.get("parsed_parameters")
if tensor_key_mapping and parsed_parameters:
self._interleave_gate_up_tensor(weights, parsed_parameters, m["bid"], tensor_key_mapping)
return GGUFTensor(weights, None, {})

# 3. Bias tensors (1D) → no transpose
if ".bias" in name and len(weights.shape) == 1:
return GGUFTensor(weights, name, {})

# 4. Default handling for all other tensors
return GGUFTensor(weights, name, {})

def _split_moe_expert_tensor(
self,
weights: np.ndarray,
parsed_parameters: dict,
bid: str,
proj: str,
tensor_key_mapping: dict,
):
"""Split a stacked MoE tensor into individual expert tensors."""
num_experts = self.config.get("num_local_experts", 128)
# Expected shape: [num_experts, hidden_size, intermediate_size] (or swapped).
# We assume the stored order is correct for the projection after splitting.
for i in range(min(num_experts, weights.shape[0])):
expert_weight = weights[i] # shape: [hidden, inter] or [inter, hidden]
# Build HF parameter name
hf_name = f"model.layers.{bid}.block_sparse_moe.experts.{i}.{proj}_proj.weight"
# Apply any user‑provided tensor key mapping
for key, mapped_key in tensor_key_mapping.items():
if key in hf_name:
hf_name = hf_name.replace(key, mapped_key)
# Store the tensor
parsed_parameters["tensors"][hf_name] = torch.tensor(expert_weight, copy=True)

def _interleave_gate_up_tensor(
self,
weights: np.ndarray,
parsed_parameters: dict,
bid: str,
tensor_key_mapping: dict,
):
"""
Process a combined gate+up tensor.
Expected shape: [num_experts, intermediate_size, hidden_size].
Interleaving: gate occupies first half of intermediate dimension,
up occupies second half. Transpose to [hidden, half_inter] per expert.
"""
num_experts = self.config.get("num_local_experts", 128)
inter_size = weights.shape[1]
half_inter = inter_size // 2
gate_part = weights[:, :half_inter, :] # [E, half_inter, hidden]
up_part = weights[:, half_inter:, :] # [E, half_inter, hidden]

for i in range(min(num_experts, weights.shape[0])):
gate_weight = gate_part[i].T # [hidden, half_inter]
up_weight = up_part[i].T # [hidden, half_inter]

gate_name = f"model.layers.{bid}.block_sparse_moe.experts.{i}.gate_proj.weight"
up_name = f"model.layers.{bid}.block_sparse_moe.experts.{i}.up_proj.weight"

# Apply mapping
for key, mapped_key in tensor_key_mapping.items():
if key in gate_name:
gate_name = gate_name.replace(key, mapped_key)
if key in up_name:
up_name = up_name.replace(key, mapped_key)

parsed_parameters["tensors"][gate_name] = torch.tensor(gate_weight, copy=True)
parsed_parameters["tensors"][up_name] = torch.tensor(up_weight, copy=True)


class BloomTensorProcessor(TensorProcessor):
def __init__(self, config=None):
super().__init__(config=config)
Expand Down Expand Up @@ -355,6 +456,7 @@ def _set_moe_expert_tensor(self, weights: np.ndarray, parsed_parameters: dict[st
TENSOR_PROCESSORS = {
"llama": LlamaTensorProcessor,
"qwen2moe": Qwen2MoeTensorProcessor,
"gpt_oss": GptOssTensorProcessor,
"qwen3moe": Qwen2MoeTensorProcessor,
"bloom": BloomTensorProcessor,
"t5": T5TensorProcessor,
Expand Down Expand Up @@ -416,6 +518,8 @@ def get_gguf_hf_weights_map(
model_type = "t5"
elif model_type == "minimax_m2":
model_type = "minimax-m2"
elif model_type == "gpt_oss":
model_type = "gpt-oss"
arch = None
for key, value in MODEL_ARCH_NAMES.items():
if value == model_type:
Expand Down Expand Up @@ -522,6 +626,8 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_lo

if "qwen2moe" in architecture:
updated_architecture = "qwen2_moe"
elif "gpt_oss" in architecture or "gpt-oss" in architecture:
updated_architecture = "gpt_oss"
elif "qwen3moe" in architecture:
updated_architecture = "qwen3_moe"
elif "minimax-m2" in architecture:
Expand Down Expand Up @@ -609,6 +715,45 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_lo
i for i, num_kv_heads in enumerate(gguf_num_key_value_heads) if num_kv_heads > 0
]

if updated_architecture == "gpt_oss":
# Helper to read keys with the correct prefix
def read_gpt_key(reader, suffix, default=None):
key = f"gpt-oss.{suffix}"
if key in reader.fields:
val = reader.fields[key].parts[0]
if isinstance(val, bytes):
val = val.decode("utf-8")
return val
return default

# Reconstruct rope_scaling from GGUF metadata
rope_type = read_gpt_key(reader, "rope.scaling.type")
if rope_type is not None:
rope_scaling = {"rope_type": rope_type}

# Collect all rope.scaling keys dynamically
for key in reader.fields:
if not key.startswith("gpt-oss.rope.scaling."):
continue
suffix = key[len("gpt-oss.rope.scaling.") :]
if suffix == "type":
continue
value = reader.fields[key].parts[0]
if isinstance(value, bytes):
value = value.decode("utf-8")
# Convert to appropriate type
if suffix in ("factor", "attention_factor", "beta_fast", "beta_slow"):
value = float(value)
elif suffix in ("original_context_length", "original_max_position_embeddings"):
# Map GGUF's original_context_length to HF's original_max_position_embeddings
suffix = "original_max_position_embeddings"
value = int(value)
else:
pass
rope_scaling[suffix] = value

parsed_parameters["config"]["rope_scaling"] = rope_scaling

# retrieve config vocab_size from tokenizer
# Please refer to https://github.com/huggingface/transformers/issues/32526 for more details
if "vocab_size" not in parsed_parameters["config"]:
Expand Down
16 changes: 16 additions & 0 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ class GgufModelTests(unittest.TestCase):
q4_k_m_qwen3moe_model_id = "Qwen3-30B-A3B-Q4_K_M.gguf"
q8_0_umt5_encoder_model_id = "umt5-xxl-encoder-Q8_0.gguf"
q4_k_m_lfm2_model_id = "LFM2-1.2B-Q4_K_M.gguf"
gpt_oss_model_id = "unsloth/gpt-oss-20b-GGUF"
gpt_oss_gguf_file = "gpt-oss-20b-Q5_K_M.gguf"

example_text = "Hello"

Expand Down Expand Up @@ -384,6 +386,20 @@ def test_qwen2_q4_0(self):
EXPECTED_TEXT = "Hello.jsoup\n\nI am a beginner"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_gpt_oss_q5_k_m(self):
tokenizer = AutoTokenizer.from_pretrained(self.gpt_oss_model_id, gguf_file=self.gpt_oss_gguf_file)
model = AutoModelForCausalLM.from_pretrained(
self.gpt_oss_model_id,
gguf_file=self.gpt_oss_gguf_file,
device_map="auto",
dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, I just want to say that I am just"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_qwen2moe_q8(self):
tokenizer = AutoTokenizer.from_pretrained(self.qwen2moe_model_id, gguf_file=self.q8_qwen2moe_model_id)
model = AutoModelForCausalLM.from_pretrained(
Expand Down