From 69f669e862c14c0d66b9ba33231731d73db08412 Mon Sep 17 00:00:00 2001 From: sirzechs66 Date: Mon, 30 Mar 2026 15:12:22 +0530 Subject: [PATCH 1/9] Add GPT-OSS GGUF support with YaRN rope scaling reconstruction --- src/transformers/integrations/ggml.py | 15 ++ .../modeling_gguf_pytorch_utils.py | 136 ++++++++++++++++++ 2 files changed, 151 insertions(+) diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index 29ec365e7ce2..84ec36dcced3 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -89,6 +89,21 @@ "expert_count": "num_experts", "expert_used_count": "num_experts_per_tok", }, + "gpt_oss": { + "context_length": "max_position_embeddings", + "block_count": "num_hidden_layers", + "feed_forward_length": "intermediate_size", + "embedding_length": "hidden_size", + "rope.dimension_count": None, + "rope.freq_base": "rope_theta", + "attention.head_count": "num_attention_heads", + "attention.head_count_kv": "num_key_value_heads", + "attention.layer_norm_rms_epsilon": "rms_norm_eps", + "vocab_size": "vocab_size", + "expert_count": "num_local_experts", + "expert_used_count": "num_experts_per_tok", + "sliding_window": "sliding_window", + }, "lfm2": { "context_length": "max_position_embeddings", "block_count": "num_hidden_layers", diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 66306b6f71f6..c135461f0ea5 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -169,6 +169,113 @@ def _set_moe_expert_tensor(self, weights: np.ndarray, parsed_parameters: dict[st else: # w == "up" out = out.narrow(shard_dim, shard_size, shard_size) out.copy_(torch_weights) +class GptOssTensorProcessor(TensorProcessor): + """ + Tensor processor for GPT-OSS models (MoE with 128 experts). + Handles: + - Splitting stacked expert tensors (down_proj, gate_proj, up_proj) into individual experts. + - Interleaving gate and up projections if stored in a combined tensor (gate_up_projs). + - Bias tensors (1D) are passed through without transpose. + """ + # Regex for separate expert tensors: e.g., blk.0.ffn_down_projs.weight + GGUF_MOE_WEIGHTS_PATTERN = re.compile( + r"blk\.(?P\d+)\.ffn_(?Pdown|gate|up)_projs\.weight$" + ) + # Regex for combined gate+up tensor: e.g., blk.0.ffn_gate_up_projs.weight + GGUF_MOE_COMBINED_PATTERN = re.compile( + r"blk\.(?P\d+)\.ffn_gate_up_projs\.weight$" + ) + + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name: str, **kwargs): + # 1. Handle separate MoE expert tensors (down, gate, up) + if m := self.GGUF_MOE_WEIGHTS_PATTERN.match(name): + tensor_key_mapping = kwargs.get("tensor_key_mapping") + parsed_parameters = kwargs.get("parsed_parameters") + if tensor_key_mapping and parsed_parameters: + self._split_moe_expert_tensor( + weights, parsed_parameters, m["bid"], m["proj"], tensor_key_mapping + ) + return GGUFTensor(weights, None, {}) # signal handled + + # 2. Handle combined gate+up tensor + if m := self.GGUF_MOE_COMBINED_PATTERN.match(name): + tensor_key_mapping = kwargs.get("tensor_key_mapping") + parsed_parameters = kwargs.get("parsed_parameters") + if tensor_key_mapping and parsed_parameters: + self._interleave_gate_up_tensor( + weights, parsed_parameters, m["bid"], tensor_key_mapping + ) + return GGUFTensor(weights, None, {}) + + # 3. Bias tensors (1D) → no transpose + if ".bias" in name and len(weights.shape) == 1: + return GGUFTensor(weights, name, {}) + + # 4. Default handling for all other tensors + return GGUFTensor(weights, name, {}) + + def _split_moe_expert_tensor( + self, + weights: np.ndarray, + parsed_parameters: dict, + bid: str, + proj: str, + tensor_key_mapping: dict, + ): + """Split a stacked MoE tensor into individual expert tensors.""" + num_experts = self.config.get("num_local_experts", 128) + # Expected shape: [num_experts, hidden_size, intermediate_size] (or swapped). + # We assume the stored order is correct for the projection after splitting. + for i in range(min(num_experts, weights.shape[0])): + expert_weight = weights[i] # shape: [hidden, inter] or [inter, hidden] + # Build HF parameter name + hf_name = f"model.layers.{bid}.block_sparse_moe.experts.{i}.{proj}_proj.weight" + # Apply any user‑provided tensor key mapping + for key, mapped_key in tensor_key_mapping.items(): + if key in hf_name: + hf_name = hf_name.replace(key, mapped_key) + # Store the tensor + parsed_parameters["tensors"][hf_name] = torch.from_numpy(np.copy(expert_weight)) + + def _interleave_gate_up_tensor( + self, + weights: np.ndarray, + parsed_parameters: dict, + bid: str, + tensor_key_mapping: dict, + ): + """ + Process a combined gate+up tensor. + Expected shape: [num_experts, intermediate_size, hidden_size]. + Interleaving: gate occupies first half of intermediate dimension, + up occupies second half. Transpose to [hidden, half_inter] per expert. + """ + num_experts = self.config.get("num_local_experts", 128) + inter_size = weights.shape[1] + hidden_size = weights.shape[2] + half_inter = inter_size // 2 + gate_part = weights[:, :half_inter, :] # [E, half_inter, hidden] + up_part = weights[:, half_inter:, :] # [E, half_inter, hidden] + + for i in range(min(num_experts, weights.shape[0])): + gate_weight = gate_part[i].T # [hidden, half_inter] + up_weight = up_part[i].T # [hidden, half_inter] + + gate_name = f"model.layers.{bid}.block_sparse_moe.experts.{i}.gate_proj.weight" + up_name = f"model.layers.{bid}.block_sparse_moe.experts.{i}.up_proj.weight" + + # Apply mapping + for key, mapped_key in tensor_key_mapping.items(): + if key in gate_name: + gate_name = gate_name.replace(key, mapped_key) + if key in up_name: + up_name = up_name.replace(key, mapped_key) + + parsed_parameters["tensors"][gate_name] = torch.from_numpy(np.copy(gate_weight)) + parsed_parameters["tensors"][up_name] = torch.from_numpy(np.copy(up_weight)) class BloomTensorProcessor(TensorProcessor): @@ -355,6 +462,7 @@ def _set_moe_expert_tensor(self, weights: np.ndarray, parsed_parameters: dict[st TENSOR_PROCESSORS = { "llama": LlamaTensorProcessor, "qwen2moe": Qwen2MoeTensorProcessor, + "gpt_oss": GptOssTensorProcessor, "qwen3moe": Qwen2MoeTensorProcessor, "bloom": BloomTensorProcessor, "t5": T5TensorProcessor, @@ -416,6 +524,8 @@ def get_gguf_hf_weights_map( model_type = "t5" elif model_type == "minimax_m2": model_type = "minimax-m2" + elif model_type == "gpt_oss": + model_type = "gpt-oss" arch = None for key, value in MODEL_ARCH_NAMES.items(): if value == model_type: @@ -516,6 +626,8 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_lo if "qwen2moe" in architecture: updated_architecture = "qwen2_moe" + elif "gpt_oss" in architecture or "gpt-oss" in architecture: + updated_architecture = "gpt_oss" elif "qwen3moe" in architecture: updated_architecture = "qwen3_moe" elif "minimax-m2" in architecture: @@ -602,6 +714,30 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_lo parsed_parameters["config"]["full_attn_idxs"] = [ i for i, num_kv_heads in enumerate(gguf_num_key_value_heads) if num_kv_heads > 0 ] + + if updated_architecture == "gpt_oss": + # Helper to read keys with the correct prefix + def read_gpt_key(reader, suffix, default=None): + key = f"gpt-oss.{suffix}" + if key in reader.fields: + val = reader.fields[key].parts[0] + if isinstance(val, bytes): + val = val.decode("utf-8") + return val + return default + +# Reconstruct YaRN rope_scaling (only if type is "yarn") + rope_type = read_gpt_key(reader, "rope.scaling.type") + if rope_type == "yarn": + rope_scaling = { + "rope_type": rope_type, + "factor": float(read_gpt_key(reader, "rope.scaling.factor", 1.0)), + "original_max_position_embeddings": int(read_gpt_key(reader, "rope.scaling.original_context_length", 4096)), + "attention_factor": 1.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + } + parsed_parameters["config"]["rope_scaling"] = rope_scaling # retrieve config vocab_size from tokenizer # Please refer to https://github.com/huggingface/transformers/issues/32526 for more details From 073b3d3c291fb26832c52ce30b268e46b840e90d Mon Sep 17 00:00:00 2001 From: sirzechs66 Date: Mon, 30 Mar 2026 17:03:15 +0530 Subject: [PATCH 2/9] =?UTF-8?q?Add=20GGUF=20loading=20test=20suite=20for?= =?UTF-8?q?=20GPT=E2=80=91OSS?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/models/gpt_oss/test_modeling_gpt_oss.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 45b6b0e87b37..5c34e021005c 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -623,3 +623,58 @@ def test_model_matches_original_120b(self): decoded_string = tokenizer.decode(tokens) self.assertTrue(original_output.startswith(decoded_string)) + +@require_torch +class GptOssGgufLoadingTest(unittest.TestCase): + """Test loading GPT‑OSS from GGUF files.""" + + def test_gguf_registration(self): + """Check that the GGUF loader recognises GPT‑OSS architecture and processor.""" + from transformers.modeling_gguf_pytorch_utils import TENSOR_PROCESSORS, GGUF_SUPPORTED_ARCHITECTURES + from transformers.models.gpt_oss import GptOssTensorProcessor + + self.assertIn("gpt_oss", GGUF_SUPPORTED_ARCHITECTURES) + self.assertIn("gpt_oss", TENSOR_PROCESSORS) + self.assertEqual(TENSOR_PROCESSORS["gpt_oss"], GptOssTensorProcessor) + + @slow + def test_load_20b_gguf(self): + """Load the 20B GGUF file from the Hub and verify config reconstruction.""" + import os + from huggingface_hub import hf_hub_download + + repo_id = "unsloth/gpt-oss-20b-GGUF" + gguf_file = "gpt-oss-20b-Q5_K_M.gguf" + + # Download the GGUF file to a local path (cached by huggingface_hub) + try: + local_path = hf_hub_download(repo_id, gguf_file, cache_dir=None) + except Exception as e: + self.skipTest(f"Could not download {gguf_file} from {repo_id}: {e}") + + # Loading model and tokenizer from the GGUF + model = AutoModelForCausalLM.from_pretrained( + repo_id, + gguf_file=gguf_file, + device_map="auto", + torch_dtype=torch.bfloat16, + offload_folder="./offload", + low_cpu_mem_usage=True, + ) + tokenizer = AutoTokenizer.from_pretrained(repo_id, gguf_file=gguf_file) + + + config = model.config + self.assertEqual(config.model_type, "gpt_oss") + self.assertIsInstance(config.rope_scaling, dict) + self.assertEqual(config.rope_scaling.get("rope_type"), "yarn") + self.assertEqual(config.sliding_window, 128) + self.assertEqual(config.num_hidden_layers, 24) + self.assertEqual(config.num_local_experts, 32) + + + inputs = tokenizer("Hello", return_tensors="pt").to(model.device) + with torch.no_grad(): + outputs = model(**inputs) + self.assertIn("logits", outputs) + self.assertEqual(outputs.logits.shape[0], 1) \ No newline at end of file From 8c33e373d8cb40f635ec3d8d6970b0b9b3c41b16 Mon Sep 17 00:00:00 2001 From: sirzechs66 Date: Mon, 30 Mar 2026 17:43:32 +0530 Subject: [PATCH 3/9] docs: add GGUF loading section to gpt_oss.md --- docs/source/en/model_doc/gpt_oss.md | 17 +++++++++++++++++ .../models/gpt_oss/modeling_gpt_oss.py | 1 + tests/models/gpt_oss/test_modeling_gpt_oss.py | 14 ++++++-------- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/docs/source/en/model_doc/gpt_oss.md b/docs/source/en/model_doc/gpt_oss.md index 513710f35f95..ad6ec558a322 100644 --- a/docs/source/en/model_doc/gpt_oss.md +++ b/docs/source/en/model_doc/gpt_oss.md @@ -70,6 +70,23 @@ print(tokenizer.decode(output[0], skip_special_tokens=True)) - SDPA is not supported because attention sinks require direct access to the full attention logits before softmax. Use Flash Attention or Flex Attention instead. - When using Flex Attention, attention sinks require special handling. The `score_mod` function operates on individual score elements rather than the full attention matrix, so sink renormalization is applied after computation using the log-sum-exp (LSE) values returned by Flex Attention. +## Loading GGUF files + +GPT‑OSS models are also available as quantised GGUF files (e.g., from [Unsloth](https://huggingface.co/unsloth) or [ggml‑org](https://huggingface.co/ggml-org)). You can load them directly with `from_pretrained` by passing the `gguf_file` argument: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained( + "path/to/local/directory", + gguf_file="gpt-oss-20b-Q5_K_M.gguf", + device_map="auto", +) +tokenizer = AutoTokenizer.from_pretrained( + "path/to/local/directory", + gguf_file="gpt-oss-20b-Q5_K_M.gguf", +) + ## GptOssConfig [[autodoc]] GptOssConfig diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 18f31ea90379..2a1dfcbd1a00 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -702,4 +702,5 @@ class GptOssForTokenClassification(GenericForTokenClassification, GptOssPreTrain "GptOssForTokenClassification", "GptOssModel", "GptOssPreTrainedModel", + "GptOssTensorProcessor", ] diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 5c34e021005c..ba88ec62400b 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -624,13 +624,14 @@ def test_model_matches_original_120b(self): decoded_string = tokenizer.decode(tokens) self.assertTrue(original_output.startswith(decoded_string)) + @require_torch class GptOssGgufLoadingTest(unittest.TestCase): """Test loading GPT‑OSS from GGUF files.""" def test_gguf_registration(self): """Check that the GGUF loader recognises GPT‑OSS architecture and processor.""" - from transformers.modeling_gguf_pytorch_utils import TENSOR_PROCESSORS, GGUF_SUPPORTED_ARCHITECTURES + from transformers.modeling_gguf_pytorch_utils import GGUF_SUPPORTED_ARCHITECTURES, TENSOR_PROCESSORS from transformers.models.gpt_oss import GptOssTensorProcessor self.assertIn("gpt_oss", GGUF_SUPPORTED_ARCHITECTURES) @@ -640,7 +641,6 @@ def test_gguf_registration(self): @slow def test_load_20b_gguf(self): """Load the 20B GGUF file from the Hub and verify config reconstruction.""" - import os from huggingface_hub import hf_hub_download repo_id = "unsloth/gpt-oss-20b-GGUF" @@ -648,7 +648,7 @@ def test_load_20b_gguf(self): # Download the GGUF file to a local path (cached by huggingface_hub) try: - local_path = hf_hub_download(repo_id, gguf_file, cache_dir=None) + hf_hub_download(repo_id, gguf_file, cache_dir=None) except Exception as e: self.skipTest(f"Could not download {gguf_file} from {repo_id}: {e}") @@ -663,18 +663,16 @@ def test_load_20b_gguf(self): ) tokenizer = AutoTokenizer.from_pretrained(repo_id, gguf_file=gguf_file) - config = model.config self.assertEqual(config.model_type, "gpt_oss") self.assertIsInstance(config.rope_scaling, dict) self.assertEqual(config.rope_scaling.get("rope_type"), "yarn") - self.assertEqual(config.sliding_window, 128) - self.assertEqual(config.num_hidden_layers, 24) + self.assertEqual(config.sliding_window, 128) + self.assertEqual(config.num_hidden_layers, 24) self.assertEqual(config.num_local_experts, 32) - inputs = tokenizer("Hello", return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model(**inputs) self.assertIn("logits", outputs) - self.assertEqual(outputs.logits.shape[0], 1) \ No newline at end of file + self.assertEqual(outputs.logits.shape[0], 1) From 53b7efbe86587182b6d4941f31ec5b3d958875a5 Mon Sep 17 00:00:00 2001 From: sirzechs66 Date: Mon, 30 Mar 2026 18:32:01 +0530 Subject: [PATCH 4/9] fix: correct import of GptOssTensorProcessor in test; remove from model __all__ --- src/transformers/models/gpt_oss/modeling_gpt_oss.py | 1 - tests/models/gpt_oss/test_modeling_gpt_oss.py | 7 +++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 2a1dfcbd1a00..18f31ea90379 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -702,5 +702,4 @@ class GptOssForTokenClassification(GenericForTokenClassification, GptOssPreTrain "GptOssForTokenClassification", "GptOssModel", "GptOssPreTrainedModel", - "GptOssTensorProcessor", ] diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index ba88ec62400b..c35cf08851be 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -631,8 +631,11 @@ class GptOssGgufLoadingTest(unittest.TestCase): def test_gguf_registration(self): """Check that the GGUF loader recognises GPT‑OSS architecture and processor.""" - from transformers.modeling_gguf_pytorch_utils import GGUF_SUPPORTED_ARCHITECTURES, TENSOR_PROCESSORS - from transformers.models.gpt_oss import GptOssTensorProcessor + from transformers.modeling_gguf_pytorch_utils import ( + GGUF_SUPPORTED_ARCHITECTURES, + TENSOR_PROCESSORS, + GptOssTensorProcessor, + ) self.assertIn("gpt_oss", GGUF_SUPPORTED_ARCHITECTURES) self.assertIn("gpt_oss", TENSOR_PROCESSORS) From fcde5f8dbbc6d3a0fe5e3396d5e7cdb7f7744da0 Mon Sep 17 00:00:00 2001 From: sirzechs66 Date: Tue, 31 Mar 2026 23:41:05 +0530 Subject: [PATCH 5/9] =?UTF-8?q?Finalize=20GPT=E2=80=91OSS=20GGUF=20support?= =?UTF-8?q?:=20move=20test,=20adjust=20config=20reconstruction?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/transformers/integrations/ggml.py | 2 +- .../modeling_gguf_pytorch_utils.py | 67 +++++++++++-------- tests/models/gpt_oss/test_modeling_gpt_oss.py | 56 ---------------- tests/quantization/ggml/test_ggml.py | 17 +++++ 4 files changed, 56 insertions(+), 86 deletions(-) diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index 84ec36dcced3..c9ba021c54db 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -94,7 +94,7 @@ "block_count": "num_hidden_layers", "feed_forward_length": "intermediate_size", "embedding_length": "hidden_size", - "rope.dimension_count": None, + "rope.dimension_count": None, "rope.freq_base": "rope_theta", "attention.head_count": "num_attention_heads", "attention.head_count_kv": "num_key_value_heads", diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index c135461f0ea5..5b26f8e65201 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -169,6 +169,8 @@ def _set_moe_expert_tensor(self, weights: np.ndarray, parsed_parameters: dict[st else: # w == "up" out = out.narrow(shard_dim, shard_size, shard_size) out.copy_(torch_weights) + + class GptOssTensorProcessor(TensorProcessor): """ Tensor processor for GPT-OSS models (MoE with 128 experts). @@ -177,14 +179,11 @@ class GptOssTensorProcessor(TensorProcessor): - Interleaving gate and up projections if stored in a combined tensor (gate_up_projs). - Bias tensors (1D) are passed through without transpose. """ + # Regex for separate expert tensors: e.g., blk.0.ffn_down_projs.weight - GGUF_MOE_WEIGHTS_PATTERN = re.compile( - r"blk\.(?P\d+)\.ffn_(?Pdown|gate|up)_projs\.weight$" - ) + GGUF_MOE_WEIGHTS_PATTERN = re.compile(r"blk\.(?P\d+)\.ffn_(?Pdown|gate|up)_projs\.weight$") # Regex for combined gate+up tensor: e.g., blk.0.ffn_gate_up_projs.weight - GGUF_MOE_COMBINED_PATTERN = re.compile( - r"blk\.(?P\d+)\.ffn_gate_up_projs\.weight$" - ) + GGUF_MOE_COMBINED_PATTERN = re.compile(r"blk\.(?P\d+)\.ffn_gate_up_projs\.weight$") def __init__(self, config=None): super().__init__(config=config) @@ -195,9 +194,7 @@ def process(self, weights, name: str, **kwargs): tensor_key_mapping = kwargs.get("tensor_key_mapping") parsed_parameters = kwargs.get("parsed_parameters") if tensor_key_mapping and parsed_parameters: - self._split_moe_expert_tensor( - weights, parsed_parameters, m["bid"], m["proj"], tensor_key_mapping - ) + self._split_moe_expert_tensor(weights, parsed_parameters, m["bid"], m["proj"], tensor_key_mapping) return GGUFTensor(weights, None, {}) # signal handled # 2. Handle combined gate+up tensor @@ -205,9 +202,7 @@ def process(self, weights, name: str, **kwargs): tensor_key_mapping = kwargs.get("tensor_key_mapping") parsed_parameters = kwargs.get("parsed_parameters") if tensor_key_mapping and parsed_parameters: - self._interleave_gate_up_tensor( - weights, parsed_parameters, m["bid"], tensor_key_mapping - ) + self._interleave_gate_up_tensor(weights, parsed_parameters, m["bid"], tensor_key_mapping) return GGUFTensor(weights, None, {}) # 3. Bias tensors (1D) → no transpose @@ -255,14 +250,13 @@ def _interleave_gate_up_tensor( """ num_experts = self.config.get("num_local_experts", 128) inter_size = weights.shape[1] - hidden_size = weights.shape[2] half_inter = inter_size // 2 - gate_part = weights[:, :half_inter, :] # [E, half_inter, hidden] - up_part = weights[:, half_inter:, :] # [E, half_inter, hidden] + gate_part = weights[:, :half_inter, :] # [E, half_inter, hidden] + up_part = weights[:, half_inter:, :] # [E, half_inter, hidden] for i in range(min(num_experts, weights.shape[0])): - gate_weight = gate_part[i].T # [hidden, half_inter] - up_weight = up_part[i].T # [hidden, half_inter] + gate_weight = gate_part[i].T # [hidden, half_inter] + up_weight = up_part[i].T # [hidden, half_inter] gate_name = f"model.layers.{bid}.block_sparse_moe.experts.{i}.gate_proj.weight" up_name = f"model.layers.{bid}.block_sparse_moe.experts.{i}.up_proj.weight" @@ -714,7 +708,7 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_lo parsed_parameters["config"]["full_attn_idxs"] = [ i for i, num_kv_heads in enumerate(gguf_num_key_value_heads) if num_kv_heads > 0 ] - + if updated_architecture == "gpt_oss": # Helper to read keys with the correct prefix def read_gpt_key(reader, suffix, default=None): @@ -724,19 +718,34 @@ def read_gpt_key(reader, suffix, default=None): if isinstance(val, bytes): val = val.decode("utf-8") return val - return default + return default -# Reconstruct YaRN rope_scaling (only if type is "yarn") + # Reconstruct rope_scaling from GGUF metadata rope_type = read_gpt_key(reader, "rope.scaling.type") - if rope_type == "yarn": - rope_scaling = { - "rope_type": rope_type, - "factor": float(read_gpt_key(reader, "rope.scaling.factor", 1.0)), - "original_max_position_embeddings": int(read_gpt_key(reader, "rope.scaling.original_context_length", 4096)), - "attention_factor": 1.0, - "beta_fast": 32.0, - "beta_slow": 1.0, - } + if rope_type is not None: + rope_scaling = {"rope_type": rope_type} + + # Collect all rope.scaling keys dynamically + for key in reader.fields: + if not key.startswith("gpt-oss.rope.scaling."): + continue + suffix = key[len("gpt-oss.rope.scaling.") :] + if suffix == "type": + continue + value = reader.fields[key].parts[0] + if isinstance(value, bytes): + value = value.decode("utf-8") + # Convert to appropriate type + if suffix in ("factor", "attention_factor", "beta_fast", "beta_slow"): + value = float(value) + elif suffix in ("original_context_length", "original_max_position_embeddings"): + # Map GGUF's original_context_length to HF's original_max_position_embeddings + suffix = "original_max_position_embeddings" + value = int(value) + else: + pass + rope_scaling[suffix] = value + parsed_parameters["config"]["rope_scaling"] = rope_scaling # retrieve config vocab_size from tokenizer diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index c35cf08851be..45b6b0e87b37 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -623,59 +623,3 @@ def test_model_matches_original_120b(self): decoded_string = tokenizer.decode(tokens) self.assertTrue(original_output.startswith(decoded_string)) - - -@require_torch -class GptOssGgufLoadingTest(unittest.TestCase): - """Test loading GPT‑OSS from GGUF files.""" - - def test_gguf_registration(self): - """Check that the GGUF loader recognises GPT‑OSS architecture and processor.""" - from transformers.modeling_gguf_pytorch_utils import ( - GGUF_SUPPORTED_ARCHITECTURES, - TENSOR_PROCESSORS, - GptOssTensorProcessor, - ) - - self.assertIn("gpt_oss", GGUF_SUPPORTED_ARCHITECTURES) - self.assertIn("gpt_oss", TENSOR_PROCESSORS) - self.assertEqual(TENSOR_PROCESSORS["gpt_oss"], GptOssTensorProcessor) - - @slow - def test_load_20b_gguf(self): - """Load the 20B GGUF file from the Hub and verify config reconstruction.""" - from huggingface_hub import hf_hub_download - - repo_id = "unsloth/gpt-oss-20b-GGUF" - gguf_file = "gpt-oss-20b-Q5_K_M.gguf" - - # Download the GGUF file to a local path (cached by huggingface_hub) - try: - hf_hub_download(repo_id, gguf_file, cache_dir=None) - except Exception as e: - self.skipTest(f"Could not download {gguf_file} from {repo_id}: {e}") - - # Loading model and tokenizer from the GGUF - model = AutoModelForCausalLM.from_pretrained( - repo_id, - gguf_file=gguf_file, - device_map="auto", - torch_dtype=torch.bfloat16, - offload_folder="./offload", - low_cpu_mem_usage=True, - ) - tokenizer = AutoTokenizer.from_pretrained(repo_id, gguf_file=gguf_file) - - config = model.config - self.assertEqual(config.model_type, "gpt_oss") - self.assertIsInstance(config.rope_scaling, dict) - self.assertEqual(config.rope_scaling.get("rope_type"), "yarn") - self.assertEqual(config.sliding_window, 128) - self.assertEqual(config.num_hidden_layers, 24) - self.assertEqual(config.num_local_experts, 32) - - inputs = tokenizer("Hello", return_tensors="pt").to(model.device) - with torch.no_grad(): - outputs = model(**inputs) - self.assertIn("logits", outputs) - self.assertEqual(outputs.logits.shape[0], 1) diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index 763f8ac40502..3fb1199188a2 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -351,6 +351,8 @@ class GgufModelTests(unittest.TestCase): q4_k_m_qwen3moe_model_id = "Qwen3-30B-A3B-Q4_K_M.gguf" q8_0_umt5_encoder_model_id = "umt5-xxl-encoder-Q8_0.gguf" q4_k_m_lfm2_model_id = "LFM2-1.2B-Q4_K_M.gguf" + gpt_oss_model_id = "unsloth/gpt-oss-20b-GGUF" + gpt_oss_gguf_file = "gpt-oss-20b-Q5_K_M.gguf" example_text = "Hello" @@ -384,6 +386,21 @@ def test_qwen2_q4_0(self): EXPECTED_TEXT = "Hello.jsoup\n\nI am a beginner" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_gpt_oss_q5_k_m(self): + tokenizer = AutoTokenizer.from_pretrained(self.gpt_oss_model_id, gguf_file=self.gpt_oss_gguf_file) + model = AutoModelForCausalLM.from_pretrained( + self.gpt_oss_model_id, + gguf_file=self.gpt_oss_gguf_file, + device_map="auto", + dtype=torch.float16, + ) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, I just want to say that I am just" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_qwen2moe_q8(self): tokenizer = AutoTokenizer.from_pretrained(self.qwen2moe_model_id, gguf_file=self.q8_qwen2moe_model_id) model = AutoModelForCausalLM.from_pretrained( From c6945b3feb0beb4b23cd65b35b5464be8480e7d0 Mon Sep 17 00:00:00 2001 From: sirzechs66 Date: Wed, 15 Apr 2026 18:06:37 +0530 Subject: [PATCH 6/9] fixed docs not closing example bracket --- docs/source/en/model_doc/gpt_oss.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/model_doc/gpt_oss.md b/docs/source/en/model_doc/gpt_oss.md index ad6ec558a322..9df8db220ef0 100644 --- a/docs/source/en/model_doc/gpt_oss.md +++ b/docs/source/en/model_doc/gpt_oss.md @@ -86,6 +86,7 @@ tokenizer = AutoTokenizer.from_pretrained( "path/to/local/directory", gguf_file="gpt-oss-20b-Q5_K_M.gguf", ) +``` ## GptOssConfig From af5ad575b3a4f049e97cf98cc31556b090343f4e Mon Sep 17 00:00:00 2001 From: sirzechs66 Date: Sat, 18 Apr 2026 14:33:04 +0530 Subject: [PATCH 7/9] Fix lint: remove trailing whitespace --- tests/quantization/ggml/test_ggml.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index 3fb1199188a2..aa5cdbc7adc6 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -397,7 +397,6 @@ def test_gpt_oss_q5_k_m(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = "Hello, I just want to say that I am just" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) From a3b7e4ec9f117e09256959558db4bf980325396d Mon Sep 17 00:00:00 2001 From: sirzechs66 Date: Mon, 20 Apr 2026 13:22:14 +0530 Subject: [PATCH 8/9] Fix tensor construction consistency --- src/transformers/modeling_gguf_pytorch_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 5b26f8e65201..7513b5c43b38 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -233,7 +233,7 @@ def _split_moe_expert_tensor( if key in hf_name: hf_name = hf_name.replace(key, mapped_key) # Store the tensor - parsed_parameters["tensors"][hf_name] = torch.from_numpy(np.copy(expert_weight)) + parsed_parameters["tensors"][hf_name] = torch.tensor(expert_weight, copy=True) def _interleave_gate_up_tensor( self, @@ -268,8 +268,8 @@ def _interleave_gate_up_tensor( if key in up_name: up_name = up_name.replace(key, mapped_key) - parsed_parameters["tensors"][gate_name] = torch.from_numpy(np.copy(gate_weight)) - parsed_parameters["tensors"][up_name] = torch.from_numpy(np.copy(up_weight)) + parsed_parameters["tensors"][gate_name] = torch.tensor(gate_weight, copy=True) + parsed_parameters["tensors"][up_name] = torch.tensor(up_weight, copy=True) class BloomTensorProcessor(TensorProcessor): From e08938e58e7e6afb77421e544b8dd8d09a2cff3d Mon Sep 17 00:00:00 2001 From: sirzechs66 Date: Mon, 20 Apr 2026 21:36:15 +0530 Subject: [PATCH 9/9] reverting to original docs --- docs/source/en/model_doc/gpt_oss.md | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/docs/source/en/model_doc/gpt_oss.md b/docs/source/en/model_doc/gpt_oss.md index 9df8db220ef0..513710f35f95 100644 --- a/docs/source/en/model_doc/gpt_oss.md +++ b/docs/source/en/model_doc/gpt_oss.md @@ -70,24 +70,6 @@ print(tokenizer.decode(output[0], skip_special_tokens=True)) - SDPA is not supported because attention sinks require direct access to the full attention logits before softmax. Use Flash Attention or Flex Attention instead. - When using Flex Attention, attention sinks require special handling. The `score_mod` function operates on individual score elements rather than the full attention matrix, so sink renormalization is applied after computation using the log-sum-exp (LSE) values returned by Flex Attention. -## Loading GGUF files - -GPT‑OSS models are also available as quantised GGUF files (e.g., from [Unsloth](https://huggingface.co/unsloth) or [ggml‑org](https://huggingface.co/ggml-org)). You can load them directly with `from_pretrained` by passing the `gguf_file` argument: - -```python -from transformers import AutoModelForCausalLM, AutoTokenizer - -model = AutoModelForCausalLM.from_pretrained( - "path/to/local/directory", - gguf_file="gpt-oss-20b-Q5_K_M.gguf", - device_map="auto", -) -tokenizer = AutoTokenizer.from_pretrained( - "path/to/local/directory", - gguf_file="gpt-oss-20b-Q5_K_M.gguf", -) -``` - ## GptOssConfig [[autodoc]] GptOssConfig