Skip to content
Merged
93 changes: 93 additions & 0 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Optional, Tuple, Union

import torch

__all__ = ['get_gpt2_forward']

def get_gpt2_forward():

try:
from xformers.ops import memory_efficient_attention as me_attention
from xformers.ops.fmha.attn_bias import LowerTriangularMask
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")

def gpt2_flash_attention_forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
_, tgt_len, _ = hidden_states.size()
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."

if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = split_heads(query, self.num_heads, self.head_dim)
key = split_heads(key, self.num_heads, self.head_dim)
value = split_heads(value, self.num_heads, self.head_dim)

if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1)

if use_cache is True:
present = (key, value)
else:
present = None

attn_bias = None
if not self.is_cross_attention:
attn_bias = LowerTriangularMask()
if attention_mask != None:
if attn_bias:
attn_bias.add_bias(attention_mask)
else:
batch_size, _, tgt_len, src_len = attention_mask.size()
attn_bias = attention_mask.expand(batch_size, self.num_heads, tgt_len, src_len).contiguous()

scale = value.size(-1) ** -0.5
if self.scale_attn_by_inverse_layer_idx:
scale = scale * (1 / float(self.layer_idx + 1))
attn_output = me_attention(query=query, key=key, value=value, attn_bias=attn_bias, p=self.attn_dropout.p, scale=scale)

attn_output = merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present, None)

return outputs

return gpt2_flash_attention_forward

def split_heads(tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor

def merge_heads(tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
8 changes: 7 additions & 1 deletion colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from ..modeling.gpt2 import get_gpt2_forward

__all__ = [
'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy',
Expand All @@ -29,7 +30,7 @@ def preprocess(self):
return self.model

def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model, GPT2Attention

policy = {}

Expand Down Expand Up @@ -106,6 +107,11 @@ def module_policy(self):
],
policy=policy,
target_key=GPT2Block)

if self.shard_config.enable_flash_attention:
policy[GPT2Attention] = ModulePolicyDescription(method_replacement={
'forward': get_gpt2_forward(),
})
return policy

def postprocess(self):
Expand Down
6 changes: 3 additions & 3 deletions tests/kit/model_zoo/transformers/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def data_gen():
# tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64)
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)


Expand All @@ -33,7 +33,7 @@ def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64)
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
return data


Expand Down
5 changes: 3 additions & 2 deletions tests/test_shardformer/test_model/test_shard_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,11 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo

@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism):
@parameterize('enable_flash_attention', [True, False])
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention):
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

Expand Down