diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 8a6a7cf17e08..528419654784 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -493,7 +493,6 @@ def forward( if sp_mode in ["split_gather", "ring"]: q_len *= sp_size - assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states)