Skip to content

[BitsAndBytesConfig] Providing llm_int8_skip_modules clears the default lm_head exclusion, causing AssertionError in 4-bit inference #45674

@softguy777

Description

@softguy777

Environment

  • transformers: 5.5.4
  • bitsandbytes: 0.49.2
  • torch: 2.11.0+cu126
  • CUDA: 12.6
  • OS: Windows 11
  • GPU: NVIDIA RTX 3090

Bug Description

When specifying llm_int8_skip_modules in BitsAndBytesConfig, the default module exclusion list (which normally protects lm_head from being quantized) is silently cleared. This causes a crash during inference:

File "bitsandbytes/nn/modules.py", line 415, in fix_4bit_weight_quant_state_from_module
    assert module.weight.shape[1] == 1
AssertionError

Root Cause

In transformers/quantizers/base.py, get_modules_to_not_convert():

if skip_modules is None or add_default_skips:
    modules_to_not_convert = get_keys_to_not_convert(model)  # auto-detects lm_head
else:
    modules_to_not_convert = []  # ← cleared when user provides ANY list!

if skip_modules is not None:
    modules_to_not_convert.extend(skip_modules)

When llm_int8_skip_modules=None (default), get_keys_to_not_convert() automatically finds and excludes lm_head and other output projection layers.

The moment the user provides any list (e.g., to protect a multimodal audio/vision tower from quantization), the auto-exclusion is disabled and lm_head gets quantized → AssertionError in bitsandbytes.

This is particularly easy to hit with multimodal models like Gemma 4 E2B-IT, where users need to explicitly skip the audio/vision towers to prevent quality degradation, but are not aware that doing so also removes the lm_head protection.

Minimal Reproducible Example

import torch
from transformers import AutoModelForImageTextToText, BitsAndBytesConfig

# ❌ Crashes: providing any list clears the lm_head default exclusion
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    llm_int8_skip_modules=["model.audio_tower"],
)
model = AutoModelForImageTextToText.from_pretrained(
    "google/gemma-4-e2b-it",
    quantization_config=bnb_config,
    device_map="auto",
)
# model.generate(...) → AssertionError in lm_head

Workaround (user must manually re-add lm_head):

# ✅ Works: lm_head added explicitly
llm_int8_skip_modules=["lm_head", "model.audio_tower"]

Suggested Fix

Change the behavior so that the user-provided list is additive rather than replacing the auto-detected defaults:

# In get_modules_to_not_convert():
modules_to_not_convert = get_keys_to_not_convert(model)  # always start with defaults

if skip_modules is not None:
    modules_to_not_convert.extend(skip_modules)  # user additions merged, not replacing

modules_to_not_convert = list(set(modules_to_not_convert))
return modules_to_not_convert

At minimum, the docstring for BitsAndBytesConfig.llm_int8_skip_modules should warn that lm_head must be included manually when this parameter is set.

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions