diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 258e94f3f298..94079dbf1e0d 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -137,7 +137,6 @@ We will follow this roadmap to develop Shardformer: | swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| mistral | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | ## 💡 API Design diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 81033c429fc7..19b973be8679 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -109,8 +109,8 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: ) LazyInitContext.materialize(module) - # to check if it is huggingface LlamaRMSNorm or MistralRMSNorm - if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]: + # to check if it is huggingface LlamaRMSNorm + if module.__class__.__name__ == "LlamaRMSNorm": normalized_shape = module.weight.shape[0] eps = module.variance_epsilon elementwise_affine = True diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py deleted file mode 100644 index 02bd72ea3dda..000000000000 --- a/colossalai/shardformer/modeling/mistral.py +++ /dev/null @@ -1,77 +0,0 @@ -import warnings -from typing import List, Optional, Tuple - -import torch -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) -from transformers.utils import logging - - -def get_mistral_flash_attention_forward(): - from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv - - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention - - def forward( - self: MistralAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." - - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) - query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) - key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) - value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape) - - flash_attention_mask = None - attn_mask_type = AttnMaskType.causal - if attention_mask != None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - attn_mask_type = AttnMaskType.paddedcausal - - attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) - attn_output = attention( - query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type - ) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - return forward diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py deleted file mode 100644 index e1552f7c6d82..000000000000 --- a/colossalai/shardformer/policies/mistral.py +++ /dev/null @@ -1,175 +0,0 @@ -import warnings -from functools import partial -from typing import Callable, Dict, List, Union - -import torch.nn as nn -from torch import Tensor -from torch.nn import Module - -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D - -from ..modeling.mistral import get_mistral_flash_attention_forward -from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription - -__all__ = ["MistralPolicy", "MistralModelPolicy", "MistralForCausalLMPolicy", "MistralForSequenceClassificationPolicy"] - - -class MistralPolicy(Policy): - def config_sanity_check(self): - pass - - def preprocess(self): - if self.shard_config.enable_tensor_parallelism: - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - - return self.model - - def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.mistral.modeling_mistral import MistralAttention, MistralDecoderLayer, MistralModel - - policy = {} - - if self.shard_config.enable_sequence_parallelism: - self.shard_config.enable_sequence_parallelism = False - warnings.warn("Mistral dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") - - if self.shard_config.enable_tensor_parallelism: - decoder_attribute_replacement = { - "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attn.num_key_value_heads": self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size - } - - policy[MistralDecoderLayer] = ModulePolicyDescription( - attribute_replacement=decoder_attribute_replacement, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attn.q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.o_proj", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.gate_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="mlp.up_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="mlp.down_proj", - target_module=Linear1D_Row, - ), - ], - ) - - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ), - policy=policy, - target_key=MistralModel, - ) - - # optimization configuration - if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="input_layernorm", - target_module=FusedRMSNorm, - ), - SubModuleReplacementDescription( - suffix="post_attention_layernorm", - target_module=FusedRMSNorm, - ), - ], - policy=policy, - target_key=MistralDecoderLayer, - ) - - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription( - suffix="norm", - target_module=FusedRMSNorm, - ), - policy=policy, - target_key=MistralModel, - ) - - if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement( - description={ - "forward": get_mistral_flash_attention_forward(), - }, - policy=policy, - target_key=MistralAttention, - ) - - return policy - - def postprocess(self): - return self.model - -class MistralModelPolicy(MistralPolicy): - def __init__(self) -> None: - super().__init__() - -class MistralForCausalLMPolicy(MistralPolicy): - def module_policy(self): - from transformers import MistralForCausalLM - - policy = super().module_policy() - - if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm - new_item = { - MistralForCausalLM: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) - ) - ] - ) - } - policy.update(new_item) - - return policy - -class MistralForSequenceClassificationPolicy(MistralPolicy): - def module_policy(self): - from transformers import MistralForSequenceClassification - - policy = super().module_policy() - - if self.shard_config.enable_tensor_parallelism: - # add a new item for sequence classification - new_item = { - MistralForSequenceClassification: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) - ) - ] - ) - } - policy.update(new_item) - return policy