diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index c9ba021c54db..2971f4fdfbdb 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -320,6 +320,23 @@ "vocab_size": "vocab_size", "expert_gating_func": "scoring_func", }, + "llama4": { + "context_length": "max_position_embeddings", + "block_count": "num_hidden_layers", + "feed_forward_length": "intermediate_size_mlp", + "expert_feed_forward_length": "intermediate_size", + "embedding_length": "hidden_size", + "rope.dimension_count": None, + "rope.freq_base": "rope_theta", + "attention.key_length": "head_dim", + "attention.head_count": "num_attention_heads", + "attention.head_count_kv": "num_key_value_heads", + "attention.layer_norm_rms_epsilon": "rms_norm_eps", + "vocab_size": "vocab_size", + "expert_count": "num_local_experts", + "expert_used_count": "num_experts_per_tok", + "interleave_moe_layer_step": "interleave_moe_layer_step", + }, } GGUF_TOKENIZER_MAPPING = { @@ -787,6 +804,7 @@ def converted(self) -> Tokenizer: GGUF_TO_FAST_CONVERTERS = { "llama": GGUFLlamaConverter, + "llama4_text": GGUFLlamaConverter, "qwen2": GGUFQwen2Converter, "qwen2_moe": GGUFQwen2Converter, "qwen3": GGUFQwen2Converter, diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 2de6cc13fc85..b24e4fcd38bb 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -453,8 +453,57 @@ def _set_moe_expert_tensor(self, weights: np.ndarray, parsed_parameters: dict[st out.copy_(torch_weights) +class Llama4TensorProcessor(TensorProcessor): + HF_MOE_GATE_UP_PATTERN = re.compile(r"(?:model\.)?layers\.(?P\d+)\.feed_forward\.experts\.gate_up_proj$") + HF_MOE_DOWN_PATTERN = re.compile(r"(?:model\.)?layers\.(?P\d+)\.feed_forward\.experts\.down_proj$") + GGUF_MOE_WEIGHTS_PATTERN = re.compile(r".*\.ffn_(?Pgate|up|down)_exps\.weight$") + + def __init__(self, config=None): + super().__init__(config=config) + + def perform_fallback_tensor_mapping( + self, gguf_to_hf_name_map: dict[str, str], suffix: str, qual_name: str, hf_name: str + ): + if m := re.fullmatch(self.HF_MOE_GATE_UP_PATTERN, hf_name): + full_hf_name = qual_name + hf_name + gguf_to_hf_name_map[f"blk.{m['bid']}.ffn_gate_exps.weight"] = full_hf_name + gguf_to_hf_name_map[f"blk.{m['bid']}.ffn_up_exps.weight"] = full_hf_name + elif m := re.fullmatch(self.HF_MOE_DOWN_PATTERN, hf_name): + full_hf_name = qual_name + hf_name + gguf_to_hf_name_map[f"blk.{m['bid']}.ffn_down_exps.weight"] = full_hf_name + + def process(self, weights, name: str, **kwargs): + if m := re.fullmatch(self.GGUF_MOE_WEIGHTS_PATTERN, name): + tensor_key_mapping = kwargs.get("tensor_key_mapping") + parsed_parameters = kwargs.get("parsed_parameters") + if tensor_key_mapping and name in tensor_key_mapping: + self._set_moe_expert_tensor(weights, parsed_parameters, tensor_key_mapping[name], m["w"]) + return GGUFTensor(weights, None, {}) + return GGUFTensor(weights, name, {}) + + def _set_moe_expert_tensor(self, weights: np.ndarray, parsed_parameters: dict[str, dict], hf_name: str, w: str): + torch_weights = torch.from_numpy(np.ascontiguousarray(np.swapaxes(weights, -1, -2))) + if w == "down": + parsed_parameters["tensors"][hf_name] = torch_weights + return + # Merge gate and up into gate_up_proj: [E, hidden, 2*expert_dim], gate first then up. + shape = list(torch_weights.shape) + shard_dim = -1 + shard_size = shape[shard_dim] + shape[shard_dim] = shard_size * 2 + if hf_name not in parsed_parameters["tensors"]: + parsed_parameters["tensors"][hf_name] = torch.zeros(shape, dtype=torch_weights.dtype) + out: torch.Tensor = parsed_parameters["tensors"][hf_name] + if w == "gate": + out = out.narrow(shard_dim, 0, shard_size) + else: # w == "up" + out = out.narrow(shard_dim, shard_size, shard_size) + out.copy_(torch_weights) + + TENSOR_PROCESSORS = { "llama": LlamaTensorProcessor, + "llama4": Llama4TensorProcessor, "qwen2moe": Qwen2MoeTensorProcessor, "gpt_oss": GptOssTensorProcessor, "qwen3moe": Qwen2MoeTensorProcessor, @@ -518,6 +567,10 @@ def get_gguf_hf_weights_map( model_type = "t5" elif model_type == "minimax_m2": model_type = "minimax-m2" + elif model_type == "llama4_text": + # GGUF Llama 4 files only contain text weights; the text-only config + # uses `llama4_text` in transformers but the GGUF arch key is `llama4`. + model_type = "llama4" elif model_type == "gpt_oss": model_type = "gpt-oss" arch = None @@ -695,6 +748,18 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_lo if parsed_parameters["config"]["model_type"] == "gemma3": parsed_parameters["config"]["model_type"] = "gemma3_text" + # Llama 4 GGUF checkpoints only contain the text backbone. Rewrite the model_type to + # the text-only config and nest rope_theta under rope_parameters (Llama4TextConfig is + # @strict and stores rope params in a nested dict rather than a top-level field). + if parsed_parameters["config"]["model_type"] == "llama4": + parsed_parameters["config"]["model_type"] = "llama4_text" + rope_theta = parsed_parameters["config"].pop("rope_theta", None) + if rope_theta is not None: + parsed_parameters["config"]["rope_parameters"] = { + "rope_type": "default", + "rope_theta": float(rope_theta), + } + # MiniMax-M2: convert expert_gating_func integer to scoring_func string if parsed_parameters["config"].get("model_type") == "minimax_m2": _gating_func_map = {0: "none", 1: "softmax", 2: "sigmoid"} diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index aa5cdbc7adc6..c38191c5a680 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -311,6 +311,7 @@ class GgufModelTests(unittest.TestCase): qwen3moe_model_id = "Qwen/Qwen3-30B-A3B-GGUF" umt5_encoder_model_id = "city96/umt5-xxl-encoder-gguf" lfm2_model_id = "LiquidAI/LFM2-1.2B-GGUF" + llama4_model_id = "unsloth/Llama-4-Scout-17B-16E-Instruct-GGUF" q4_0_phi3_model_id = "Phi-3-mini-4k-instruct-q4.gguf" q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf" @@ -351,6 +352,7 @@ class GgufModelTests(unittest.TestCase): q4_k_m_qwen3moe_model_id = "Qwen3-30B-A3B-Q4_K_M.gguf" q8_0_umt5_encoder_model_id = "umt5-xxl-encoder-Q8_0.gguf" q4_k_m_lfm2_model_id = "LFM2-1.2B-Q4_K_M.gguf" + q2_k_l_llama4_model_id = "Llama-4-Scout-17B-16E-Instruct-Q2_K_L.gguf" gpt_oss_model_id = "unsloth/gpt-oss-20b-GGUF" gpt_oss_gguf_file = "gpt-oss-20b-Q5_K_M.gguf" @@ -1145,3 +1147,28 @@ def test_lfm2_q4_k_m(self): EXPECTED_TEXT = "Hello Atari 2600! es un videoj" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + @unittest.skipUnless(is_gguf_available("0.17.0"), "test requires gguf version >= 0.17.0") + def test_llama4_q2_k_l_tokenizer(self): + tokenizer = AutoTokenizer.from_pretrained(self.llama4_model_id, gguf_file=self.q2_k_l_llama4_model_id) + with tempfile.TemporaryDirectory() as tmpdirname: + tokenizer.save_pretrained(tmpdirname) + tokenizer = AutoTokenizer.from_pretrained(tmpdirname) + special_sentence = "สวัสดี" + predicted_text = tokenizer.decode(tokenizer.encode(special_sentence, return_tensors="pt")[0]) + self.assertEqual(predicted_text, "<|begin_of_text|>" + special_sentence) + + @unittest.skipUnless(is_gguf_available("0.17.0"), "test requires gguf version >= 0.17.0") + def test_llama4_q2_k_l(self): + tokenizer = AutoTokenizer.from_pretrained(self.llama4_model_id, gguf_file=self.q2_k_l_llama4_model_id) + model = AutoModelForCausalLM.from_pretrained( + self.llama4_model_id, + gguf_file=self.q2_k_l_llama4_model_id, + dtype=torch.float16, + ) + + text = tokenizer(self.example_text, return_tensors="pt")["input_ids"] + out = model.generate(text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, I'm here to help. What" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)