Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 4 additions & 19 deletions colossalai/inference/modeling/layers/baichuan_tp_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,10 @@ def from_native_module(
module.in_features = module.weight.size(1)
module.out_features = module.weight.size(0)
module.bias = None
module.weight.data = nn.functional.normalize(module.weight)

return Linear1D_Col.from_native_module(
module,
process_group,
*args,
**kwargs,
)


class BaichuanWpackLinear1D_Col(Linear1D_Col):
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
in_features = module.in_features * 3
out_features = module.out_features // 3
module.weight.data = module.weight.view(3, out_features, -1).transpose(0, 1).reshape(out_features, in_features)
module.bias = None
module.weight.data = nn.functional.normalize(
module.weight
) # TODO(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight.
# So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue.

return Linear1D_Col.from_native_module(
module,
Expand Down
101 changes: 17 additions & 84 deletions colossalai/inference/modeling/models/nopadding_baichuan.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
import itertools
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.distributed import ProcessGroup

from colossalai.accelerator import get_accelerator
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
Expand All @@ -16,7 +16,7 @@
from colossalai.kernel.triton import rms_layernorm
from colossalai.logging import get_dist_logger
from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor
from colossalai.tensor.d_tensor import is_distributed_tensor

inference_ops = InferenceOpsLoader().load()
logger = get_dist_logger(__name__)
Expand Down Expand Up @@ -55,24 +55,19 @@ class NopadBaichuanAttention(ParallelModule):
def __init__(
self,
config,
attn_qproj_w: torch.Tensor = None,
attn_kproj_w: torch.Tensor = None,
attn_vproj_w: torch.Tensor = None,
W_pack: ParallelModule = None,
Comment thread
LRY89757 marked this conversation as resolved.
attn_oproj: ParallelModule = None,
num_heads: int = None,
hidden_size: int = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
process_group: ProcessGroup = None,
helper_layout: Layout = None,
):
"""This layer will replace the BaichuanAttention.

Args:
config (BaichuanConfig): Holding the Baichuan model config.
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None.
W_pack (ParallelModule, optional): The packed weight. Defaults to None.
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None.
"""
ParallelModule.__init__(self)
self.o_proj = attn_oproj
Expand All @@ -82,10 +77,7 @@ def __init__(
self.hidden_size = hidden_size
self.head_dim = self.hidden_size // self.num_heads
self.process_group = process_group
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))

self.helper_layout = helper_layout
self.W_pack = W_pack
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
self.attention_backend = get_attention_backend(model_shard_infer_config)
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
Expand All @@ -96,9 +88,9 @@ def __init__(
if config.hidden_size == 5120:
slopes_start = self.process_group.rank() * num_heads
self.use_alibi_attn = True
self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[
slopes_start : slopes_start + num_heads
].contiguous()
self.alibi_slopes = get_alibi_slopes(
config.num_attention_heads, device=get_accelerator().get_current_device()
)[slopes_start : slopes_start + num_heads].contiguous()
self.alibi_slopes = nn.Parameter(self.alibi_slopes)

@staticmethod
Expand All @@ -112,78 +104,22 @@ def from_native_module(
"""

config = module.config
q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((module.hidden_size, 3, -1)).transpose(0, 1)

attn_qproj_w = q_proj_w
attn_kproj_w = k_proj_w
attn_vproj_w = v_proj_w
W_pack = module.W_pack
attn_oproj = module.o_proj
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)

helper_layout = (
module.W_pack.weight.dist_layout
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)

attn_layer = NopadBaichuanAttention(
config=config,
attn_qproj_w=attn_qproj_w,
attn_kproj_w=attn_kproj_w,
attn_vproj_w=attn_vproj_w,
W_pack=W_pack,
attn_oproj=attn_oproj,
model_shard_infer_config=model_shard_infer_config,
num_heads=module.num_heads,
hidden_size=module.hidden_size,
process_group=process_group,
helper_layout=helper_layout,
)

return attn_layer

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}

key = "qkv_weight"
qkv_w = state_dict[prefix + "W_pack.weight"]

in_features = qkv_w.size(1)
out_features = qkv_w.size(0) // 3

qkv_w.data = qkv_w.view((3, out_features, -1)).transpose(0, 1).reshape(out_features, in_features * 3)

device_mesh = self.helper_layout.device_mesh
sharding_spec = self.helper_layout.sharding_spec
qkv_w = distribute_tensor(qkv_w, device_mesh, sharding_spec)

qkv_w = qkv_w.transpose(0, 1).reshape(3, in_features, -1)
input_param = nn.Parameter(
qkv_w
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)

param = local_state[key]

try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)

strict = False # to avoid unexpected_keys
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -220,13 +156,13 @@ def forward(
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""

token_nums = hidden_states.size(0)
# fused qkv
hidden_states = hidden_states.expand(3, -1, -1)
query_states, key_states, value_states = (
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
)

proj = self.W_pack(hidden_states)
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
query_states = proj[0].view(token_nums, self.num_heads, self.head_dim)
key_states = proj[1].view(token_nums, self.num_heads, self.head_dim)
value_states = proj[2].view(token_nums, self.num_heads, self.head_dim)

block_size = k_cache.size(-2)

Expand Down Expand Up @@ -279,9 +215,6 @@ def forward(

return attn_output

def extra_repr(self) -> str:
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"


# NOTE This will cause difference as out length increases.
class NopadBaichuanMLP(NopadLlamaMLP):
Expand Down
10 changes: 3 additions & 7 deletions colossalai/inference/modeling/policy/nopadding_baichuan.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from colossalai.inference.config import RPC_PARAM
from colossalai.inference.modeling.layers.baichuan_tp_linear import (
BaichuanLMHeadLinear1D_Col,
BaichuanWpackLinear1D_Col,
)
from colossalai.inference.modeling.layers.baichuan_tp_linear import BaichuanLMHeadLinear1D_Col
from colossalai.inference.modeling.models.nopadding_baichuan import (
NopadBaichuanAttention,
NopadBaichuanMLP,
Expand All @@ -14,7 +11,7 @@
llama_model_forward,
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import FusedLinear1D_Col, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy

Expand Down Expand Up @@ -60,8 +57,7 @@ def module_policy(self):
target_module=NopadBaichuanMLP,
),
SubModuleReplacementDescription(
suffix="self_attn.W_pack",
target_module=BaichuanWpackLinear1D_Col,
suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3}
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
Expand Down