From ef6eaec80027c47842ebb48cd523f3f4ff353bbc Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 28 Dec 2023 10:46:44 +0800 Subject: [PATCH] fix flash attn --- colossalai/shardformer/modeling/llama.py | 7 +++---- colossalai/shardformer/policies/llama.py | 4 +++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 286852899dc1..1b53ce4afebb 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -414,7 +414,7 @@ def llama_for_sequence_classification_forward( return {"hidden_states": hidden_states} -def get_llama_flash_attention_forward(): +def get_llama_flash_attention_forward(shard_config: ShardConfig): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention @@ -470,14 +470,13 @@ def forward( flash_attention_mask = None attn_mask_type = AttnMaskType.causal - if attention_mask != None: + if not getattr(shard_config, "causal_lm", False) and attention_mask != None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - if not torch.all(flash_attention_mask): - attn_mask_type = AttnMaskType.paddedcausal + attn_mask_type = AttnMaskType.paddedcausal attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) attn_output = attention( diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index eee2259f2c56..08b99ee526cd 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -126,7 +126,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_llama_flash_attention_forward(), + "forward": get_llama_flash_attention_forward(self.shard_config), }, policy=policy, target_key=LlamaAttention, @@ -206,6 +206,8 @@ def module_policy(self): policy = super().module_policy() + setattr(self.shard_config, "causal_lm", True) + if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm new_item = {