Skip to content

LLaMA RoPE precision with bf16 model #29301

@marviss

Description

@marviss

System Info

  • transformers == 4.38.1
  • python == 3.10.13
  • torch == 2.2.0

Who can help?

@gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

When using the bf16 argument for the trainer, the resulting freqs tensor is also in bf16 precision, although the operands are explicitly cast to float32:

From https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L126:

inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)  # ---> bf16

Expected behavior

This produces a different output compared to v4.37.2, where the resulting freqs tensor is float32, obtaining different position embeddings for a given sequence length.

From https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py#L142:

t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)  # ---> float32

Is using bf16 precision for freqs with bf16 models the expected behaviour?

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions