diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 4bfef45297ea..0d30b6f7641b 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -413,8 +413,6 @@ def get_llama_flash_attention_forward(): warnings.warn("using llamav1, llamav1 hasn't repeat_kv function") llama_version = 1 - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention - def forward( self: LlamaAttention, hidden_states: torch.Tensor,