diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index a9a38bce235f..fbb1691ec595 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -4,14 +4,11 @@ __all__ = ['get_gpt2_forward'] + def get_gpt2_forward(): - try: - from xformers.ops import memory_efficient_attention as me_attention - from xformers.ops.fmha.attn_bias import LowerTriangularMask - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - + from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + def gpt2_flash_attention_forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], @@ -30,8 +27,7 @@ def gpt2_flash_attention_forward( if not hasattr(self, "q_attn"): raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." - ) + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`.") query = self.q_attn(hidden_states) key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) @@ -52,31 +48,38 @@ def gpt2_flash_attention_forward( present = (key, value) else: present = None - - attn_bias = None + if not self.is_cross_attention: - attn_bias = LowerTriangularMask() + attn_mask_type = AttnMaskType.causal + flash_attention_mask = None if attention_mask != None: - if attn_bias: - attn_bias.add_bias(attention_mask) + if attn_mask_type == AttnMaskType.causal: + attn_mask_type == AttnMaskType.paddedcausal else: - batch_size, _, tgt_len, src_len = attention_mask.size() - attn_bias = attention_mask.expand(batch_size, self.num_heads, tgt_len, src_len).contiguous() - - scale = value.size(-1) ** -0.5 + attn_mask_type = AttnMaskType.padding + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + + scale = value.size(-1)**-0.5 if self.scale_attn_by_inverse_layer_idx: scale = scale * (1 / float(self.layer_idx + 1)) - attn_output = me_attention(query=query, key=key, value=value, attn_bias=attn_bias, p=self.attn_dropout.p, scale=scale) - - attn_output = merge_heads(attn_output, self.num_heads, self.head_dim) + + # use coloattention + attention = ColoAttention(embed_dim=self.embed_dim, + num_heads=self.num_heads, + dropout=self.attn_dropout.p, + scale=scale) + + attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) + attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present, None) return outputs - + return gpt2_flash_attention_forward + def split_heads(tensor, num_heads, attn_head_size): """ Splits hidden_size dim into attn_head_size and num_heads @@ -84,10 +87,3 @@ def split_heads(tensor, num_heads, attn_head_size): new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) tensor = tensor.view(new_shape) return tensor - -def merge_heads(tensor, num_heads, attn_head_size): - """ - Merges attn_head_size dim and num_attn_heads dim into hidden_size - """ - new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) - return tensor.view(new_shape) \ No newline at end of file