Cast multimodal forward_kwargs to compute dtype for bf16/fp16 training#5073
Open
akshan-main wants to merge 1 commit intohuggingface:mainfrom
Open
Cast multimodal forward_kwargs to compute dtype for bf16/fp16 training#5073akshan-main wants to merge 1 commit intohuggingface:mainfrom
akshan-main wants to merge 1 commit intohuggingface:mainfrom
Conversation
Author
|
@qgallouedec hey mate! Can this be reviewed? Let me know if you want any changes. |
Author
|
@codex review |
qgallouedec
reviewed
Feb 16, 2026
tests/test_grpo_trainer.py
Outdated
Comment on lines
2290
to
2322
| def test_forward_kwargs_dtype_casting(): | ||
| forward_kwargs = { | ||
| "pixel_values": torch.randn(1, 3, 224, 224, dtype=torch.float32), | ||
| "image_grid_thw": torch.tensor([[1, 14, 14]]), | ||
| } | ||
|
|
||
| for bf16, fp16, expected_dtype in [ | ||
| (True, False, torch.bfloat16), | ||
| (False, True, torch.float16), | ||
| (False, False, None), | ||
| ]: | ||
| if bf16: | ||
| compute_dtype = torch.bfloat16 | ||
| elif fp16: | ||
| compute_dtype = torch.float16 | ||
| else: | ||
| compute_dtype = None | ||
|
|
||
| if compute_dtype is not None: | ||
| result = { | ||
| k: v.to(compute_dtype) if isinstance(v, torch.Tensor) and torch.is_floating_point(v) else v | ||
| for k, v in forward_kwargs.items() | ||
| } | ||
| else: | ||
| result = forward_kwargs | ||
|
|
||
| if expected_dtype is not None: | ||
| assert result["pixel_values"].dtype == expected_dtype | ||
| else: | ||
| assert result["pixel_values"].dtype == torch.float32 | ||
| assert result["image_grid_thw"].dtype == torch.int64 | ||
|
|
||
|
|
Member
There was a problem hiding this comment.
what are we testing here? just that the to works? if so I think it's out of the scope the trl tests
Author
There was a problem hiding this comment.
You're right, I'll drop the test.
697ba32 to
68b0c28
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
When training VLMs with bf16=True or fp16=True, pixel_values returned by the processor stay float32 after _prepare_inputs (dtype casting is DeepSpeed-specific).
If the vision encoder weights are bfloat16/float16, this can crash in torch.layer_norm with:
This is the next failure reported in the #4451 thread after the prompt-format TypeError.
Note: Fixed by casting floating-point tensors in forward_kwargs to the compute dtype when bf16=True or fp16=True. This is consistent with how the trainer already handles model dtype casting. If the model is loaded in bf16 via torch_dtype without setting the training flag, this path won't trigger, but neither does the existing model casting.
Changes made
In the multimodal path, cast only floating-point tensors in forward_kwargs to the active compute dtype (bf16/fp16).
Leave non-floating tensors (for example image_grid_thw) unchanged.
No prompt-format behavior changes (proposed in #5064 and #5067)
No reward-function behavior changes (proposed in #5064)
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
@qgallouedec @kashif @albertvillanova