Fix bnb fsdp loading for pre-quantized checkpoint#41415
Conversation
|
cc @winglian |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Cyrilvallez
left a comment
There was a problem hiding this comment.
Left a few comments about naming for clarity, otherwise LGTM!
| val_kwargs = value.__dict__ | ||
| if value.dtype in [torch.uint8, torch.int8]: |
There was a problem hiding this comment.
Maybe just value.is_floating_point() if that works?
There was a problem hiding this comment.
that should work ! I think that at some point it should be fine to even remove that if the modules are correctly initialized
| if value.dtype in [torch.uint8, torch.int8]: | ||
| val_kwargs["requires_grad"] = False | ||
| value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__) | ||
| param_to = "meta" if is_fsdp_enabled() and not is_local_dist_rank_0() else "cpu" |
There was a problem hiding this comment.
Let's just call it device IMO, param_to is a bit weird
| def update_param_name(self, param_name: str) -> str: | ||
| """ |
There was a problem hiding this comment.
Let's maybe call it get_param_name instead as it does not update it
| # special case for gpt_oss model, we wait for the param to be leave the meta device before casting it to cpu | ||
| if model.config.model_type == "gpt_oss" and value.device.type == "meta": | ||
| # We need to wait until the quantized value is created | ||
| if value.device.type == "meta": |
There was a problem hiding this comment.
Still a bit weird to me that we have to do this, but I wanted to investigate further anyway to remove the gpt-oss special exception - already happy to see it a bit more general and not gpt-oss-specific!
|
[For maintainers] Suggested jobs to run (before merge) run-slow: mxfp4 |
* fix * fix * get_param_name * fix device name
* fix * fix * get_param_name * fix device name
What does this PR do?
This PR fixes bnb loading when using FSDP for pre-quantized checkpoints. This happened because we changed how we load quantized checkpoints as we need to cache all the quantized stats before creating the quantized weight.