From c3e821582dbfda3c59721c2818c30e10322d4fb5 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 12 Apr 2024 13:14:14 +0800 Subject: [PATCH 1/2] [shardformer] fix llama modeling --- colossalai/shardformer/modeling/llama.py | 99 ++++++++++++++++-------- 1 file changed, 66 insertions(+), 33 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index dd2caefc5054..53332726d10e 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -7,6 +7,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.cache_utils import Cache from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -16,11 +17,12 @@ LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, apply_rotary_pos_emb, repeat_kv, ) from transformers.utils import logging -from transformers.cache_utils import Cache from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import ( @@ -32,8 +34,6 @@ from ..layer import ColoAttention, cross_entropy_1d -from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa - class LlamaPipelineForwards: """ @@ -107,7 +107,10 @@ def llama_model_forward( if position_ids is None: position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, ) position_ids = position_ids.unsqueeze(0) @@ -117,26 +120,33 @@ def llama_model_forward( # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) attention_mask = ColoAttention.prepare_attn_kwargs( - mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, ) else: if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length - ) + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -159,7 +169,7 @@ def llama_model_forward( num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( stage=stage_manager.stage, num_layers=end_idx - start_idx, - model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0, + model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), ) assert num_ckpt_layers <= end_idx - start_idx @@ -203,7 +213,16 @@ def llama_model_forward( next_cache = next_decoder_cache if use_cache else None if stage_manager.is_last_stage(): if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + ] + if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -307,7 +326,9 @@ def llama_for_causal_lm_forward( new_vocab_size = logits.shape[-1] shift_logits = shift_logits.view(-1, new_vocab_size) loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, ) else: shift_logits = shift_logits.view(-1, self.config.vocab_size) @@ -446,12 +467,10 @@ def llama_for_sequence_classification_forward( def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - llama_version = 2 try: from transformers.models.llama.modeling_llama import repeat_kv except: warnings.warn("using llamav1, llamav1 hasn't repeat_kv function") - llama_version = 1 def forward( self: LlamaAttention, @@ -494,8 +513,8 @@ def forward( raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) + "with a layer index." + ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) @@ -567,7 +586,10 @@ def forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: @@ -581,7 +603,11 @@ def forward( # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) attention_mask = ColoAttention.prepare_attn_kwargs( - mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, ) if self.gradient_checkpointing and self.training: @@ -736,7 +762,9 @@ def forward( new_vocab_size = logits.shape[-1] shift_logits = shift_logits.view(-1, new_vocab_size) loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, ) if not return_dict: @@ -910,7 +938,10 @@ def forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: @@ -926,7 +957,9 @@ def forward( if attention_mask is None: attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, ) attention_mask = self._prepare_decoder_attention_mask( From 8b72eabfe4547b759cfce157119c0a63a3d1931d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Apr 2024 05:15:30 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/modeling/bloom.py | 3 ++- colossalai/shardformer/modeling/falcon.py | 16 +++++++++------- colossalai/shardformer/modeling/mistral.py | 19 +++++++++---------- colossalai/shardformer/modeling/opt.py | 9 +++++---- colossalai/shardformer/modeling/t5.py | 7 ++----- colossalai/shardformer/modeling/vit.py | 10 +++++----- colossalai/shardformer/modeling/whisper.py | 8 ++++++-- colossalai/shardformer/policies/mistral.py | 10 ++++------ 8 files changed, 42 insertions(+), 40 deletions(-) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 2b2bf89a06e2..c4f326364596 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -6,6 +6,7 @@ from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import functional as F +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -21,7 +22,7 @@ BloomModel, ) from transformers.utils import logging -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index 49e9564d8773..34754ecdbac9 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -1,9 +1,12 @@ -from typing import List, Optional, Tuple, Union import math +import warnings +from typing import List, Optional, Tuple, Union + import torch import torch.distributed as dist from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import functional as F from transformers.modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_causal_attention_mask, @@ -22,14 +25,13 @@ FalconForSequenceClassification, FalconForTokenClassification, FalconModel, - build_alibi_tensor, apply_rotary_pos_emb, + build_alibi_tensor, ) from transformers.utils import logging -import warnings + from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig -from torch.nn import functional as F def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: @@ -171,7 +173,7 @@ def forward( def get_falcon_flash_attention_forward(): try: - from xformers.ops import memory_efficient_attention as me_attention + pass except: raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") from transformers.models.falcon.modeling_falcon import FalconAttention @@ -347,7 +349,7 @@ def falcon_model_forward( past_key_values = None return_dict = return_dict if return_dict is not None else self.config.use_return_dict - + # case: First stage of training if stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: @@ -449,7 +451,7 @@ def falcon_model_forward( attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) - + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index c325cb284c22..3b876bcab96a 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -1,21 +1,20 @@ -from typing import Optional, Tuple +import warnings +from typing import List, Optional, Tuple, Union import torch +from transformers.cache_utils import Cache +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPast -from typing import List, Optional, Tuple, Union -import warnings from transformers.models.mistral.modeling_mistral import MistralModel -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.utils import logging -from transformers.cache_utils import Cache logger = logging.get_logger(__name__) + class MistralForwards: - @staticmethod def mistral_model_forward( - self:MistralModel, + self: MistralModel, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -94,7 +93,6 @@ def mistral_model_forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -123,7 +121,7 @@ def mistral_model_forward( hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] + layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -145,6 +143,7 @@ def mistral_model_forward( attentions=all_self_attns, ) + def get_mistral_flash_attention_forward(): from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv @@ -218,7 +217,7 @@ def forward( if not output_attentions: attn_weights = None - + return attn_output, attn_weights, past_key_value return forward diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 0a31820876ad..de5b1a267cd7 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -1,7 +1,9 @@ import random from typing import List, Optional, Tuple, Union + import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -15,7 +17,7 @@ OPTModel, ) from transformers.utils import logging -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.shard import ShardConfig @@ -55,7 +57,7 @@ class OPTPipelineForwards: This class serves as a micro library for forward function substitution of OPT models under pipeline setting. """ - + @staticmethod def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ @@ -70,7 +72,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - @staticmethod def opt_model_forward( self: OPTModel, @@ -125,7 +126,7 @@ def opt_model_forward( if decoder.project_in is not None: inputs_embeds = decoder.project_in(inputs_embeds) device = input_ids.device if input_ids is not None else inputs_embeds.device - _dtype = inputs_embeds.dtype + inputs_embeds.dtype hidden_states = inputs_embeds else: if hidden_states is None: diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index 94f4fce74501..b35bb6b94991 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -3,7 +3,6 @@ import torch from torch.nn import CrossEntropyLoss -from torch.utils.checkpoint import checkpoint from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -121,7 +120,7 @@ def t5_stack_forward( # initialize past_key_values with `None` if past does not exist if past_key_values is None: past_key_values = [None] * len(self.block) - + if attention_mask is None: attention_mask = torch.ones(batch_size, mask_seq_length, device=device) @@ -135,9 +134,7 @@ def t5_stack_forward( encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: - encoder_attention_mask = torch.ones( - encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long - ) + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 401973ce4dfe..67b10988d100 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -26,11 +26,11 @@ def _encoder_forward( if encoder.gradient_checkpointing and encoder.training: layer_outputs = encoder._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 6997f181c9ee..509fc3dac86f 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -5,6 +5,10 @@ import torch from torch import nn from torch.nn import CrossEntropyLoss +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -21,7 +25,7 @@ shift_tokens_right, ) from transformers.utils import logging -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa + from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.shard import ShardConfig @@ -695,7 +699,7 @@ def whisper_decoder_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - + if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 31ce160463a6..3645cf3694fa 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -1,12 +1,12 @@ import warnings from functools import partial -from typing import Dict, Union, Callable +from typing import Callable, Dict, Union import torch.nn as nn from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from ..modeling.mistral import get_mistral_flash_attention_forward, MistralForwards +from ..modeling.mistral import MistralForwards, get_mistral_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["MistralPolicy", "MistralModelPolicy", "MistralForCausalLMPolicy", "MistralForSequenceClassificationPolicy"] @@ -129,11 +129,9 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def postprocess(self): return self.model - + def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - method_replacement = { - "forward": partial(new_forward) - } + method_replacement = {"forward": partial(new_forward)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)