Skip to content

FlashAttention2 issue with Mistral/Mixtral related to max length and RotaryEmbedding #31228

@psinger

Description

@psinger

I have been seeing very weird behavior when training and running Mistral or Mixtral with samples being exactly the length of max_position_embeddings. The strange behavior manifested itself with complete broken outputs that interestingly resolved itself after reloading the model and running samples with shorter length through.

So the following combination always broke:
Model with max_position_embeddings=8192 and using FA2 and using some samples with size max_length=8192.
It was resolved by either disabling FA2, or actually using samples with max_length=8191.

After a lot of debugging, I figured out that this issue only happens with Flash Attention 2 and not with SDPA or vanilla attention.

I am suspecting that this issue stems from the following line:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L447

rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)

If we have a batch with a sequence length of let's say 8192, which could be the same as max_position_embeddings, then kv_seq_len will be 8192 which is the max here, but then we are adding 1, which will lead to 8193 and then we are calling rotary_emb with it.

There, we then call:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L214
and thus re-init the cache with a longer than supported max sequence length.

I think it can be already solved by changing it to:
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item() + 1)

I actually noticed that this code has been changed very recently for Mistral to not take the max length and reset it anylonger:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L108

This was done in PR #30642

I think this might have been just a side effect and does not fix Mixtral behavior.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions