diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index bd55df28cd34..c0a80f3dfa8c 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -29,10 +29,21 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_available, + logging, + replace_return_docstrings, +) from .configuration_persimmon import PersimmonConfig +if is_flash_attn_available(): + from flash_attn import flash_attn_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "PersimmonConfig" @@ -287,7 +298,7 @@ def forward( query_states = self.q_layernorm(query_states) key_states = self.k_layernorm(key_states) - # [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim] + # [batch_size, seq_length, num_heads, head_dim] -> [batch_size, num_heads, seq_length, head_dim] query_states = query_states.transpose(1, 2) value_states = value_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -306,10 +317,10 @@ def forward( key_states[..., : self.rotary_emb.dim], key_states[..., self.rotary_emb.dim :], ) - # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + # [batch_size, num_heads, seq_length, head_dim // config.partial_rotary_factor] query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) - # [batch_size, seq_length, num_heads, head_dim] + # [batch_size, num_heads, seq_length, head_dim] query_states = torch.cat((query_rot, query_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1) @@ -358,11 +369,170 @@ def forward( return attn_output, attn_weights, past_key_value +class PersimmonFlashAttention2(PersimmonAttention): + """ + Persimmon flash attention module. This module inherits from `PersimmonAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # FlashAttention output_attentions is unstable see https://github.com/Dao-AILab/flash-attention/blob/4c8ff9154e76c68e7114292bd527c22f45fbf586/flash_attn/flash_attn_interface.py#L506 + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + # [batch_size, seq_length, 3 x hidden_size] + fused_qkv = self.query_key_value(hidden_states) + + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_states, key_states, value_states) = self._split_heads(fused_qkv) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + # [batch_size, seq_length, num_heads, head_dim] -> [batch_size, num_heads, seq_length, head_dim] + query_states = query_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_emb.dim], + query_states[..., self.rotary_emb.dim :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_emb.dim], + key_states[..., self.rotary_emb.dim :], + ) + # [batch_size, num_heads, seq_length, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) + + # [batch_size, num_heads, seq_length, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be related to" + " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + " float16." + ) + + query_states = query_states.to(torch.float16) + key_states = key_states.to(torch.float16) + value_states = value_states.to(torch.float16) + + attn_dropout = self.attention_dropout if self.training else 0.0 + + # [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim] + query_states = query_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, None, q_len, dropout=attn_dropout + ) + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + attn_output = self.dense(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + padding_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True + ) + + return attn_output + + # Not needed for Persimmon + # if padding_mask is not None: + # batch_size = query_states.shape[0] + # query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + # query_states, key_states, value_states, padding_mask, query_length + # ) + + # cu_seqlens_q, cu_seqlens_k = cu_seq_lens + # max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + # attn_output_unpad = flash_attn_varlen_func( + # query_states, + # key_states, + # value_states, + # cu_seqlens_q=cu_seqlens_q, + # cu_seqlens_k=cu_seqlens_k, + # max_seqlen_q=max_seqlen_in_batch_q, + # max_seqlen_k=max_seqlen_in_batch_k, + # dropout_p=dropout, + # softmax_scale=softmax_scale, + # causal=True, + # ) + + # attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + # else: + + class PersimmonDecoderLayer(nn.Module): def __init__(self, config: PersimmonConfig): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = PersimmonAttention(config=config) + self.self_attn = ( + PersimmonAttention(config=config) + if not getattr(config, "_flash_attn_2_enabled", False) + else PersimmonFlashAttention2(config=config) + ) + # self.self_attn = PersimmonAttention(config=config) self.mlp = PersimmonMLP(config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -458,6 +628,7 @@ class PersimmonPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["PersimmonDecoderLayer"] _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True def _init_weights(self, module): std = self.config.initializer_range