I have been seeing very weird behavior when training and running Mistral or Mixtral with samples being exactly the length of max_position_embeddings. The strange behavior manifested itself with complete broken outputs that interestingly resolved itself after reloading the model and running samples with shorter length through.
So the following combination always broke:
Model with max_position_embeddings=8192 and using FA2 and using some samples with size max_length=8192.
It was resolved by either disabling FA2, or actually using samples with max_length=8191.
After a lot of debugging, I figured out that this issue only happens with Flash Attention 2 and not with SDPA or vanilla attention.
I am suspecting that this issue stems from the following line:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L447
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
If we have a batch with a sequence length of let's say 8192, which could be the same as max_position_embeddings, then kv_seq_len will be 8192 which is the max here, but then we are adding 1, which will lead to 8193 and then we are calling rotary_emb with it.
There, we then call:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L214
and thus re-init the cache with a longer than supported max sequence length.
I think it can be already solved by changing it to:
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item() + 1)
I actually noticed that this code has been changed very recently for Mistral to not take the max length and reset it anylonger:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L108
This was done in PR #30642
I think this might have been just a side effect and does not fix Mixtral behavior.
I have been seeing very weird behavior when training and running Mistral or Mixtral with samples being exactly the length of
max_position_embeddings. The strange behavior manifested itself with complete broken outputs that interestingly resolved itself after reloading the model and running samples with shorter length through.So the following combination always broke:
Model with
max_position_embeddings=8192and using FA2 and using some samples with sizemax_length=8192.It was resolved by either disabling FA2, or actually using samples with
max_length=8191.After a lot of debugging, I figured out that this issue only happens with Flash Attention 2 and not with SDPA or vanilla attention.
I am suspecting that this issue stems from the following line:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L447
If we have a batch with a sequence length of let's say
8192, which could be the same asmax_position_embeddings, thenkv_seq_lenwill be8192which is the max here, but then we are adding1, which will lead to8193and then we are callingrotary_embwith it.There, we then call:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L214
and thus re-init the cache with a longer than supported max sequence length.
I think it can be already solved by changing it to:
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item() + 1)I actually noticed that this code has been changed very recently for Mistral to not take the max length and reset it anylonger:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L108
This was done in PR #30642
I think this might have been just a side effect and does not fix Mixtral behavior.