Skip to content

[Bug Report] Unable to Llama 3 70b on multigpu in 4bit #569

@winglian

Description

@winglian

Unable to Llama 3 70b on multigpu

base_model = AutoModelForCausalLM.from_pretrained('meta-llama/Meta-Llama-3-70B-Instruct', torch_dtype=torch.bfloat16, device_map="auto", load_in_4bit=True)
model = HookedTransformer.from_pretrained(
    'meta-llama/Meta-Llama-3-70B-Instruct',
    hf_model=base_model,
    fold_ln=False,
    fold_value_biases=False,
)

errors with

  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict                                                                                                                                 
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(                                                                                                                                                                                
RuntimeError: Error(s) in loading state_dict for HookedTransformer:                                                                                                                                                                                          
        size mismatch for blocks.0.attn._W_K: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]). 
        size mismatch for blocks.0.attn._W_V: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]). 
        size mismatch for blocks.1.attn._W_K: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]). 
        size mismatch for blocks.1.attn._W_V: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]). 
        size mismatch for blocks.2.attn._W_K: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]). 
        size mismatch for blocks.2.attn._W_V: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]). 
        size mismatch for blocks.3.attn._W_K: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]). 
        size mismatch for blocks.3.attn._W_V: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]). 
        ... (etc)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingcomplexity-highVery complicated changes for people to address who are quite familiar with the codemulti-gpu

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions