diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index a3d774ff2abb..4b6458ae0627 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -1,6 +1,9 @@ +from typing import Optional, Tuple + import torch import torch.distributed as dist from torch.distributed import ProcessGroup +from torch.nn import functional as F def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: @@ -67,3 +70,74 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) return build_bloom_alibi_tensor + +def get_bloom_forward(): + + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + from transformers.models.bloom.modeling_bloom import dropout_add + + def bloom_flash_attention_forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False,): + + fused_qkv = self.query_key_value(hidden_states) + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + batch_size, tgt_len, _ = hidden_states.size() + assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." + + _, kv_length, _, _ = key_layer.size() + + proj_shape = (batch_size, tgt_len, self.num_heads, self.head_dim) + query_layer = query_layer.contiguous().view(*proj_shape) + key_layer = key_layer.contiguous().view(*proj_shape) + value_layer = value_layer.contiguous().view(*proj_shape) + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, head_dim, kv_length] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + if use_cache is True: + present = (key_layer, value_layer) + else: + present = None + + tgt_len = key_layer.size()[1] + + attention_numerical_mask = torch.zeros((batch_size, self.num_heads, tgt_len, kv_length), dtype=torch.float32, device=query_layer.device, requires_grad=True) + attention_numerical_mask = attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta + attention_numerical_mask = torch.masked_fill(attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min) + + context_layer = me_attention(query_layer, key_layer, value_layer, attn_bias=attention_numerical_mask, scale=self.inv_norm_factor, p=self.attention_dropout.p) + context_layer = context_layer.reshape(-1, kv_length, self.hidden_size) + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + # TODO to replace with the bias_dropout_add function in jit + output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + outputs = (output_tensor, present, None) + + return outputs + + return bloom_flash_attention_forward \ No newline at end of file diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index b4b9b21a807c..aab56f056436 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -40,6 +40,7 @@ def llama_flash_attention_forward( use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() + assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 733f86944bc3..c02c17d023d7 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -27,6 +27,7 @@ def opt_flash_attention_forward( # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() + assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) # get query proj diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index a0b5340f72bc..daae3cf881a2 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -3,7 +3,7 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ -from ..modeling.bloom import build_bloom_alibi_tensor_fn +from ..modeling.bloom import build_bloom_alibi_tensor_fn, get_bloom_forward from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -25,7 +25,7 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel + from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, BloomAttention policy = {} @@ -101,6 +101,11 @@ def module_policy(self): ], policy=policy, target_key=BloomBlock) + + if self.shard_config.enable_flash_attention: + policy[BloomAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_bloom_forward(), + }) return policy diff --git a/tests/kit/model_zoo/transformers/bloom.py b/tests/kit/model_zoo/transformers/bloom.py index 71146c0b9819..e9dc25378e63 100644 --- a/tests/kit/model_zoo/transformers/bloom.py +++ b/tests/kit/model_zoo/transformers/bloom.py @@ -16,8 +16,8 @@ def data_gen(): # tokenized_input = tokenizer(input, return_tensors='pt') # input_ids = tokenized_input['input_ids'] # attention_mask = tokenized_input['attention_mask'] - input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64) + input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595, 632, 207595]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) return dict(input_ids=input_ids, attention_mask=attention_mask) @@ -33,7 +33,7 @@ def data_gen_for_token_classification(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64) + data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) return data @@ -53,8 +53,8 @@ def data_gen_for_question_answering(): # inputs = tokenizer(question, text, return_tensors="pt") input_ids = torch.tensor( - [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) return dict(input_ids=input_ids, attention_mask=attention_mask) diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index a3389652269c..63438bbe9a91 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -69,10 +69,11 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_flash_attention', [True, False]) +def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache()