Skip to content
113 changes: 113 additions & 0 deletions colossalai/shardformer/modeling/bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import Optional, Tuple
import math

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import functional as F

__all__ = ['get_bert_forward']

def get_bert_forward():

try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")

def bert_flash_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)

# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None

if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))

query_layer = self.transpose_for_scores(mixed_query_layer)

use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)

final_attention_mask = None
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r

positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility

if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
final_attention_mask = relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
final_attention_mask = relative_position_scores_query + relative_position_scores_key

scale = 1 / math.sqrt(self.attention_head_size)
if attention_mask is not None:
if final_attention_mask != None:
final_attention_mask = final_attention_mask * scale + attention_mask
else:
final_attention_mask = attention_mask
batch_size, src_len = query_layer.size()[0], query_layer.size()[2]
tgt_len = key_layer.size()[2]
final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len, tgt_len)

query_layer = query_layer.permute(0, 2, 1, 3).contiguous()
key_layer = key_layer.permute(0, 2, 1, 3).contiguous()
value_layer = value_layer.permute(0, 2, 1, 3).contiguous()

context_layer = me_attention(query_layer, key_layer, value_layer, attn_bias=final_attention_mask, p=self.dropout.p, scale=scale)
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, None)

if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs

return bert_flash_attention_forward
10 changes: 9 additions & 1 deletion colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from ..modeling.bert import get_bert_forward

__all__ = [
'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy',
Expand Down Expand Up @@ -31,7 +32,7 @@ def preprocess(self):
return self.model

def module_policy(self):
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertSelfAttention

policy = {}

Expand Down Expand Up @@ -120,6 +121,13 @@ def module_policy(self):
)],
policy=policy,
target_key=BertEmbeddings)

# use flash attention
if self.shard_config.enable_flash_attention:
policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={
'forward': get_bert_forward(),
})

return policy

def add_lm_head_policy(self, base_policy):
Expand Down
14 changes: 7 additions & 7 deletions tests/kit/model_zoo/transformers/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def data_gen():
# token_type_ids = tokenized_input['token_type_ids']
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64)
token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 0]], dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)


Expand Down Expand Up @@ -69,19 +69,19 @@ def data_gen_for_mcq():
# data['labels'] = torch.tensor([0], dtype=torch.int64)
input_ids = torch.tensor([[[
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591,
4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102
4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102
],
[
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037,
4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096,
2218, 1999, 1996, 2192, 1012, 102, 0
2218, 1999, 1996, 2192, 1012, 102, 0, 0
]]])
token_type_ids = torch.tensor(
[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]])
[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]])
attention_mask = torch.tensor(
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]])
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]])
labels = torch.tensor([0], dtype=torch.int64)

return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_shardformer/test_model/test_shard_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,11 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo

@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism):
@parameterize('enable_flash_attention', [True, False])
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)

torch.cuda.empty_cache()
Expand Down