From c62b2aad0a4de647ef73f82fa4ec33352abe610a Mon Sep 17 00:00:00 2001 From: Runyu Lu Date: Mon, 10 Jun 2024 04:45:31 +0000 Subject: [PATCH 1/2] refactor baichuan --- .../modeling/layers/baichuan_tp_linear.py | 32 +++--- .../modeling/models/nopadding_baichuan.py | 101 +++--------------- .../modeling/policy/nopadding_baichuan.py | 10 +- 3 files changed, 36 insertions(+), 107 deletions(-) diff --git a/colossalai/inference/modeling/layers/baichuan_tp_linear.py b/colossalai/inference/modeling/layers/baichuan_tp_linear.py index e050dd71c8b2..fcced85e22cb 100644 --- a/colossalai/inference/modeling/layers/baichuan_tp_linear.py +++ b/colossalai/inference/modeling/layers/baichuan_tp_linear.py @@ -25,19 +25,19 @@ def from_native_module( ) -class BaichuanWpackLinear1D_Col(Linear1D_Col): - @staticmethod - def from_native_module( - module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs - ) -> ParallelModule: - in_features = module.in_features * 3 - out_features = module.out_features // 3 - module.weight.data = module.weight.view(3, out_features, -1).transpose(0, 1).reshape(out_features, in_features) - module.bias = None - - return Linear1D_Col.from_native_module( - module, - process_group, - *args, - **kwargs, - ) +# class BaichuanWpackLinear1D_Col(Linear1D_Col): +# @staticmethod +# def from_native_module( +# module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs +# ) -> ParallelModule: +# in_features = module.in_features * 3 +# out_features = module.out_features // 3 +# module.weight.data = module.weight.view(3, out_features, -1).transpose(0, 1).reshape(out_features, in_features) +# module.bias = None + +# return Linear1D_Col.from_native_module( +# module, +# process_group, +# *args, +# **kwargs, +# ) diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py index b50e73d6fcf4..ec2abd3f08e2 100644 --- a/colossalai/inference/modeling/models/nopadding_baichuan.py +++ b/colossalai/inference/modeling/models/nopadding_baichuan.py @@ -1,5 +1,4 @@ # This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py -import itertools import math from typing import List, Optional, Tuple, Union @@ -7,6 +6,7 @@ import torch.nn as nn from torch.distributed import ProcessGroup +from colossalai.accelerator import get_accelerator from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP from colossalai.kernel.kernel_loader import InferenceOpsLoader @@ -20,7 +20,7 @@ ) from colossalai.logging import get_dist_logger from colossalai.shardformer.layer.parallel_module import ParallelModule -from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor +from colossalai.tensor.d_tensor import is_distributed_tensor logger = get_dist_logger(__name__) @@ -96,23 +96,18 @@ class NopadBaichuanAttention(ParallelModule): def __init__( self, config, - attn_qproj_w: torch.Tensor = None, - attn_kproj_w: torch.Tensor = None, - attn_vproj_w: torch.Tensor = None, + W_pack: ParallelModule = None, attn_oproj: ParallelModule = None, num_heads: int = None, hidden_size: int = None, process_group: ProcessGroup = None, - helper_layout: Layout = None, ): """This layer will replace the BaichuanAttention. Args: config (BaichuanConfig): Holding the Baichuan model config. - attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None. - attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None. - attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None. - attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None. + W_pack (ParallelModule, optional): The packed weight. Defaults to None. + attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None. """ ParallelModule.__init__(self) self.o_proj = attn_oproj @@ -122,10 +117,7 @@ def __init__( self.hidden_size = hidden_size self.head_dim = self.hidden_size // self.num_heads self.process_group = process_group - qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)] - self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0)) - - self.helper_layout = helper_layout + self.W_pack = W_pack self.alibi_slopes = None self.use_alibi_attn = False @@ -133,9 +125,9 @@ def __init__( if config.hidden_size == 5120: slopes_start = self.process_group.rank() * num_heads self.use_alibi_attn = True - self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[ - slopes_start : slopes_start + num_heads - ].contiguous() + self.alibi_slopes = get_alibi_slopes( + config.num_attention_heads, device=get_accelerator().get_current_device() + )[slopes_start : slopes_start + num_heads].contiguous() self.alibi_slopes = nn.Parameter(self.alibi_slopes) @staticmethod @@ -149,76 +141,20 @@ def from_native_module( """ config = module.config - q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((module.hidden_size, 3, -1)).transpose(0, 1) - - attn_qproj_w = q_proj_w - attn_kproj_w = k_proj_w - attn_vproj_w = v_proj_w + W_pack = module.W_pack attn_oproj = module.o_proj - helper_layout = ( - module.W_pack.weight.dist_layout - ) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict) - attn_layer = NopadBaichuanAttention( config=config, - attn_qproj_w=attn_qproj_w, - attn_kproj_w=attn_kproj_w, - attn_vproj_w=attn_vproj_w, + W_pack=W_pack, attn_oproj=attn_oproj, num_heads=module.num_heads, hidden_size=module.hidden_size, process_group=process_group, - helper_layout=helper_layout, ) return attn_layer - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - for hook in self._load_state_dict_pre_hooks.values(): - hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - - persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} - local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) - local_state = {k: v for k, v in local_name_params if v is not None} - - key = "qkv_weight" - qkv_w = state_dict[prefix + "W_pack.weight"] - - in_features = qkv_w.size(1) - out_features = qkv_w.size(0) // 3 - - qkv_w.data = qkv_w.view((3, out_features, -1)).transpose(0, 1).reshape(out_features, in_features * 3) - - device_mesh = self.helper_layout.device_mesh - sharding_spec = self.helper_layout.sharding_spec - qkv_w = distribute_tensor(qkv_w, device_mesh, sharding_spec) - - qkv_w = qkv_w.transpose(0, 1).reshape(3, in_features, -1) - input_param = nn.Parameter( - qkv_w - ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) - - param = local_state[key] - - try: - with torch.no_grad(): - param.copy_(input_param) - except Exception as ex: - error_msgs.append( - 'While copying the parameter named "{}", ' - "whose dimensions in the model are {} and " - "whose dimensions in the checkpoint are {}, " - "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) - ) - - strict = False # to avoid unexpected_keys - super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) - def forward( self, hidden_states: torch.Tensor, @@ -257,13 +193,13 @@ def forward( cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length. high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ - token_nums = hidden_states.size(0) - # fused qkv - hidden_states = hidden_states.expand(3, -1, -1) - query_states, key_states, value_states = ( - torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0) - ) + + proj = self.W_pack(hidden_states) + proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) + query_states = proj[0].view(token_nums, self.num_heads, self.head_dim) + key_states = proj[1].view(token_nums, self.num_heads, self.head_dim) + value_states = proj[2].view(token_nums, self.num_heads, self.head_dim) block_size = k_cache.size(-2) @@ -388,9 +324,6 @@ def forward( return attn_output - def extra_repr(self) -> str: - return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False" - # NOTE This will cause difference as out length increases. class NopadBaichuanMLP(NopadLlamaMLP): diff --git a/colossalai/inference/modeling/policy/nopadding_baichuan.py b/colossalai/inference/modeling/policy/nopadding_baichuan.py index 78268d6e7e85..5b32ef536a50 100644 --- a/colossalai/inference/modeling/policy/nopadding_baichuan.py +++ b/colossalai/inference/modeling/policy/nopadding_baichuan.py @@ -1,8 +1,5 @@ from colossalai.inference.config import RPC_PARAM -from colossalai.inference.modeling.layers.baichuan_tp_linear import ( - BaichuanLMHeadLinear1D_Col, - BaichuanWpackLinear1D_Col, -) +from colossalai.inference.modeling.layers.baichuan_tp_linear import BaichuanLMHeadLinear1D_Col from colossalai.inference.modeling.models.nopadding_baichuan import ( NopadBaichuanAttention, NopadBaichuanMLP, @@ -14,7 +11,7 @@ llama_model_forward, ) from colossalai.inference.utils import init_to_get_rotary -from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import FusedLinear1D_Col, Linear1D_Col, Linear1D_Row from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy @@ -60,8 +57,7 @@ def module_policy(self): target_module=NopadBaichuanMLP, ), SubModuleReplacementDescription( - suffix="self_attn.W_pack", - target_module=BaichuanWpackLinear1D_Col, + suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3} ), SubModuleReplacementDescription( suffix="self_attn.o_proj", From 90064bc70cc3cb905b045c4c9b194173ddbfb8b0 Mon Sep 17 00:00:00 2001 From: Runyu Lu Date: Tue, 11 Jun 2024 02:41:19 +0000 Subject: [PATCH 2/2] remove unused code and add TODO for lazyinit --- .../modeling/layers/baichuan_tp_linear.py | 23 ++++--------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/colossalai/inference/modeling/layers/baichuan_tp_linear.py b/colossalai/inference/modeling/layers/baichuan_tp_linear.py index fcced85e22cb..50806a14b9e8 100644 --- a/colossalai/inference/modeling/layers/baichuan_tp_linear.py +++ b/colossalai/inference/modeling/layers/baichuan_tp_linear.py @@ -15,7 +15,10 @@ def from_native_module( module.in_features = module.weight.size(1) module.out_features = module.weight.size(0) module.bias = None - module.weight.data = nn.functional.normalize(module.weight) + module.weight.data = nn.functional.normalize( + module.weight + ) # TODO(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight. + # So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue. return Linear1D_Col.from_native_module( module, @@ -23,21 +26,3 @@ def from_native_module( *args, **kwargs, ) - - -# class BaichuanWpackLinear1D_Col(Linear1D_Col): -# @staticmethod -# def from_native_module( -# module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs -# ) -> ParallelModule: -# in_features = module.in_features * 3 -# out_features = module.out_features // 3 -# module.weight.data = module.weight.view(3, out_features, -1).transpose(0, 1).reshape(out_features, in_features) -# module.bias = None - -# return Linear1D_Col.from_native_module( -# module, -# process_group, -# *args, -# **kwargs, -# )