From b0ef71e2dc49a2b2d17641c4f8db17ac642cb0d5 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 29 Jul 2024 12:00:47 +0800 Subject: [PATCH] [shardformer] hotfix attn mask --- colossalai/shardformer/modeling/command.py | 2 +- colossalai/shardformer/modeling/llama.py | 2 +- colossalai/shardformer/modeling/mistral.py | 2 +- colossalai/shardformer/modeling/qwen2.py | 8 ++++++-- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 759c8d7b8d59..5b36fc7db3b9 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -116,7 +116,7 @@ def command_model_forward( # for the other stages, hidden_states is the output of the previous stage if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + mask_shape = (batch_size, 1, seq_length, seq_length_with_past) attention_mask = ColoAttention.prepare_attn_kwargs( mask_shape, hidden_states.dtype, diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 54ff8e321e06..9ffbca517d4c 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -643,7 +643,7 @@ def forward( # in this case, attention_mask is a dict rather than a tensor if shard_config.enable_flash_attention: - mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) + mask_shape = (inputs_embeds.shape[0], 1, seq_len, past_seen_tokens + seq_len) attention_mask = ColoAttention.prepare_attn_kwargs( mask_shape, inputs_embeds.dtype, diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 82e8ef5f9af7..ec1a8a00a58a 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -91,7 +91,7 @@ def mistral_model_forward( if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length, seq_length) + mask_shape = (batch_size, 1, seq_length, seq_length + past_key_values_length) attention_mask = ColoAttention.prepare_attn_kwargs( mask_shape, hidden_states.dtype, diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 55822b1505f1..538e96c32c6d 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -136,7 +136,7 @@ def qwen2_model_forward( # for the other stages, hidden_states is the output of the previous stage if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + mask_shape = (batch_size, 1, seq_length, seq_length_with_past) attention_mask = ColoAttention.prepare_attn_kwargs( mask_shape, hidden_states.dtype, @@ -651,6 +651,10 @@ def forward( seq_length_with_past = seq_length past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( @@ -668,7 +672,7 @@ def forward( if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + mask_shape = (batch_size, 1, seq_length, seq_length_with_past) attention_mask = ColoAttention.prepare_attn_kwargs( mask_shape, hidden_states.dtype,