diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 06b8f93d20b5..ac7845400d8d 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -2,12 +2,22 @@ from typing import List, Optional, Tuple, Union import torch -from transformers.cache_utils import Cache +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.mistral.modeling_mistral import MistralModel +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.models.mistral.modeling_mistral import MistralForCausalLM, MistralModel from transformers.utils import logging +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.shard import ShardConfig + +from ..layer import ColoAttention + logger = logging.get_logger(__name__) @@ -24,6 +34,10 @@ def mistral_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ) -> Union[Tuple, BaseModelOutputWithPast]: if use_cache: logger.warning_once("use_cache=True is not supported for Mistral models at the moment.") @@ -35,6 +49,376 @@ def mistral_model_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + past_key_values_length = 0 + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if attention_mask is not None and self._use_flash_attention_2 and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if shard_config.enable_flash_attention: + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length, seq_length) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + start_idx, end_idx = stage_index[0], stage_index[1] + num_ckpt_layers = 0 + if self.gradient_checkpointing and self.training: + num_ckpt_layers = end_idx - start_idx + # TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer + if shard_config.gradient_checkpoint_config is not None: + num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( + stage=stage_manager.stage, + num_layers=end_idx - start_idx, + model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), + ) + assert num_ckpt_layers <= end_idx - start_idx + + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if idx - start_idx < num_ckpt_layers: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + else: + return {"hidden_states": hidden_states} + + @staticmethod + def mistral_for_causal_lm_forward( + self: MistralForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = MistralForwards.mistral_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + past_key_values = None + + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def mistral_for_sequence_classification_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = MistralForwards.mistral_model_forward( + self.model, + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + if input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + batch_size = hidden_states.shape[0] + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( + logits.device + ) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig): + logger = logging.get_logger(__name__) + assert shard_config.enable_flash_attention, "Flash Attention is not enabled." + + def forward( + self: MistralModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") @@ -47,6 +431,12 @@ def mistral_model_forward( past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( @@ -67,19 +457,29 @@ def mistral_model_forward( " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + if shard_config.enable_flash_attention: + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length, seq_length) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + inputs_embeds.dtype, + inputs_embeds.device, + q_padding_mask=attention_mask, + is_causal=True, ) + else: + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) hidden_states = inputs_embeds @@ -93,6 +493,7 @@ def mistral_model_forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -121,7 +522,7 @@ def mistral_model_forward( hidden_states = layer_outputs[0] if use_cache: - layer_outputs[2 if output_attentions else 1] + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -133,6 +534,8 @@ def mistral_model_forward( all_hidden_states += (hidden_states,) next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -143,11 +546,11 @@ def mistral_model_forward( attentions=all_self_attns, ) + return forward -def get_mistral_flash_attention_forward(): - from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention +def get_mistral_flash_attention_forward(shard_config: ShardConfig): + from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv def forward( self: MistralAttention, @@ -164,15 +567,14 @@ def forward( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) bsz, q_len, _ = hidden_states.size() - assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = ( - self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - ) - value_states = ( - self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - ) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -190,31 +592,18 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) - query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) - key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) - value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape) + assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." + attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) - flash_attention_mask = None - attn_mask_type = AttnMaskType.causal - if attention_mask != None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - attn_mask_type = AttnMaskType.paddedcausal - - attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) - attn_output = attention( - query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type - ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) - return attn_output, past_key_value + return attn_output, None, past_key_value return forward diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index b82840576974..e72a97e4bfc0 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -137,6 +137,7 @@ def module_policy(self): if self.shard_config.enable_flash_attention: warnings.warn("Falcon doesn't support flash attention now, fallback to transformers attention.") + return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 36d081c17991..6f4f835a8dbe 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -48,6 +48,7 @@ def module_policy(self): policy = {} attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = col_nn.VocabParallelEmbedding1D diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index f40a17ad9487..1280efaec921 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -43,6 +43,7 @@ def module_policy(self): policy = {} attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = col_nn.VocabParallelEmbedding1D diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index b4d25a2fe85e..b5018e47d65d 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -1,8 +1,10 @@ import warnings from functools import partial -from typing import Callable, Dict, Union +from typing import Callable, Dict, List, Union import torch.nn as nn +from torch import Tensor +from torch.nn import Module from colossalai.shardformer.layer import ( FusedRMSNorm, @@ -14,7 +16,11 @@ VocabParallelLMHead1D, ) -from ..modeling.mistral import MistralForwards, get_mistral_flash_attention_forward +from ..modeling.mistral import ( + MistralForwards, + get_mistral_flash_attention_forward, + get_mistral_model_forward_for_flash_attn, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["MistralPolicy", "MistralModelPolicy", "MistralForCausalLMPolicy", "MistralForSequenceClassificationPolicy"] @@ -45,6 +51,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -145,16 +152,83 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=attn_cls, ) + if self.pipeline_stage_manager is None: + # replace llama model forward method + self.append_or_create_method_replacement( + description={ + "forward": get_mistral_model_forward_for_flash_attn(self.shard_config), + }, + policy=policy, + target_key=MistralModel, + ) return policy def postprocess(self): return self.model - def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - method_replacement = {"forward": partial(new_forward)} + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager is None: + return + + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "MistralModel": + module = self.model + else: + module = self.model.model + + if stage_manager.is_interleave: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) + method_replacement = { + "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + } + + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "MistralModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(module.norm) + + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + return held_layers + class MistralModelPolicy(MistralPolicy): def __init__(self) -> None: @@ -164,17 +238,28 @@ def module_policy(self): policy = super().module_policy() from transformers.models.mistral.modeling_mistral import MistralModel - self.set_forward(model_cls=MistralModel, new_forward=MistralForwards.mistral_model_forward, policy=policy) + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=MistralModel, new_forward=MistralForwards.mistral_model_forward, policy=policy + ) + return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in mistral model""" + return [] + class MistralForCausalLMPolicy(MistralPolicy): def module_policy(self): from transformers import MistralForCausalLM policy = super().module_policy() - if self.pipeline_stage_manager: - warnings.warn("Mistral doesn't support pipeline parallelism now.") if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm @@ -207,8 +292,38 @@ def module_policy(self): policy.update(new_item) + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=MistralForCausalLM, new_forward=MistralForwards.mistral_for_causal_lm_forward, policy=policy + ) + return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + mistral_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if ( + id(mistral_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): + # tie weights + return [ + { + 0: mistral_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] + return [] + class MistralForSequenceClassificationPolicy(MistralPolicy): def module_policy(self): @@ -227,9 +342,26 @@ def module_policy(self): ] ) } + policy.update(new_item) - if self.pipeline_stage_manager: - warnings.warn("Mistral doesn't support pipeline parallelism now.") + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=MistralForSequenceClassification, + new_forward=MistralForwards.mistral_for_sequence_classification_forward, + policy=policy, + ) - policy.update(new_item) return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama for sequence classification model""" + return [] diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index e589c674088f..2f6eabd5fef9 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -58,6 +58,7 @@ def module_policy(self): policy = {} attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D diff --git a/requirements/requirements.txt b/requirements/requirements.txt index b0352230788a..d307312ded8e 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -16,4 +16,4 @@ ray sentencepiece google protobuf -transformers==4.36.0 +transformers==4.36.2 diff --git a/tests/kit/model_zoo/transformers/mistral.py b/tests/kit/model_zoo/transformers/mistral.py index 37f87585759e..ae5a9700240a 100644 --- a/tests/kit/model_zoo/transformers/mistral.py +++ b/tests/kit/model_zoo/transformers/mistral.py @@ -52,6 +52,9 @@ def data_gen_for_sequence_classification(): hidden_size=256, intermediate_size=256, num_attention_heads=64, num_hidden_layers=2, vocab_size=50258 ) +if hasattr(config, "pad_token_id"): + config.pad_token_id = config.eos_token_id + model_zoo.register( name="transformers_mistral", model_fn=lambda: transformers.MistralModel(config), diff --git a/tests/test_shardformer/test_model/test_shard_mistral.py b/tests/test_shardformer/test_model/test_shard_mistral.py index f127472aee0b..05c1998146b6 100644 --- a/tests/test_shardformer/test_model/test_shard_mistral.py +++ b/tests/test_shardformer/test_model/test_shard_mistral.py @@ -91,7 +91,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config["precision"] == "fp32": - atol, rtol = 1e-4, 1e-3 + atol, rtol = 2e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 check_weight( @@ -114,6 +114,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 4, "pp_size": 1, @@ -156,7 +174,6 @@ def check_mistral(rank, world_size, port): run_mistral_test() -@pytest.mark.skip("something wrong with pipeline parallelism") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run()