Fix llama4 bnb mode#44588
Conversation
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
|
Reproduce the error: #!/usr/bin/env python3
"""
Quick sanity check for a quantized (or non-quantized) Llama-4 model.
Sends a simple prompt and prints the generated response.
"""
import argparse
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
def main():
parser = argparse.ArgumentParser(description="Verify Llama-4 model output")
parser.add_argument("--model", required=True, help="Model name or local path")
parser.add_argument("--max_new_tokens", type=int, default=256)
args = parser.parse_args()
print(f"Loading model: {args.model}")
processor = AutoProcessor.from_pretrained(args.model)
model = AutoModelForImageTextToText.from_pretrained(
args.model, device_map="auto", dtype=torch.bfloat16
)
prompts = [
"Explain what a neural network is in 2 sentences.",
"Translate to French: The weather is beautiful today.",
"What is 15 * 37?",
]
for prompt in prompts:
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=text, return_tensors="pt").to(model.device)
output_ids = model.generate(**inputs, max_new_tokens=args.max_new_tokens)
# Decode only newly generated tokens
generated = output_ids[0, inputs["input_ids"].shape[1]:]
response = processor.decode(generated, skip_special_tokens=True)
print(f"\n{'='*60}")
print(f"Q: {prompt}")
print(f"A: {response}")
print(f"\n{'='*60}")
print("Done.")
if __name__ == "__main__":
main()Before this PR output: After this PR output: (Old llama4 bnb model like https://huggingface.co/unsloth/Llama-4-Scout-17B-16E-Instruct-unsloth-bnb-4bit is out-of-date cause it also quantized router, which is incompatible with the latest transformers as I mentioned in the PR statement.)
|
|
cc @SunMarc for quants, rip mohammed |
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
|
Hi @SunMarc . The conflict has been resolved. Please review the PR. Thanks! |
|
cc @ArthurZucker, is it fine to put router in fp32 ? |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
It's fine but maybe we should just have a _no_quantization attribute? because this will affect normal run of the model if you use bf16 no? (it will cast to fp32) which will change results I am afraid!
|
cc @Marcsun! |
|
maybe we can just change how |
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
|
Hi @SunMarc . I have added |
|
Hi @SunMarc . Would you please review this PR? Thanks! |
SunMarc
left a comment
There was a problem hiding this comment.
I mean to change how get_modules_not_to_convert is computed. I prefer not to add this for now _modules_to_not_quantize in the modeling of a model. Just add router to modules_to_not_convert if we are dealing with llama4 model.
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.llm_int8_skip_modules, model._keep_in_fp32_modules
)
|
[For maintainers] Suggested jobs to run (before merge) run-slow: llama4 |
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
|
Hi @SunMarc, thanks for the feedback! I've updated the approach:
This avoids adding any model-specific attribute and handles the issue generically at the quantizer level. If you think the scope is too broad, I can narrow it down to only target Llama4. Please take a look! |
|
Hi @SunMarc . Would you please review this PR? Thanks! |
| # Skip nn.Linear subclasses with custom forward methods (e.g., MoE routers) | ||
| # as they would lose their custom behavior when replaced by quantized modules | ||
| for name, module in model.named_modules(): | ||
| if isinstance(module, torch.nn.Linear) and type(module) is not torch.nn.Linear: | ||
| if type(module).forward is not torch.nn.Linear.forward: | ||
| modules_to_not_convert.append(name) | ||
|
|
There was a problem hiding this comment.
thanks, this is actually a nice idea. Maybe instead of putting that here, we can just update it in the replacement logic (replace_with_bnb_linear) in each quantizer by adding type(module) is torch.nn.Linear as an extra condition ? I'm saying this because some libs might be able to correctly subclass of linear on their own.
Thanks for bearing with me ;)
There was a problem hiding this comment.
Fixed, please review the new change. Thanks!
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Sidebar. In the future, I agree we need more |
|
Hi @SunMarc . Is it okay to merge this PR? Thanks. |
* check float before using normal op Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix llama4 weight Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add bnb quant skip module for llama4 Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * revert bnb integration Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * revert initialization.py Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * total revert init Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix _keep_in_fp32_modules Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add _modules_to_not_quantize Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix modules_to_not_convert Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * update bnb quantize condition Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Fixes Llama4 model loading under BitsAndBytes (BNB) quantization mode.
Router quantized incorrectly causes shape mismatch: Llama4Router inherits from nn.Linear, so BNB quantizes its weight into a packed format. However, super().forward() calls torch.nn.Linear.forward → F.linear, which hits BNB's torch_function hook and raises RuntimeError: mat1 and mat2 shapes cannot be multiplied due to the packed weight shape being incompatible with the input. Adding "router" (along with other sensitive modules) to _keep_in_fp32_modules skips quantization and fixes the error.