Fix qwen3_vl mix precision dtype#41701
Conversation
|
[For maintainers] Suggested jobs to run (before merge) run-slow: qwen3_omni_moe, qwen3_vl, qwen3_vl_moe |
|
|
||
| pos_embeds = self.fast_pos_embed_interpolate(grid_thw) | ||
| hidden_states = hidden_states + pos_embeds | ||
| hidden_states = (hidden_states + pos_embeds).to(input_dtype) |
There was a problem hiding this comment.
i think we should cast only pos_embeds to the input dtype here
There was a problem hiding this comment.
Agreed, further couldn't we fix in fast_pos_embed_interpolate instead of recasting? To avoid too many conversions - could for instance pass the input_dtype.
In h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) passing the wanted dtype should be enough no?
There was a problem hiding this comment.
Thought the same thing, but I am not sure if positional embedding was intentionally done in full precision for better performance 🤔
There was a problem hiding this comment.
Yeah, this change intentionally keeps it running in FP32/same dtype as master weights for now, without changing the numerical dynamics.
Casting only pos_embeds or fixing inside fast_pos_embed_interpolate would have numerical implications, which requires ablations with training results if we want to be careful.
Happy to discuss — I'm leaning towards not changing model behaviors for now.
There was a problem hiding this comment.
fast_pos_embed_interpolatereturnspos_embedsin the same dtype as the master weights.Therefore, when the master weights are in FP32 but the forward pass runs in BF16,
hidden_stateswill be upcast to FP32, causing dtype mismatches with other activations.CC @yonigozlan @molbap @zucchini-nlp