System Info
- transformers == 4.38.1
- python == 3.10.13
- torch == 2.2.0
Who can help?
@gante
Information
Tasks
Reproduction
When using the bf16 argument for the trainer, the resulting freqs tensor is also in bf16 precision, although the operands are explicitly cast to float32:
From https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L126:
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) # ---> bf16
Expected behavior
This produces a different output compared to v4.37.2, where the resulting freqs tensor is float32, obtaining different position embeddings for a given sequence length.
From https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py#L142:
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq) # ---> float32
Is using bf16 precision for freqs with bf16 models the expected behaviour?
Thank you!
System Info
Who can help?
@gante
Information
Tasks
examplesfolder (such as GLUE/SQuAD, ...)Reproduction
When using the
bf16argument for the trainer, the resultingfreqstensor is also inbf16precision, although the operands are explicitly cast tofloat32:From https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L126:
Expected behavior
This produces a different output compared to v4.37.2, where the resulting
freqstensor isfloat32, obtaining different position embeddings for a given sequence length.From https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py#L142:
Is using
bf16precision forfreqswithbf16models the expected behaviour?Thank you!