Skip to content

Fix dtype mismatches in SwitchTransformers and TimmWrapperModel for bfloat16#45085

Closed
hkc5 wants to merge 1 commit intohuggingface:mainfrom
hkc5:fix-45072-dtype-mismatches
Closed

Fix dtype mismatches in SwitchTransformers and TimmWrapperModel for bfloat16#45085
hkc5 wants to merge 1 commit intohuggingface:mainfrom
hkc5:fix-45072-dtype-mismatches

Conversation

@hkc5
Copy link
Copy Markdown

@hkc5 hkc5 commented Mar 28, 2026

This PR fixes #45072.

Changes

SwitchTransformers

  • Fixed a bug in SwitchTransformersTop1Router.forward() where router_logits was being reassigned to the max probability values instead of keeping the raw logits from the classifier. This caused dtype mismatches when using bfloat16.
  • The fix introduces router_max_probs as a new variable to hold the max probability values, while router_logits now correctly retains the raw logits for computing the router z-loss.

TimmWrapperModel

  • Fixed a bug where pixel_values was only being cast to the device but not to the model's dtype in TimmWrapperModel.forward(). This caused dtype mismatches when the model was loaded in bfloat16 but the input was float32.
  • The fix aligns TimmWrapperModel with TimmWrapperForImageClassification which already had the correct behavior.

Testing

Both fixes have been verified to resolve the dtype mismatch issues described in #45072.

…float16

- SwitchTransformersTop1Router: Don't reassign router_logits, use router_max_probs instead

- TimmWrapperModel: Cast pixel_values to model dtype in forward()

Fixes huggingface#45072
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: switch_transformers, timm_wrapper

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45085&sha=4005bd

@Rocketknight1
Copy link
Copy Markdown
Member

Duplicate of #45074!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG][CI] SwitchTransformers and TimmWrapperModel dtype mismatches in bfloat16 inference

2 participants