Skip to content

Update param_element_size#42818

Merged
SunMarc merged 17 commits intomainfrom
clean-param-size
Dec 18, 2025
Merged

Update param_element_size#42818
SunMarc merged 17 commits intomainfrom
clean-param-size

Conversation

@SunMarc
Copy link
Copy Markdown
Member

@SunMarc SunMarc commented Dec 11, 2025

What does this PR do?

This PR fixes how we calculate the param size for quantized models. This should be simpler to hack around !

I added some tests to check that it works for the following methods bnb, finegrained_fp8, torchao, mxfp4, quanto.

cc @Cyrilvallez as you were concerned at some point.

@SunMarc SunMarc changed the title Update Update param_element_size Dec 11, 2025
@SunMarc SunMarc requested a review from MekkCyber December 11, 2025 15:39
Comment on lines -4414 to -4415
if not is_quantized or not hf_quantizer.param_needs_quantization(self, key):
_load_parameter_into_model(self, key, value)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be fine to remove the condition no ? cc @Cyrilvallez

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, yes and no... Previously there was a

else:
    hf_quantizer.create_quantized_param(...)

in case the missing weight needed to be quantized (we quantized the new random weight).

I'm not sure why it was removed, but it should probably be the best way no?

Copy link
Copy Markdown
Member Author

@SunMarc SunMarc Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We removed create_quantized_param as we didn't need to anymore since we have the quantize ops xD. Let's see in another how to deal with that in another PR cc @MekkCyber but I don't think this urgent.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Contributor

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome 🧼

Comment thread src/transformers/quantizers/base.py
Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SUPER NICE! 🤗 Thanks a lot for doing that! High time we finally have a nice and clean way to pre-allocate with quantization! 🚀
Just added a few comments, then we can merge!

Comment on lines -4414 to -4415
if not is_quantized or not hf_quantizer.param_needs_quantization(self, key):
_load_parameter_into_model(self, key, value)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, yes and no... Previously there was a

else:
    hf_quantizer.create_quantized_param(...)

in case the missing weight needed to be quantized (we quantized the new random weight).

I'm not sure why it was removed, but it should probably be the best way no?

Comment thread src/transformers/modeling_utils.py Outdated
Comment thread src/transformers/quantizers/quantizer_bnb_4bit.py
Comment thread src/transformers/modeling_utils.py Outdated
else None
)

modules_sizes, _ = compute_module_sizes(model, hf_quantizer, only_modules=False)
Copy link
Copy Markdown
Member Author

@SunMarc SunMarc Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using this instead cc @Cyrilvallez if that's fine

Comment thread src/transformers/modeling_utils.py Outdated
Comment on lines +4039 to +4041
# We need parameters + buffers here, as state_dict does not count non-persistent buffers which are taking space
expected_keys = [name for name, _ in model.named_parameters()] + [name for name, _ in model.named_buffers()]

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there were some differences because of this

Comment on lines +471 to +473
# check that we get the same value, as we use `compute_module_sizes` in `get_total_byte_count`
assert total_byte_count == model_size[""]
assert quantized_total_byte_count == quantized_model_size[""]
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here i check that we have the same result

@SunMarc SunMarc requested a review from Cyrilvallez December 12, 2025 17:59
@Cyrilvallez
Copy link
Copy Markdown
Member

Will have a last look next Monday!

Comment thread src/transformers/modeling_utils.py Outdated
Comment on lines +4039 to +4040
expected_keys = list(model.state_dict().keys())
# We need parameters + buffers here, as state_dict does not count non-persistent buffers which are taking space
expected_keys = [name for name, _ in model.named_parameters()] + [name for name, _ in model.named_buffers()]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Humm, the fact here is that non-persistent buffers are NOT loaded with the other params (because they are non-persistent of course), so it is not needed to account for them when allocating memory before loading
Thus I believe we should revert this change

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed xD I was only trying to match the numbers and didn't think too far. In any case, the tests I wrote should still be valid !

Comment thread src/transformers/modeling_utils.py Outdated
Comment on lines +4543 to +4544
modules_sizes, _ = compute_module_sizes(model, hf_quantizer, only_modules=False)
for param_name, device in accelerator_device_map.items():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Humm, here we iterate twice over all params for no reason... Better to go back to the old loop and mimic what's being done in compute_module_sizes by using dtype_size = hf_quantizer.param_element_size(model, name, param) if we have a quantizer!

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair, sounds good !

Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Left a final comment but that's it!
Feel free to merge after the conflicts on quantizers have been resolved 🤗
Thanks again for this, will make everything much much smoother! 🤗

Comment on lines +4523 to +4560
def get_total_byte_count(
model: PreTrainedModel, accelerator_device_map: dict, hf_quantizer: Optional[HfQuantizer] = None
):
"""
This utility function calculates the total bytes count needed to load the model on each device.
This is useful for caching_allocator_warmup as we want to know how much cache we need to pre-allocate.
"""

total_byte_count = defaultdict(lambda: 0)
tied_param_names = model.all_tied_weights_keys.keys()

tp_plan = getattr(model, "_tp_plan", []) or []
tp_plan_regex = (
re.compile("|".join([re.escape(plan) for plan in tp_plan]))
if _torch_distributed_available and torch.distributed.is_initialized()
else None
)

for param_name, device in accelerator_device_map.items():
# Skip if the parameter has already been accounted for (tied weights)
if param_name in tied_param_names:
continue

param = model.get_parameter_or_buffer(param_name)

if hf_quantizer is not None:
dtype_size = hf_quantizer.param_element_size(model, param_name, param)
else:
dtype_size = param.element_size()

param_byte_count = param.numel() * dtype_size

if tp_plan_regex is not None:
generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
param_byte_count //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1

total_byte_count[device] += param_byte_count
return total_byte_count
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if you don't mind, I find that it's easier to follow if everything is inside the same function in this case, so IMO I would put it back in caching_allocator_warmup
No super strong opinions here though, so if you think you'll ever need it elsewhere we can keep separate

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a do this as I need to to test that we get the correct allocation ;D

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: bnb, finegrained_fp8, mxfp4, quanto_integration, torchao_integration

@SunMarc SunMarc merged commit dd8057a into main Dec 18, 2025
26 checks passed
@SunMarc SunMarc deleted the clean-param-size branch December 18, 2025 14:26
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
* clean

* int

* check

* better

* working

* remove unrelated stuff

* rm print

* torchao

* Fix

* added

* fix quanto

* revert

* reverted

* rm comment

* fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants