Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 24 additions & 28 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@

__all__ = ['get_gpt2_forward']


def get_gpt2_forward():

try:
from xformers.ops import memory_efficient_attention as me_attention
from xformers.ops.fmha.attn_bias import LowerTriangularMask
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")

from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention

def gpt2_flash_attention_forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
Expand All @@ -30,8 +27,7 @@ def gpt2_flash_attention_forward(
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`.")

query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
Expand All @@ -52,42 +48,42 @@ def gpt2_flash_attention_forward(
present = (key, value)
else:
present = None

attn_bias = None

if not self.is_cross_attention:
attn_bias = LowerTriangularMask()
attn_mask_type = AttnMaskType.causal
flash_attention_mask = None
if attention_mask != None:
if attn_bias:
attn_bias.add_bias(attention_mask)
if attn_mask_type == AttnMaskType.causal:
attn_mask_type == AttnMaskType.paddedcausal
else:
batch_size, _, tgt_len, src_len = attention_mask.size()
attn_bias = attention_mask.expand(batch_size, self.num_heads, tgt_len, src_len).contiguous()
scale = value.size(-1) ** -0.5
attn_mask_type = AttnMaskType.padding
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()

scale = value.size(-1)**-0.5
if self.scale_attn_by_inverse_layer_idx:
scale = scale * (1 / float(self.layer_idx + 1))
attn_output = me_attention(query=query, key=key, value=value, attn_bias=attn_bias, p=self.attn_dropout.p, scale=scale)

attn_output = merge_heads(attn_output, self.num_heads, self.head_dim)

# use coloattention
attention = ColoAttention(embed_dim=self.embed_dim,
num_heads=self.num_heads,
dropout=self.attn_dropout.p,
scale=scale)

attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)

attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present, None)

return outputs

return gpt2_flash_attention_forward


def split_heads(tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor

def merge_heads(tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)