fix(models): Fix dtype mismatch in SwitchTransformers and TimmWrapperModel#45074
Conversation
| Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding to raw router logits. | ||
| This is used later for computing router z-loss. | ||
| """ | ||
| # float32 is used to ensure stability. See the discussion of "selective precision" in |
There was a problem hiding this comment.
I'm a bit worried about this comment - it looks like we're not casting to float32 anymore after this change? We should probably be upcasting the classifier weights to float32 instead of downcasting the hidden states to bfloat16, right?
There was a problem hiding this comment.
Sorry about that you're right + in agreement with the original modeling file which before 7938e91fa as mentioned in the PR desc., also upcasted the classifier weights to float32 to match the hidden_states dtype rather than downcast as you rightly mentioned (ctxt: _cast_classifier())! I've restored the applicable logic inline :)
→ Just a nit explaining why I dropped this check, so in the current bitsandbytes 0.49.2, or in any version >= BITSANDBYTES_MIN_VERSION = "0.46.1" (from import_utils.py), SCB or CB aren't direct attributes on the layer. So the guard's pretty much a no-op and doesn't serve us.
|
[For maintainers] Suggested jobs to run (before merge) run-slow: switch_transformers, timm_wrapper |
|
P.S. Just bumping this comment from the last PR as I think it got buried in the CI bot spam. Won't bump this again; I completely understand that maintainers are swamped and your time is extremely valuable! :) |
|
@Rocketknight1 Good day; just checking in to see if there are any updates :) |
Rocketknight1
left a comment
There was a problem hiding this comment.
Yes, LGTM now! Also I'm not in too many other places right now; my Twitter is mostly inactive since I wasn't really enjoying the flood of racism in the feed. You can ping me on Discord if you need me for anything though!
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@Rocketknight1 I understand. Sure thing; I've sent you a Discord DM. Thanks a lot :) |
…Model (huggingface#45074) * fix: Cast inputs to match weight dtype * new: Add test * change: Upcast to float32 instead of downcasting
…Model (huggingface#45074) * fix: Cast inputs to match weight dtype * new: Add test * change: Upcast to float32 instead of downcasting
…Model (huggingface#45074) * fix: Cast inputs to match weight dtype * new: Add test * change: Upcast to float32 instead of downcasting
What does this PR do?
The following dtype mismatch use cases were identified and fixed in this PR:
→ Switch Transformers: 7938e91fa refactored all MoE models for vLLM compatibility; in that refactor, the
_cast_classifier()method was removed fromSwitchTransformersTop1Routerbut no dtype cast was added. Casting hidden_states toclassifier.weight.dtypebefore the linear call fixes that!→ TimmWrapper: 6217adc6c8 changed the default dtype behavior to "auto"; in that commit,
pixel_values.to(self.device, self.dtype)was regressed to pixel_values.to(self.device) dropping the dtype cast. I'm not too sure why it was dropped; but restoring it seems logical to fix the use case.→ For more details on reproducing the bug and the output screenshots, please visit the linked issue!
cc: @Rocketknight1
Fixes #45072
CI run test coverage of this behavior (as suggested by @ydshieh) :):
SwitchTransformers:
→
test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_generate_with_past_key_values→
test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_model_fp16_forward→
test_modeling_switch_transformers.py::SwitchTransformerModelIntegrationTests::test_small_logitsTimmWrapper:
→
TimmWrapperModelTestdoes not have explicit bfloat16 forward pass tests; added one in this PR for complete coverage.Repro output after the fixes (feel free to cross-check):
Code Agent Policy
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.