From dd992d2de0e1ba25f8e80b56998bbeee87d92753 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 15 May 2025 11:27:09 +0800 Subject: [PATCH 1/4] fix --- colossalai/shardformer/modeling/vit.py | 2 +- colossalai/shardformer/modeling/whisper.py | 2 ++ colossalai/shardformer/policies/vit.py | 9 +++++---- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index b1a5c4143646..5106d97cf4bc 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -349,7 +349,7 @@ def forward( value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) - dropout_p = self.dropout.p if self.training else 0.0 + dropout_p = self.dropout_prob if self.training else 0.0 context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, dropout_p=dropout_p) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index cf925983be4e..619bbc98e3a0 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -82,6 +82,7 @@ def forward( attention_mask: Optional[dict] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention" @@ -172,6 +173,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 7b7dbf5557aa..d8795291a7f2 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -93,10 +93,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "use_zbv": use_zbv, }, ), - SubModuleReplacementDescription( - suffix="attention.attention.dropout", - target_module=col_nn.DropoutForParallelInput, - ), + # SubModuleReplacementDescription( + # # suffix="attention.attention.dropout", + # suffix="attention.attention.dropout_prob", + # target_module=col_nn.DropoutForParallelInput, + # ), SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, From a162c0bd257da1a6127d3acf29a4bb99e42ca21d Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 15 May 2025 11:29:54 +0800 Subject: [PATCH 2/4] fix --- colossalai/shardformer/policies/vit.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index d8795291a7f2..420ea286fd0a 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -93,11 +93,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "use_zbv": use_zbv, }, ), - # SubModuleReplacementDescription( - # # suffix="attention.attention.dropout", - # suffix="attention.attention.dropout_prob", - # target_module=col_nn.DropoutForParallelInput, - # ), SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, From b853f0596c6d62569a4b4d5d8ec7ecc5aaf098c0 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 15 May 2025 16:12:23 +0800 Subject: [PATCH 3/4] fix rotate embedding test --- .../test_kernels/cuda/test_rotary_embdding_unpad.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py index 57a82647d49b..150bfabb41c8 100644 --- a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py @@ -1,7 +1,7 @@ import numpy as np import pytest import torch -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb, LlamaConfig from colossalai.kernel.kernel_loader import InferenceOpsLoader @@ -33,7 +33,12 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype): position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN)) - emb = LlamaRotaryEmbedding(D) + config = LlamaConfig( + max_position_embeddings=SEQ_LEN, + num_attention_heads=H, + hidden_size=H*D + ) + emb = LlamaRotaryEmbedding(config) cos, sin = emb(x0, position_ids) embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin) From 2e7889ca66a69dc3ae047f3995715c177d0e2d3a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 May 2025 08:13:28 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../test_kernels/cuda/test_rotary_embdding_unpad.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py index 150bfabb41c8..ab3f04c05951 100644 --- a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py @@ -1,7 +1,7 @@ import numpy as np import pytest import torch -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb, LlamaConfig +from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, apply_rotary_pos_emb from colossalai.kernel.kernel_loader import InferenceOpsLoader @@ -33,11 +33,7 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype): position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN)) - config = LlamaConfig( - max_position_embeddings=SEQ_LEN, - num_attention_heads=H, - hidden_size=H*D - ) + config = LlamaConfig(max_position_embeddings=SEQ_LEN, num_attention_heads=H, hidden_size=H * D) emb = LlamaRotaryEmbedding(config) cos, sin = emb(x0, position_ids)