diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index b1a5c4143646..5106d97cf4bc 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -349,7 +349,7 @@ def forward( value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) - dropout_p = self.dropout.p if self.training else 0.0 + dropout_p = self.dropout_prob if self.training else 0.0 context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, dropout_p=dropout_p) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index cf925983be4e..619bbc98e3a0 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -82,6 +82,7 @@ def forward( attention_mask: Optional[dict] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention" @@ -172,6 +173,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 7b7dbf5557aa..420ea286fd0a 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -93,10 +93,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "use_zbv": use_zbv, }, ), - SubModuleReplacementDescription( - suffix="attention.attention.dropout", - target_module=col_nn.DropoutForParallelInput, - ), SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, diff --git a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py index 57a82647d49b..ab3f04c05951 100644 --- a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py @@ -1,7 +1,7 @@ import numpy as np import pytest import torch -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb +from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, apply_rotary_pos_emb from colossalai.kernel.kernel_loader import InferenceOpsLoader @@ -33,7 +33,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype): position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN)) - emb = LlamaRotaryEmbedding(D) + config = LlamaConfig(max_position_embeddings=SEQ_LEN, num_attention_heads=H, hidden_size=H * D) + emb = LlamaRotaryEmbedding(config) cos, sin = emb(x0, position_ids) embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)