Skip to content

Cast multimodal forward_kwargs to compute dtype for bf16/fp16 training#5073

Open
akshan-main wants to merge 1 commit intohuggingface:mainfrom
akshan-main:fix_grpo_vlm_pixel_dtype
Open

Cast multimodal forward_kwargs to compute dtype for bf16/fp16 training#5073
akshan-main wants to merge 1 commit intohuggingface:mainfrom
akshan-main:fix_grpo_vlm_pixel_dtype

Conversation

@akshan-main
Copy link

@akshan-main akshan-main commented Feb 11, 2026

What does this PR do?

When training VLMs with bf16=True or fp16=True, pixel_values returned by the processor stay float32 after _prepare_inputs (dtype casting is DeepSpeed-specific).
If the vision encoder weights are bfloat16/float16, this can crash in torch.layer_norm with:

RuntimeError: expected scalar type BFloat16 but found Float

This is the next failure reported in the #4451 thread after the prompt-format TypeError.

Note: Fixed by casting floating-point tensors in forward_kwargs to the compute dtype when bf16=True or fp16=True. This is consistent with how the trainer already handles model dtype casting. If the model is loaded in bf16 via torch_dtype without setting the training flag, this path won't trigger, but neither does the existing model casting.

Changes made

In the multimodal path, cast only floating-point tensors in forward_kwargs to the active compute dtype (bf16/fp16).
Leave non-floating tensors (for example image_grid_thw) unchanged.

No prompt-format behavior changes (proposed in #5064 and #5067)
No reward-function behavior changes (proposed in #5064)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

@qgallouedec @kashif @albertvillanova

@akshan-main
Copy link
Author

akshan-main commented Feb 11, 2026

@qgallouedec hey mate! Can this be reviewed? Let me know if you want any changes.

@akshan-main
Copy link
Author

@codex review

Comment on lines 2290 to 2322
def test_forward_kwargs_dtype_casting():
forward_kwargs = {
"pixel_values": torch.randn(1, 3, 224, 224, dtype=torch.float32),
"image_grid_thw": torch.tensor([[1, 14, 14]]),
}

for bf16, fp16, expected_dtype in [
(True, False, torch.bfloat16),
(False, True, torch.float16),
(False, False, None),
]:
if bf16:
compute_dtype = torch.bfloat16
elif fp16:
compute_dtype = torch.float16
else:
compute_dtype = None

if compute_dtype is not None:
result = {
k: v.to(compute_dtype) if isinstance(v, torch.Tensor) and torch.is_floating_point(v) else v
for k, v in forward_kwargs.items()
}
else:
result = forward_kwargs

if expected_dtype is not None:
assert result["pixel_values"].dtype == expected_dtype
else:
assert result["pixel_values"].dtype == torch.float32
assert result["image_grid_thw"].dtype == torch.int64


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are we testing here? just that the to works? if so I think it's out of the scope the trl tests

Copy link
Author

@akshan-main akshan-main Feb 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I'll drop the test.

@akshan-main akshan-main force-pushed the fix_grpo_vlm_pixel_dtype branch from 697ba32 to 68b0c28 Compare February 16, 2026 17:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments