Skip to content

[Bug Report] Error in loading Quantized Llama 3.2 Model from HuggingFace #930

@aditeyabaral

Description

@aditeyabaral

Describe the bug
The Llama 3.2 3B model (meta-llama/Llama-3.2-3B-Instruct) fails to load when using quantized weights. Attempting to load the model with quantization enabled (a BitsAndBytesConfig with just load_in_4bit=True) results in errors during initialization.

Code example
Implemented using the LLaMA2 GPU Quantized Notebook as reference. Attached below is a minimum working snippet to reproduce the error.

import torch
from transformer_lens import HookedTransformer
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer

model_path = "meta-llama/Llama-3.2-3B-Instruct"
device = torch.device("cuda")
bnb = BitsAndBytesConfig(
    load_in_4bit=True,
)

tokenizer = AutoTokenizer.from_pretrained(model_path)

hf_model = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=bnb,
    torch_dtype=torch.float32,
    device_map="cuda:0",
)

model = HookedTransformer.from_pretrained(
    model_path,
    hf_model=hf_model,
    dtype=torch.float32,
    fold_ln=False,
    fold_value_biases=False,
    center_writing_weights=False,
    center_unembed=False,
    tokenizer=tokenizer
)

Stack Trace

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-14-a91e03fbaed1>](https://localhost:8080/#) in <cell line: 0>()
----> 1 model = HookedTransformer.from_pretrained(
      2     model_path,
      3     hf_model=hf_model,
      4     dtype=torch.float32,
      5     fold_ln=False,

2 frames
[/usr/local/lib/python3.11/dist-packages/transformer_lens/HookedTransformer.py](https://localhost:8080/#) in from_pretrained(cls, model_name, fold_ln, center_writing_weights, center_unembed, refactor_factored_attn_matrices, checkpoint_index, checkpoint_value, hf_model, device, n_devices, tokenizer, move_to_device, fold_value_biases, default_prepend_bos, default_padding_side, dtype, first_n_layers, **from_pretrained_kwargs)
   1369         )
   1370 
-> 1371         model.load_and_process_state_dict(
   1372             state_dict,
   1373             fold_ln=fold_ln,

[/usr/local/lib/python3.11/dist-packages/transformer_lens/HookedTransformer.py](https://localhost:8080/#) in load_and_process_state_dict(self, state_dict, fold_ln, center_writing_weights, center_unembed, fold_value_biases, refactor_factored_attn_matrices)
   1632             # with quantization, parameters should be assigned
   1633             # so that quantization settings are not lost
-> 1634             self.load_state_dict(state_dict, assign=True, strict=False)
   1635         else:
   1636             state_dict_keys = list(state_dict.keys())

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in load_state_dict(self, state_dict, strict, assign)
   2579 
   2580         if len(error_msgs) > 0:
-> 2581             raise RuntimeError(
   2582                 "Error(s) in loading state_dict for {}:\n\t{}".format(
   2583                     self.__class__.__name__, "\n\t".join(error_msgs)
Error(s) in loading state_dict for HookedTransformer:
	size mismatch for blocks.0.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.0.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.1.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.1.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.2.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.2.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.3.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.3.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.4.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.4.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.5.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.5.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.6.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.6.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.7.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.7.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.8.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.8.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.9.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.9.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.10.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.10.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.11.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.11.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.12.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.12.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.13.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.13.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.14.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.14.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.15.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.15.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.16.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.16.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.17.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.17.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.18.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.18.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.19.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.19.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.20.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.20.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.21.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.21.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.22.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.22.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.23.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.23.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.24.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.24.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.25.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.25.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.26.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.26.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.27.attn._W_K: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).
	size mismatch for blocks.27.attn._W_V: copying a param with shape torch.Size([1572864, 1]) from checkpoint, the shape in current model is torch.Size([8, 3072, 128]).

System Info
Describe the characteristic of your environment:

  • transformer_lens version 2.15.4 installed via pip
  • Running on Google Colab
  • Python version = 3.11.12

Additional context
I face the same error while running on Kaggle or my local system, running the same package versions. The same error also occurs with other Llama 3.2 models (1B), while trying to load their quantized weights.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions