Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 50 additions & 87 deletions colossalai/shardformer/modeling/command.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import math
import warnings
from typing import List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import (
CohereAttention,
CohereForCausalLM,
CohereModel,
StaticCache,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.processing_utils import Unpack
from transformers.utils import logging

from colossalai.pipeline.stage_manager import PipelineStageManager
Expand All @@ -27,6 +27,8 @@

_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]

logger = logging.get_logger(__name__)


class CommandPipelineForwards:
"""
Expand All @@ -37,22 +39,23 @@ class CommandPipelineForwards:
@staticmethod
def command_model_forward(
self: CohereModel,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
force_sp_output_gather: bool = True,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
):

logger = logging.get_logger(__name__)

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand All @@ -67,8 +70,6 @@ def command_model_forward(
)
use_cache = False

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
Expand Down Expand Up @@ -122,6 +123,7 @@ def command_model_forward(

# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
shard_config.enable_flash_attention = True
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
Expand All @@ -133,7 +135,8 @@ def command_model_forward(
is_causal=True,
)
else:
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
# v4.51.3 transformers attention_mask calculation
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values)

if self.gradient_checkpointing and self.training and use_cache:
if use_cache:
Expand Down Expand Up @@ -163,6 +166,8 @@ def command_model_forward(
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
# v4.51.3 transformers position_embeddings calculation
position_embeddings = self.rotary_emb(hidden_states, position_ids)

start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0
Expand Down Expand Up @@ -193,6 +198,7 @@ def command_model_forward(
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -203,6 +209,7 @@ def command_model_forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)

hidden_states = layer_outputs[0]
Expand All @@ -224,17 +231,6 @@ def command_model_forward(
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if stage_manager.is_last_stage():
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
Expand Down Expand Up @@ -352,27 +348,22 @@ def command_for_causal_lm_forward(

def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
def forward(
self,
self: CohereAttention,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if sp_mode is not None:
assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet"
assert (sp_size is not None) and (
sp_group is not None
), "Must specify sp_size and sp_group for sequence parallel"
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)

bsz, q_len, _ = hidden_states.size()

# sp: modify sp_len when sequence parallel mode is ring
if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size
Expand All @@ -388,60 +379,36 @@ def forward(
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = position_embeddings

cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = None

shard_config.enable_flash_attention = True

if shard_config.enable_flash_attention:
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

# attn_weights and attn_output calculation is modified on the v4.51.3 of transformers.models.cohere.modeling_cohere.CohereAttention.forward.
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
dropout = 0.0 if not self.training else self.attention_dropout
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.transpose(1, 2).contiguous()
# sp: all-to-all comminucation when introducing sequence parallel
Expand All @@ -451,13 +418,11 @@ def forward(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()

attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
return attn_output, attn_weights

return forward

Expand All @@ -467,24 +432,23 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode

def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
force_sp_output_gather: bool = True,
) -> Union[Tuple, BaseModelOutputWithPast]:
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
Expand Down Expand Up @@ -516,6 +480,8 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

shard_config.enable_flash_attention = True

# in this case, attention_mask is a dict rather than a tensor
if shard_config.enable_flash_attention:
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
Expand All @@ -527,7 +493,8 @@ def forward(
is_causal=True,
)
else:
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
# v4.51.3 transformers attention_mask calculation
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)

if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(
Expand All @@ -544,6 +511,9 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

# v4.51.3 transformers position_embeddings calculation
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
Expand All @@ -557,8 +527,8 @@ def forward(
output_attentions,
use_cache,
cache_position,
position_embeddings,
)

else:
layer_outputs = decoder_layer(
hidden_states,
Expand All @@ -568,16 +538,11 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]

if output_attentions:
all_self_attns += (layer_outputs[1],)

hidden_states = self.norm(hidden_states)

# Cases that don't support parallelizing cross entropy computation along sequence
Expand All @@ -594,8 +559,6 @@ def forward(
next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
)
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
Expand Down
Loading