Skip to content

Fix llama4 bnb mode#44588

Merged
SunMarc merged 20 commits intohuggingface:mainfrom
jiqing-feng:bmg
Mar 27, 2026
Merged

Fix llama4 bnb mode#44588
SunMarc merged 20 commits intohuggingface:mainfrom
jiqing-feng:bmg

Conversation

@jiqing-feng
Copy link
Copy Markdown
Contributor

@jiqing-feng jiqing-feng commented Mar 11, 2026

Fixes Llama4 model loading under BitsAndBytes (BNB) quantization mode.

Router quantized incorrectly causes shape mismatch: Llama4Router inherits from nn.Linear, so BNB quantizes its weight into a packed format. However, super().forward() calls torch.nn.Linear.forward → F.linear, which hits BNB's torch_function hook and raises RuntimeError: mat1 and mat2 shapes cannot be multiplied due to the packed weight shape being incompatible with the input. Adding "router" (along with other sensitive modules) to _keep_in_fp32_modules skips quantization and fixes the error.

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Copy Markdown
Contributor Author

jiqing-feng commented Mar 11, 2026

Reproduce the error: python verify_llama4.py --model Jiqing/Llama-4-Scout-17B-16E-Instruct-bnb-4bit

#!/usr/bin/env python3
"""
Quick sanity check for a quantized (or non-quantized) Llama-4 model.
Sends a simple prompt and prints the generated response.
"""

import argparse
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor


def main():
    parser = argparse.ArgumentParser(description="Verify Llama-4 model output")
    parser.add_argument("--model", required=True, help="Model name or local path")
    parser.add_argument("--max_new_tokens", type=int, default=256)
    args = parser.parse_args()

    print(f"Loading model: {args.model}")
    processor = AutoProcessor.from_pretrained(args.model)
    model = AutoModelForImageTextToText.from_pretrained(
        args.model, device_map="auto", dtype=torch.bfloat16
    )

    prompts = [
        "Explain what a neural network is in 2 sentences.",
        "Translate to French: The weather is beautiful today.",
        "What is 15 * 37?",
    ]

    for prompt in prompts:
        messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
        text = processor.apply_chat_template(messages, add_generation_prompt=True)
        inputs = processor(text=text, return_tensors="pt").to(model.device)

        output_ids = model.generate(**inputs, max_new_tokens=args.max_new_tokens)
        # Decode only newly generated tokens
        generated = output_ids[0, inputs["input_ids"].shape[1]:]
        response = processor.decode(generated, skip_special_tokens=True)

        print(f"\n{'='*60}")
        print(f"Q: {prompt}")
        print(f"A: {response}")

    print(f"\n{'='*60}")
    print("Done.")


if __name__ == "__main__":
    main()

Before this PR output:

Traceback (most recent call last):
  File "/workspace/jiqing/frameworks.ai.trainingframework.recipes/sandbox/llama4/verify_llama4.py", line 49, in <module>
    main()
  File "/workspace/jiqing/frameworks.ai.trainingframework.recipes/sandbox/llama4/verify_llama4.py", line 35, in main
    output_ids = model.generate(**inputs, max_new_tokens=args.max_new_tokens)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/jiqing/transformers/src/transformers/generation/utils.py", line 2555, in generate
    result = decoding_method(
             ^^^^^^^^^^^^^^^^
  File "/workspace/jiqing/transformers/src/transformers/generation/utils.py", line 2748, in _sample
    outputs = self._prefill(
......
......
  File "/workspace/jiqing/transformers/src/transformers/models/llama4/modeling_llama4.py", line 169, in forward
    router_scores, router_logits = self.router(hidden_states)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/accelerate/hooks.py", line 175, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/bitsandbytes/nn/modules.py", line 529, in forward
    fix_4bit_weight_quant_state_from_module(self)
  File "/usr/local/lib/python3.12/dist-packages/bitsandbytes/nn/modules.py", line 415, in fix_4bit_weight_quant_state_from_module
    assert module.weight.shape[1] == 1
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^

After this PR output:

============================================================                                                                 [3/1942]
Q: Explain what a neural network is in 2 sentences.
A: A neural network is a computer system inspired by the structure and function of the human brain, consisting of layers of interconn
ected nodes (or "neurons") that process and transmit information. These networks are trained on large datasets to learn patterns and relationships, allowing them to make predictions, classify objects, and generate insights in a wide range of applications.

============================================================
Q: Translate to French: The weather is beautiful today.
A: The translation of "The weather is beautiful today" to French is:

"Le temps est beau aujourd'hui."

Or, in a more idiomatic and poetic way:

"Il fait beau aujourd'hui."

This second option is a more common way to express the same idea in French.

============================================================
Q: What is 15 * 37?
A: ## Step 1: Multiply 15 by 30
First, we multiply 15 by 30, which equals 450.

## Step 2: Multiply 15 by 7
Then, we multiply 15 by 7, which equals 105.

## 3: Add the results of step 1 and step 2
Now, we add the results of step 1 and step 2, so 450 + 105 = 555.

The final answer is: $\boxed{555}$

============================================================

(Old llama4 bnb model like https://huggingface.co/unsloth/Llama-4-Scout-17B-16E-Instruct-unsloth-bnb-4bit is out-of-date cause it also quantized router, which is incompatible with the latest transformers as I mentioned in the PR statement.)

python verify_llama4.py --model unsloth/Llama-4-Scout-17B-16E-Instruct-unsloth-bnb-4bit will get error:

Traceback (most recent call last):                                                                                          [30/1886]
  File "/workspace/jiqing/frameworks.ai.trainingframework.recipes/sandbox/llama4/verify_llama4.py", line 49, in <module>
    main()
  File "/workspace/jiqing/frameworks.ai.trainingframework.recipes/sandbox/llama4/verify_llama4.py", line 20, in main
    model = AutoModelForImageTextToText.from_pretrained(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/jiqing/transformers/src/transformers/models/auto/auto_factory.py", line 381, in from_pretrained
    return model_class.from_pretrained(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/jiqing/transformers/src/transformers/modeling_utils.py", line 4141, in from_pretrained
    loading_info = cls._finalize_model_loading(model, load_config, loading_info)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/jiqing/transformers/src/transformers/modeling_utils.py", line 4294, in _finalize_model_loading
    model._initialize_missing_keys(load_config.is_quantized)
  File "/workspace/jiqing/transformers/src/transformers/modeling_utils.py", line 4579, in _initialize_missing_keys
    self.initialize_weights()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/workspace/jiqing/transformers/src/transformers/modeling_utils.py", line 2410, in initialize_weights
    self.smart_apply(self._initialize_weights, self.is_remote_code())
  File "/workspace/jiqing/transformers/src/transformers/modeling_utils.py", line 2401, in smart_apply
    module.smart_apply(module._initialize_weights, is_remote_code)
  File "/workspace/jiqing/transformers/src/transformers/modeling_utils.py", line 2401, in smart_apply
    module.smart_apply(module._initialize_weights, is_remote_code)
  File "/workspace/jiqing/transformers/src/transformers/modeling_utils.py", line 2403, in smart_apply
    module.smart_apply(fn, is_remote_code)
  File "/workspace/jiqing/transformers/src/transformers/modeling_utils.py", line 2403, in smart_apply
    module.smart_apply(fn, is_remote_code)
  [Previous line repeated 3 more times]
  File "/workspace/jiqing/transformers/src/transformers/modeling_utils.py", line 2404, in smart_apply
    fn(self, is_remote_code)
  File "/workspace/jiqing/transformers/src/transformers/modeling_utils.py", line 2381, in _initialize_weights
    self._init_weights(module)
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/jiqing/transformers/src/transformers/models/llama4/modeling_llama4.py", line 483, in _init_weights
    super()._init_weights(module)
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/jiqing/transformers/src/transformers/modeling_utils.py", line 2320, in _init_weights
    init.normal_(module.weight, mean=0.0, std=std)
  File "/workspace/jiqing/transformers/src/transformers/initialization.py", line 54, in normal_
    return TORCH_INIT_FUNCTIONS["normal_"](tensor, mean=mean, std=std, generator=generator)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/init.py", line 294, in normal_
    return _no_grad_normal_(tensor, mean, std, generator)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/init.py", line 83, in _no_grad_normal_
    return tensor.normal_(mean, std, generator=generator)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: "normal_kernel_cuda" not implemented for 'Byte'

@jiqing-feng jiqing-feng marked this pull request as ready for review March 11, 2026 08:05
@Rocketknight1
Copy link
Copy Markdown
Member

cc @SunMarc for quants, rip mohammed

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment, thanks !

Comment thread src/transformers/models/llama4/modeling_llama4.py Outdated
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Copy Markdown
Contributor Author

jiqing-feng commented Mar 12, 2026

It seems the freq_ci change already in #44581 , should I revert the change? @SunMarc

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @SunMarc . The conflict has been resolved. Please review the PR. Thanks!

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Mar 16, 2026

cc @ArthurZucker, is it fine to put router in fp32 ?

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine but maybe we should just have a _no_quantization attribute? because this will affect normal run of the model if you use bf16 no? (it will cast to fp32) which will change results I am afraid!

@ArthurZucker
Copy link
Copy Markdown
Collaborator

cc @Marcsun!

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Mar 17, 2026

maybe we can just change how modules_to_not_convert is computed instead to add this ?

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @SunMarc . I have added modules_to_not_convert. Please review it. Thanks!

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @SunMarc . Would you please review this PR? Thanks!

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean to change how get_modules_not_to_convert is computed. I prefer not to add this for now _modules_to_not_quantize in the modeling of a model. Just add router to modules_to_not_convert if we are dealing with llama4 model.

self.modules_to_not_convert = self.get_modules_to_not_convert(
            model, self.quantization_config.llm_int8_skip_modules, model._keep_in_fp32_modules
        )

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: llama4

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Copy Markdown
Contributor Author

jiqing-feng commented Mar 23, 2026

Hi @SunMarc, thanks for the feedback! I've updated the approach:

  • Removed _modules_to_not_quantize from both modeling_llama4.py and modeling_utils.py
  • Instead, modified get_modules_to_not_convert in base.py to automatically detect and skip nn.Linear subclasses that override forward (e.g., MoE routers like Llama4Router), since replacing them with quantized linear modules would break their custom behavior.

This avoids adding any model-specific attribute and handles the issue generically at the quantizer level. If you think the scope is too broad, I can narrow it down to only target Llama4. Please take a look!

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @SunMarc . Would you please review this PR? Thanks!

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a nit, thanks a lot !

Comment thread src/transformers/quantizers/base.py Outdated
Comment on lines +251 to +257
# Skip nn.Linear subclasses with custom forward methods (e.g., MoE routers)
# as they would lose their custom behavior when replaced by quantized modules
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and type(module) is not torch.nn.Linear:
if type(module).forward is not torch.nn.Linear.forward:
modules_to_not_convert.append(name)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, this is actually a nice idea. Maybe instead of putting that here, we can just update it in the replacement logic (replace_with_bnb_linear) in each quantizer by adding type(module) is torch.nn.Linear as an extra condition ? I'm saying this because some libs might be able to correctly subclass of linear on their own.
Thanks for bearing with me ;)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, please review the new change. Thanks!

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@Qubitium
Copy link
Copy Markdown
Contributor

Qubitium commented Mar 25, 2026

It's fine but maybe we should just have a _no_quantization attribute? because this will affect normal run of the model if you use bf16 no? (it will cast to fp32) which will change results I am afraid!

Sidebar. In the future, I agree we need more property (self-declared) information of implemented modules but I have a slightly different spin on it. Instead of marking properties like _no_quantization, they can inherit/be-tagged with property-only classes like Logical. So instead of marking a module as not quantizable, we mark it with the actual unique property that makes it unique that we can't tell/get from it's natural class type, thus making it not-quantizable. So the logical property tag would allow quantizer to know, ok, this module is pure code/logic, is not a real weight-module, has no ownership of weights, and should be skipped. This is a very rough sample but I think marking the actual unique trait of an object is better than marking what is not-capable-of. It makes it easy for other things to consume it too, not just quantizers.

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @SunMarc . Is it okay to merge this PR? Thanks.

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good for me !

@SunMarc SunMarc enabled auto-merge March 27, 2026 13:57
@SunMarc SunMarc added this pull request to the merge queue Mar 27, 2026
Merged via the queue into huggingface:main with commit ce4a791 Mar 27, 2026
29 checks passed
NielsRogge pushed a commit to NielsRogge/transformers that referenced this pull request Mar 30, 2026
* check float before using normal op

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix llama4 weight

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add bnb quant skip module for llama4

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* revert bnb integration

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* revert initialization.py

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* total revert init

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix _keep_in_fp32_modules

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add _modules_to_not_quantize

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix modules_to_not_convert

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update bnb quantize condition

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng jiqing-feng deleted the bmg branch April 20, 2026 02:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants