From 32029cad135b0c65192f855a941744c5b9c207c0 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 21 Nov 2023 15:36:48 +0800 Subject: [PATCH 1/2] fix flash attn --- colossalai/shardformer/modeling/chatglm2.py | 3 ++- colossalai/shardformer/modeling/gpt2.py | 9 +++++---- colossalai/shardformer/modeling/llama.py | 3 ++- colossalai/shardformer/modeling/opt.py | 3 ++- colossalai/shardformer/modeling/whisper.py | 3 ++- 5 files changed, 13 insertions(+), 8 deletions(-) diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 8934068d609c..c8a311df7c6d 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -51,7 +51,8 @@ def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_ attn_mask_type = AttnMaskType.causal else: flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - attn_mask_type = AttnMaskType.paddedcausal + if not torch.all(flash_attention_mask): + attn_mask_type = AttnMaskType.paddedcausal attention = ColoAttention( embed_dim=self.hidden_size_per_partition, diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 21f06393071d..8f456353742c 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -771,11 +771,12 @@ def forward( attn_mask_type = AttnMaskType.causal flash_attention_mask = None if attention_mask != None: - if attn_mask_type == AttnMaskType.causal: - attn_mask_type == AttnMaskType.paddedcausal - else: - attn_mask_type = AttnMaskType.padding flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + if not torch.all(flash_attention_mask): + if attn_mask_type == AttnMaskType.causal: + attn_mask_type == AttnMaskType.paddedcausal + else: + attn_mask_type = AttnMaskType.padding scale = value.size(-1) ** -0.5 if self.scale_attn_by_inverse_layer_idx: diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 4bfef45297ea..8006bb3c00e4 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -465,7 +465,8 @@ def forward( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - attn_mask_type = AttnMaskType.paddedcausal + if not torch.all(flash_attention_mask): + attn_mask_type = AttnMaskType.paddedcausal attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) attn_output = attention( diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index e0978d38e110..71f2ca3353bc 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -581,7 +581,8 @@ def forward( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - attn_mask_type = AttnMaskType.paddedcausal + if not torch.all(flash_attention_mask): + attn_mask_type = AttnMaskType.paddedcausal attention = ColoAttention( embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index ef59dbcee680..9f7c0d2cddd7 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -106,7 +106,8 @@ def forward( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous()) - attn_type = AttnMaskType.paddedcausal + if not torch.all(flash_attention_mask): + attn_type = AttnMaskType.paddedcausal attention = ColoAttention( embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling From 0de7dfa74440eb35c30ea7f1c801511483088c5a Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 21 Nov 2023 17:20:30 +0800 Subject: [PATCH 2/2] fix fix --- colossalai/shardformer/modeling/whisper.py | 2 ++ examples/language/llama2/pretrain.py | 1 + 2 files changed, 3 insertions(+) diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 9f7c0d2cddd7..9827d4801f8d 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -108,6 +108,8 @@ def forward( flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous()) if not torch.all(flash_attention_mask): attn_type = AttnMaskType.paddedcausal + else: + attn_type = AttnMaskType.causal attention = ColoAttention( embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py index 6cc73b6265a4..bb10f7a00e8a 100644 --- a/examples/language/llama2/pretrain.py +++ b/examples/language/llama2/pretrain.py @@ -76,6 +76,7 @@ def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = Non def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + tensor = tensor.data tensor.div_(dist.get_world_size()) return tensor