From ec1a922235926af583366d0c3e5993efd5be69d4 Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Wed, 27 Mar 2024 16:24:29 +0800 Subject: [PATCH] update_falcon --- colossalai/shardformer/modeling/falcon.py | 286 +++++++++++++++------- colossalai/shardformer/policies/falcon.py | 6 - 2 files changed, 201 insertions(+), 91 deletions(-) diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index 4e271dfe0fa2..49e9564d8773 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -1,9 +1,14 @@ from typing import List, Optional, Tuple, Union - +import math import torch import torch.distributed as dist from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -18,11 +23,13 @@ FalconForTokenClassification, FalconModel, build_alibi_tensor, + apply_rotary_pos_emb, ) from transformers.utils import logging - +import warnings from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig +from torch.nn import functional as F def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: @@ -99,11 +106,17 @@ def forward( hidden_states: torch.Tensor, alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + **kwargs, ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) residual = hidden_states if self.config.new_decoder_architecture: @@ -117,10 +130,12 @@ def forward( attention_layernorm_out, layer_past=layer_past, attention_mask=attention_mask, + position_ids=position_ids, alibi=alibi, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + **kwargs, ) attention_output = attn_outputs[0] @@ -166,11 +181,17 @@ def forward( hidden_states: torch.Tensor, alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + **kwargs, ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads # 3 x [batch_size, seq_length, num_heads, head_dim] @@ -178,59 +199,111 @@ def forward( batch_size, query_length, _, _ = query_layer.shape - query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim) - key_layer = key_layer.transpose(1, 2).reshape( - batch_size * num_kv_heads, - query_length, - self.head_dim, - ) - value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) + query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) + value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] - query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) + kv_seq_len = key_layer.shape[-2] + if layer_past is not None: + kv_seq_len += layer_past[0].shape[-2] + if alibi is None: + cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) + query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) if layer_past is not None: past_key, past_value = layer_past # concatenate along seq_length dimension: - # - key: [batch_size * self.num_heads, kv_length, head_dim] - # - value: [batch_size * self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=1) - value_layer = torch.cat((past_value, value_layer), dim=1) + # - key: [batch_size, self.num_heads, kv_length, head_dim] + # - value: [batch_size, self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=-2) + value_layer = torch.cat((past_value, value_layer), dim=-2) - _, kv_length, _ = key_layer.shape + kv_length = key_layer.shape[-2] if use_cache: present = (key_layer, value_layer) else: present = None - attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype) + if alibi is None: + if self._use_sdpa and not output_attentions: + attn_output = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attention_mask, + 0.0, + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. + is_causal=self.is_causal and attention_mask is None and query_length > 1, + ) + attention_scores = None + else: + attention_scores = query_layer @ key_layer.transpose(-1, -2) + attention_scores /= math.sqrt(self.head_dim) + + attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) + # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi). + attn_output = attention_scores @ value_layer - query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2).contiguous() - key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous() - value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) + attn_output = attn_output.permute(0, 2, 1, 3) + attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - if alibi is not None: - attention_mask_float = ( - attention_mask_float + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta - ) + attn_output = self.dense(attn_output) - batch_size, src_len = query_layer_.size()[0], query_layer_.size()[1] - tgt_len = key_layer_.size()[1] - attention_mask_float = attention_mask_float.expand(batch_size, self.num_heads, src_len, tgt_len).contiguous() - context_layer = me_attention( - query_layer_, - key_layer_, - value_layer_, - attn_bias=attention_mask_float, - scale=self.inv_norm_factor, - p=self.attention_dropout.p, - ) - batch_size, seq_length, _, _ = context_layer.shape - context_layer = context_layer.reshape(batch_size, seq_length, -1) + if output_attentions: + return attn_output, present, attention_scores + else: + return attn_output, present + else: + if self._use_sdpa and not output_attentions and head_mask is None: + attn_output = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.attention_dropout.p if self.training else 0.0, + is_causal=self.is_causal and attention_mask is None and query_length > 1, + ) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + + attn_output = self.dense(attn_output) + else: + matmul_result = query_layer @ key_layer.transpose(-1, -2) - output_tensor = self.dense(context_layer) + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) - return output_tensor, present + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16 or input_dtype == torch.bfloat16: + attention_scores = attention_scores.to(torch.float32) + + attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) + attention_logits *= self.inv_norm_factor + attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) + # [batch_size, num_heads, q_length, kv_length] + attention_probs = self.attention_dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # change view [batch_size, num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) + + # matmul: [batch_size * num_heads, q_length, head_dim] + attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1) + + # change view [batch_size, q_length, num_heads * head_dim] + attn_output = self._merge_heads(attn_output) + + attn_output = self.dense(attn_output) + + if output_attentions: + return attn_output, present, attention_probs + else: + return attn_output, present return forward @@ -246,6 +319,7 @@ def falcon_model_forward( input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -273,18 +347,7 @@ def falcon_model_forward( past_key_values = None return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - else: - past_key_values = self._convert_to_rw_cache(past_key_values) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape batch_size x num_heads x N x N - # head_mask has shape n_layer x batch x num_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - + # case: First stage of training if stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: @@ -295,16 +358,22 @@ def falcon_model_forward( batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") - if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - hidden_states = inputs_embeds - else: input_shape = hidden_states.shape[:-1] batch_size, seq_length = input_shape + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None @@ -312,22 +381,80 @@ def falcon_model_forward( # Compute alibi tensor: check build_alibi_tensor documentation past_key_values_length = 0 if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device) - else: - attention_mask = attention_mask.to(hidden_states.device) + past_key_values_length = past_key_values[0][0].shape[-2] if self.use_alibi: - alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + mask = ( + torch.ones( + (batch_size, seq_length + past_key_values_length), device=inputs_embeds.device, dtype=torch.long + ) + if attention_mask is None + else attention_mask + ) + alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype) else: alibi = None + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + if alibi is None: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + elif head_mask is None: + alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) - causal_mask = self._prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, - ) + attention_mask_2d = attention_mask + # We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched. + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + # We take care to integrate alibi bias in the attention_mask here. + if attention_mask_2d is None: + attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads) + else: + attention_mask = torch.masked_fill( + alibi / math.sqrt(self.config.hidden_size // self.num_heads), + attention_mask < -1, + torch.finfo(alibi.dtype).min, + ) + + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend + # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 + if seq_length > 1: + attention_mask = AttentionMaskConverter._unmask_unattended( + attention_mask, attention_mask_2d, unmasked_value=0.0 + ) + else: + # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case. + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) start_idx, end_idx = stage_index[0], stage_index[1] for i, (block, layer_past) in enumerate( @@ -337,31 +464,23 @@ def falcon_model_forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self._gradient_checkpointing_func( + block.__call__, hidden_states, alibi, - causal_mask, + attention_mask, + position_ids, head_mask[i], + layer_past, + use_cache, + output_attentions, ) else: outputs = block( hidden_states, layer_past=layer_past, - attention_mask=causal_mask, + attention_mask=attention_mask, + position_ids=position_ids, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, @@ -382,9 +501,6 @@ def custom_forward(*inputs): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if presents is not None: - presents = self._convert_cache_to_standard_format(presents, batch_size) - if stage_manager.is_last_stage(): if not return_dict: return tuple( diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 5c148880f980..1d9eb81c9e65 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -21,12 +21,6 @@ class FalconPolicy(Policy): def __init__(self) -> None: super().__init__() - import transformers - from packaging.version import Version - - assert Version(transformers.__version__) <= Version( - "4.33.0" - ), "The Falcon model should run on a transformers version not greater than 4.33.0." def config_sanity_check(self): pass