Fix Qwen3Next dtype API usage#41735
Conversation
vasqu
left a comment
There was a problem hiding this comment.
LGTM overall, let's revert the unrelated changes tho!
| # Safety: if the model is sharded across multiple devices (hf_device_map/device_map) and we are | ||
| # doing sampling, enable `remove_invalid_values` by default to avoid NaN/Inf logits causing CUDA | ||
| # asserts during multinomial sampling. Users can still override this by passing the flag explicitly. |
There was a problem hiding this comment.
Unrelated changes? Also in logits_process.py
|
Yeah, if you could revert the unrelated changes this PR is good! |
Replace torch.get_current_dtype() with torch.get_default_dtype() to fix FLA compatibility
15485b7 to
6717030
Compare
|
[For maintainers] Suggested jobs to run (before merge) run-slow: qwen3_next |
|
Thanks for pointing it out, i have made this pr clean and added other changes to #41734. Sorry for the delay. Please check and provide me the feedback or any further improvements. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Replace torch.get_current_dtype() with torch.get_default_dtype() to fix FLA compatibility
Replace torch.get_current_dtype() with torch.get_default_dtype() to fix FLA compatibility
|
This might be a naive question, but how come this didn't make it into v4.57.3? When will this be on a release version? |
Replace torch.get_current_dtype() with torch.get_default_dtype() to fix FLA compatibility
This PR fixes an invalid PyTorch API usage in the Qwen3Next model.
Changes
torch.get_current_dtype()withtorch.get_default_dtype()in both modular and modeling filesTechnical Details
get_current_dtype()get_default_dtype()which returns the global default dtype settingTesting
The changes have been tested with
make fixupand pass the repository consistency checks.Fix #41732