Skip to content
82 changes: 82 additions & 0 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Optional, Tuple

import torch

__all__ = ['get_llama_forward']


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def get_llama_forward():

try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
Comment thread
FrankLeeeee marked this conversation as resolved.

def llama_flash_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape)
value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape)

if attention_mask != None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")
attention_mask = attention_mask.expand(bsz, self.num_heads, q_len, kv_seq_len).contiguous()

attn_output = me_attention(query_states, key_states, value_states, attn_bias=attention_mask)
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)

return attn_output, None, past_key_value

return llama_flash_attention_forward
152 changes: 78 additions & 74 deletions colossalai/shardformer/modeling/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,87 +2,91 @@

import torch

__all__ = ['opt_flash_attention_forward']
__all__ = ['get_opt_forward']


def opt_flash_attention_forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()

attention_input_shape = (bsz, -1, self.num_heads, self.head_dim)
# get query proj
# query_states = self._shape(self.q_proj(hidden_states), -1, bsz)
query_states = self.q_proj(hidden_states).view(*attention_input_shape)
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape)
value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape)
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*attention_input_shape)
value_states = self.v_proj(key_value_states).view(*attention_input_shape)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*attention_input_shape)
value_states = self.v_proj(hidden_states).view(*attention_input_shape)
key_states = torch.cat([past_key_value[0], key_states], dim=1)
value_states = torch.cat([past_key_value[1], value_states], dim=1)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*attention_input_shape)
value_states = self.v_proj(hidden_states).view(*attention_input_shape)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)

src_len = key_states.size(1)
if layer_head_mask != None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}")
if attention_mask != None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}")
attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous()

def get_opt_forward():

try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")

def opt_flash_attention_forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()

attention_input_shape = (bsz, -1, self.num_heads, self.head_dim)
# get query proj
# query_states = self._shape(self.q_proj(hidden_states), -1, bsz)
query_states = self.q_proj(hidden_states).view(*attention_input_shape)
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape)
value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape)
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*attention_input_shape)
value_states = self.v_proj(key_value_states).view(*attention_input_shape)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*attention_input_shape)
value_states = self.v_proj(hidden_states).view(*attention_input_shape)
key_states = torch.cat([past_key_value[0], key_states], dim=1)
value_states = torch.cat([past_key_value[1], value_states], dim=1)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*attention_input_shape)
value_states = self.v_proj(hidden_states).view(*attention_input_shape)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)

src_len = key_states.size(1)
if layer_head_mask != None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}")
if attention_mask != None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}")
attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous()

attn_output = me_attention(query_states,
key_states,
value_states,
attn_bias=attention_mask,
p=self.dropout,
scale=self.scaling)
attn_output = me_attention(query_states,
key_states,
value_states,
attn_bias=attention_mask,
p=self.dropout,
scale=self.scaling)

attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim)
attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value

return opt_flash_attention_forward
8 changes: 7 additions & 1 deletion colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from ..modeling.llama import get_llama_forward

__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']

Expand All @@ -26,7 +27,7 @@ def preprocess(self):
return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel, LlamaAttention

policy = {}

Expand Down Expand Up @@ -98,6 +99,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
),
policy=policy,
target_key=LlamaModel)

if self.shard_config.enable_flash_attention:
policy[LlamaAttention] = ModulePolicyDescription(method_replacement={
'forward': get_llama_forward(),
})

return policy

Expand Down
4 changes: 2 additions & 2 deletions colossalai/shardformer/policies/opt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D

from .._utils import getattr_, setattr_
from ..modeling.opt import opt_flash_attention_forward
from ..modeling.opt import get_opt_forward
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = [
Expand Down Expand Up @@ -93,7 +93,7 @@ def module_policy(self):
# use flash attention
if self.shard_config.enable_flash_attention:
policy[OPTAttention] = ModulePolicyDescription(method_replacement={
'forward': opt_flash_attention_forward,
'forward': get_opt_forward(),
})

return policy
Expand Down
5 changes: 3 additions & 2 deletions tests/test_shardformer/test_model/test_shard_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,11 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo

@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism):
@parameterize('enable_flash_attention', [True, False])
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

Expand Down