Hi everyone,
When running a Llama3.1 training job with FSDP and BF16 mixed precision, we noticed a large gap in train/val loss between FP8 autocast enabled vs. disabled which we do not see with Llama2.
For context, the main difference between Llama3.1 and Llama2 is that the rotary embedding initialization is different (main relevant portion is that the RoPE rotary base value is changed from 10k to 500k).
After some investigation we found that this was because although the RoPE inverse frequencies are created in FP32 (and cast to FP32 during fused rope application), the query/key layers are in BF16 (due to mixed precision in our case), they are passed to tex.fused_rope_forward(t, freqs) as is (which I assume is intentional).
However, this appears to have different behavior than passing both t and freqs explicitly as FP32, which I assume is some loss of precision with the fused rope kernel.
We can verify this with the following test script:
# TE main (after 1.11 at time of writing)
# needed for the tunable `rotary_base` change in `RotaryPositionEmbedding`
from transformer_engine.pytorch.attention import (
apply_rotary_pos_emb,
RotaryPositionEmbedding,
)
import torch
S = 2 # seqlen
B = 1 # bs
H = 1 # num attn heads per partition
D = 8 # hidden size per attn head
device = torch.device("cuda:0")
rope = RotaryPositionEmbedding(dim=D, rotary_base=10000)
# FP32 rope frequencies
rotary_emb = rope(max_seq_len=S)
# BF16 input tensor
t1 = torch.rand((S, B, H, D), dtype=torch.bfloat16, device=device)
# FP32 input tensor
t2 = t1.detach().clone().to(torch.float32)
t1_out = apply_rotary_pos_emb(t=t1, freqs=rotary_emb, fused=True).to(torch.float32)
t2_out = apply_rotary_pos_emb(t=t2, freqs=rotary_emb, fused=True)
# Assertion fails here
assert torch.allclose(t1_out, t2_out)
This precision loss appears to be fine for the smaller rotary_base=10k for Llama2, but is exaggerated when rotary_base=500k and even further exaggerated when FP8 is enabled.
Llama with rotary_base=500k.
- blue=FP8 autocast off
- orange=FP8 autocast on

Llama with rotary_base=500k with the query/key layers upcast to FP32 before FusedRoPEFunc.apply().
- pink=FP8 autocast off
- green=FP8 autocast on

What I think is happening, is that even though the RoPE computation itself is not being done in FP8, the precision loss is being amplified and accumulating across layers/iterations when FP8 is enabled.
Thus, I think there should be an explicit upcast to FP32 for the query/key layers and so the logic in MultiheadAttention.forward() should look something like this:
orig_q_dtype = query_layer.dtype
orig_k_dtype = key_layer.dtype
# Upcast qk to FP32.
query_layer = query_layer.to(torch.float32)
key_layer = key_layer.to(torch.float32)
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)
# Cast qk back to orig dtype.
query_layer = query_layer.to(orig_q_dtype)
key_layer = key_layer.to(orig_k_dtype)
It could also be an upcast_to_fp32 flag in apply_rotary_pos_emb, either way looks good to me.
This should probably be applied to all cases, but if not all, at least when FP8 is enabled to mitigate its impact.
Although slightly different, HuggingFace appeared to also see similar precision issues with RoPE, though this is just showing how impactful RoPE precision is.
Note 1: I did not investigate the unfused rope case as it is not used in MultiheadAttention.
Note 2: The Llama runs were run with TE=v1.10, with a patched RotaryPositionEmbedding to allow the tuning of rotary_base.
What do you guys think?
Thanks!
Hi everyone,
When running a Llama3.1 training job with FSDP and BF16 mixed precision, we noticed a large gap in train/val loss between FP8 autocast enabled vs. disabled which we do not see with Llama2.
For context, the main difference between Llama3.1 and Llama2 is that the rotary embedding initialization is different (main relevant portion is that the RoPE rotary base value is changed from 10k to 500k).
After some investigation we found that this was because although the RoPE inverse frequencies are created in FP32 (and cast to FP32 during fused rope application), the query/key layers are in BF16 (due to mixed precision in our case), they are passed to
tex.fused_rope_forward(t, freqs)as is (which I assume is intentional).However, this appears to have different behavior than passing both
tandfreqsexplicitly as FP32, which I assume is some loss of precision with the fused rope kernel.We can verify this with the following test script:
This precision loss appears to be fine for the smaller
rotary_base=10kfor Llama2, but is exaggerated whenrotary_base=500kand even further exaggerated when FP8 is enabled.Llama with
rotary_base=500k.Llama with
rotary_base=500kwith the query/key layers upcast to FP32 beforeFusedRoPEFunc.apply().What I think is happening, is that even though the RoPE computation itself is not being done in FP8, the precision loss is being amplified and accumulating across layers/iterations when FP8 is enabled.
Thus, I think there should be an explicit upcast to FP32 for the query/key layers and so the logic in MultiheadAttention.forward() should look something like this:
It could also be an
upcast_to_fp32flag inapply_rotary_pos_emb, either way looks good to me.This should probably be applied to all cases, but if not all, at least when FP8 is enabled to mitigate its impact.
Although slightly different, HuggingFace appeared to also see similar precision issues with RoPE, though this is just showing how impactful RoPE precision is.
Note 1: I did not investigate the unfused rope case as it is not used in
MultiheadAttention.Note 2: The Llama runs were run with TE=v1.10, with a patched
RotaryPositionEmbeddingto allow the tuning ofrotary_base.What do you guys think?
Thanks!