Skip to content

Fix bnb fsdp loading for pre-quantized checkpoint#41415

Merged
SunMarc merged 4 commits intomainfrom
fix-fsdp-quant
Oct 9, 2025
Merged

Fix bnb fsdp loading for pre-quantized checkpoint#41415
SunMarc merged 4 commits intomainfrom
fix-fsdp-quant

Conversation

@SunMarc
Copy link
Copy Markdown
Member

@SunMarc SunMarc commented Oct 7, 2025

What does this PR do?

This PR fixes bnb loading when using FSDP for pre-quantized checkpoints. This happened because we changed how we load quantized checkpoints as we need to cache all the quantized stats before creating the quantized weight.

@SunMarc SunMarc changed the title Fix bnb fsdp loading Fix bnb fsdp loading for pre-quantized checkpoint Oct 7, 2025
@SunMarc SunMarc requested a review from Cyrilvallez October 7, 2025 15:31
@SunMarc
Copy link
Copy Markdown
Member Author

SunMarc commented Oct 7, 2025

cc @winglian

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@SunMarc SunMarc added the for patch Tag issues / labels that should be included in the next patch label Oct 7, 2025
Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments about naming for clarity, otherwise LGTM!

Comment thread src/transformers/modeling_utils.py Outdated
Comment on lines +772 to +773
val_kwargs = value.__dict__
if value.dtype in [torch.uint8, torch.int8]:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just value.is_floating_point() if that works?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that should work ! I think that at some point it should be fine to even remove that if the modules are correctly initialized

Comment thread src/transformers/modeling_utils.py Outdated
if value.dtype in [torch.uint8, torch.int8]:
val_kwargs["requires_grad"] = False
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
param_to = "meta" if is_fsdp_enabled() and not is_local_dist_rank_0() else "cpu"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just call it device IMO, param_to is a bit weird

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +157 to +158
def update_param_name(self, param_name: str) -> str:
"""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maybe call it get_param_name instead as it does not update it

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines -769 to +770
# special case for gpt_oss model, we wait for the param to be leave the meta device before casting it to cpu
if model.config.model_type == "gpt_oss" and value.device.type == "meta":
# We need to wait until the quantized value is created
if value.device.type == "meta":
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still a bit weird to me that we have to do this, but I wanted to investigate further anyway to remove the gpt-oss special exception - already happy to see it a bit more general and not gpt-oss-specific!

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Oct 9, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: mxfp4

@SunMarc SunMarc merged commit 823fab4 into main Oct 9, 2025
26 checks passed
@SunMarc SunMarc deleted the fix-fsdp-quant branch October 9, 2025 16:05
AhnJoonSung pushed a commit to AhnJoonSung/transformers that referenced this pull request Oct 12, 2025
* fix

* fix

* get_param_name

* fix device name
Cyrilvallez pushed a commit that referenced this pull request Oct 14, 2025
* fix

* fix

* get_param_name

* fix device name
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

for patch Tag issues / labels that should be included in the next patch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants