Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions colossalai/shardformer/modeling/bloom.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Optional, Tuple

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import functional as F


def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
Expand Down Expand Up @@ -67,3 +70,74 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int,
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)

return build_bloom_alibi_tensor

def get_bloom_forward():

try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.bloom.modeling_bloom import dropout_add

def bloom_flash_attention_forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,):

fused_qkv = self.query_key_value(hidden_states)
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, tgt_len, _ = hidden_states.size()
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."

_, kv_length, _, _ = key_layer.size()

proj_shape = (batch_size, tgt_len, self.num_heads, self.head_dim)
query_layer = query_layer.contiguous().view(*proj_shape)
key_layer = key_layer.contiguous().view(*proj_shape)
value_layer = value_layer.contiguous().view(*proj_shape)

if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, head_dim, kv_length]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)

if use_cache is True:
present = (key_layer, value_layer)
else:
present = None

tgt_len = key_layer.size()[1]

attention_numerical_mask = torch.zeros((batch_size, self.num_heads, tgt_len, kv_length), dtype=torch.float32, device=query_layer.device, requires_grad=True)
attention_numerical_mask = attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta
attention_numerical_mask = torch.masked_fill(attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min)

context_layer = me_attention(query_layer, key_layer, value_layer, attn_bias=attention_numerical_mask, scale=self.inv_norm_factor, p=self.attention_dropout.p)
context_layer = context_layer.reshape(-1, kv_length, self.hidden_size)
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)

# TODO to replace with the bias_dropout_add function in jit
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
outputs = (output_tensor, present, None)

return outputs

return bloom_flash_attention_forward
1 change: 1 addition & 0 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def llama_flash_attention_forward(
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."

query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
Expand Down
1 change: 1 addition & 0 deletions colossalai/shardformer/modeling/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def opt_flash_attention_forward(
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."

attention_input_shape = (bsz, -1, self.num_heads, self.head_dim)
# get query proj
Expand Down
9 changes: 7 additions & 2 deletions colossalai/shardformer/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import colossalai.shardformer.layer as col_nn

from .._utils import getattr_, setattr_
from ..modeling.bloom import build_bloom_alibi_tensor_fn
from ..modeling.bloom import build_bloom_alibi_tensor_fn, get_bloom_forward
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription


Expand All @@ -25,7 +25,7 @@ def preprocess(self):
return self.model

def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, BloomAttention

policy = {}

Expand Down Expand Up @@ -101,6 +101,11 @@ def module_policy(self):
],
policy=policy,
target_key=BloomBlock)

if self.shard_config.enable_flash_attention:
policy[BloomAttention] = ModulePolicyDescription(method_replacement={
'forward': get_bloom_forward(),
})

return policy

Expand Down
10 changes: 5 additions & 5 deletions tests/kit/model_zoo/transformers/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def data_gen():
# tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64)
input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595, 632, 207595]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)


Expand All @@ -33,7 +33,7 @@ def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64)
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
return data


Expand All @@ -53,8 +53,8 @@ def data_gen_for_question_answering():
# inputs = tokenizer(question, text, return_tensors="pt")

input_ids = torch.tensor(
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)


Expand Down
5 changes: 3 additions & 2 deletions tests/test_shardformer/test_model/test_shard_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,11 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo

@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism):
@parameterize('enable_flash_attention', [True, False])
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

Expand Down