Skip to content

[BUG]: ShardConfig跑llama2 7b和13b模型,没有同时均等切分num_key_value_heads和num_heads #4565

@wangbluo

Description

@wangbluo

🐛 Describe the bug

测试脚本:https://github.com/hpcaitech/ColossalAI/blob/main/examples/language/llama2/scripts/benchmark_7B/gemini.sh

其中plugin为3d

测试参数:args: args: Namespace(config='13b', **plugin='3d'**, batch_size=8, num_steps=5, ignore_steps=2, grad_checkpoint=True, max_length=4096, warmup_ratio=0.8, memory_limit=None, xformers=True, shard_param_frac=1.0, offload_optim_frac=0.0, offload_param_frac=0.0, tp=2, pp=4, mbs=1, zero=0)

之前测试main/colossalai/shardformer/examples/convergence_benchmark.sh也有同样的问题,都是ShardConfig引起的。

报错原因:llama70b没有区分self.num_key_value_heads self.num_heads,但是在llama7b和13b的时候,这两个heads是不一样的,看起来shardformer只切分了num_heads而没有切分num_key_value_heads,建议shardformer可以加一下这个情况的切分处理,不然直接跑跑不通的

我尝试在main/examples/language/llama2/attn.py第35行添加了self.num_key_value_heads = self.num_heads,可以解决这个报错

报错内容:

File "/usr/local/lib/python3.9/site-packages/colossalai-0.3.1-py3.9.egg/colossalai/shardformer/modeling/llama.py", line 129, in llama_model_forward
    layer_outputs = torch.utils.checkpoint.checkpoint(
  File "/usr/local/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/usr/local/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 107, in forward
    outputs = run_function(*args)
  File "/usr/local/lib/python3.9/site-packages/colossalai-0.3.1-py3.9.egg/colossalai/shardformer/modeling/llama.py", line 125, in custom_forward
    return module(*inputs, output_attentions, None)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 415, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/workspace/workfile/ColossalAI-main/examples/language/llama2/attn.py", line 37, in llama_flash_attention
    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
RuntimeError: shape '[8, 4096, 40, 128]' is invalid for input of size 83886080

image

Environment

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions