Conversation
| if not is_quantized or not hf_quantizer.param_needs_quantization(self, key): | ||
| _load_parameter_into_model(self, key, value) |
There was a problem hiding this comment.
it should be fine to remove the condition no ? cc @Cyrilvallez
There was a problem hiding this comment.
Well, yes and no... Previously there was a
else:
hf_quantizer.create_quantized_param(...)in case the missing weight needed to be quantized (we quantized the new random weight).
I'm not sure why it was removed, but it should probably be the best way no?
There was a problem hiding this comment.
We removed create_quantized_param as we didn't need to anymore since we have the quantize ops xD. Let's see in another how to deal with that in another PR cc @MekkCyber but I don't think this urgent.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Cyrilvallez
left a comment
There was a problem hiding this comment.
SUPER NICE! 🤗 Thanks a lot for doing that! High time we finally have a nice and clean way to pre-allocate with quantization! 🚀
Just added a few comments, then we can merge!
| if not is_quantized or not hf_quantizer.param_needs_quantization(self, key): | ||
| _load_parameter_into_model(self, key, value) |
There was a problem hiding this comment.
Well, yes and no... Previously there was a
else:
hf_quantizer.create_quantized_param(...)in case the missing weight needed to be quantized (we quantized the new random weight).
I'm not sure why it was removed, but it should probably be the best way no?
| else None | ||
| ) | ||
|
|
||
| modules_sizes, _ = compute_module_sizes(model, hf_quantizer, only_modules=False) |
There was a problem hiding this comment.
using this instead cc @Cyrilvallez if that's fine
| # We need parameters + buffers here, as state_dict does not count non-persistent buffers which are taking space | ||
| expected_keys = [name for name, _ in model.named_parameters()] + [name for name, _ in model.named_buffers()] | ||
|
|
There was a problem hiding this comment.
there were some differences because of this
| # check that we get the same value, as we use `compute_module_sizes` in `get_total_byte_count` | ||
| assert total_byte_count == model_size[""] | ||
| assert quantized_total_byte_count == quantized_model_size[""] |
There was a problem hiding this comment.
here i check that we have the same result
|
Will have a last look next Monday! |
| expected_keys = list(model.state_dict().keys()) | ||
| # We need parameters + buffers here, as state_dict does not count non-persistent buffers which are taking space | ||
| expected_keys = [name for name, _ in model.named_parameters()] + [name for name, _ in model.named_buffers()] |
There was a problem hiding this comment.
Humm, the fact here is that non-persistent buffers are NOT loaded with the other params (because they are non-persistent of course), so it is not needed to account for them when allocating memory before loading
Thus I believe we should revert this change
There was a problem hiding this comment.
indeed xD I was only trying to match the numbers and didn't think too far. In any case, the tests I wrote should still be valid !
| modules_sizes, _ = compute_module_sizes(model, hf_quantizer, only_modules=False) | ||
| for param_name, device in accelerator_device_map.items(): |
There was a problem hiding this comment.
Humm, here we iterate twice over all params for no reason... Better to go back to the old loop and mimic what's being done in compute_module_sizes by using dtype_size = hf_quantizer.param_element_size(model, name, param) if we have a quantizer!
Cyrilvallez
left a comment
There was a problem hiding this comment.
LGTM! Left a final comment but that's it!
Feel free to merge after the conflicts on quantizers have been resolved 🤗
Thanks again for this, will make everything much much smoother! 🤗
| def get_total_byte_count( | ||
| model: PreTrainedModel, accelerator_device_map: dict, hf_quantizer: Optional[HfQuantizer] = None | ||
| ): | ||
| """ | ||
| This utility function calculates the total bytes count needed to load the model on each device. | ||
| This is useful for caching_allocator_warmup as we want to know how much cache we need to pre-allocate. | ||
| """ | ||
|
|
||
| total_byte_count = defaultdict(lambda: 0) | ||
| tied_param_names = model.all_tied_weights_keys.keys() | ||
|
|
||
| tp_plan = getattr(model, "_tp_plan", []) or [] | ||
| tp_plan_regex = ( | ||
| re.compile("|".join([re.escape(plan) for plan in tp_plan])) | ||
| if _torch_distributed_available and torch.distributed.is_initialized() | ||
| else None | ||
| ) | ||
|
|
||
| for param_name, device in accelerator_device_map.items(): | ||
| # Skip if the parameter has already been accounted for (tied weights) | ||
| if param_name in tied_param_names: | ||
| continue | ||
|
|
||
| param = model.get_parameter_or_buffer(param_name) | ||
|
|
||
| if hf_quantizer is not None: | ||
| dtype_size = hf_quantizer.param_element_size(model, param_name, param) | ||
| else: | ||
| dtype_size = param.element_size() | ||
|
|
||
| param_byte_count = param.numel() * dtype_size | ||
|
|
||
| if tp_plan_regex is not None: | ||
| generic_name = re.sub(r"\.\d+\.", ".*.", param_name) | ||
| param_byte_count //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1 | ||
|
|
||
| total_byte_count[device] += param_byte_count | ||
| return total_byte_count |
There was a problem hiding this comment.
nit: if you don't mind, I find that it's easier to follow if everything is inside the same function in this case, so IMO I would put it back in caching_allocator_warmup
No super strong opinions here though, so if you think you'll ever need it elsewhere we can keep separate
There was a problem hiding this comment.
I had a do this as I need to to test that we get the correct allocation ;D
|
[For maintainers] Suggested jobs to run (before merge) run-slow: bnb, finegrained_fp8, mxfp4, quanto_integration, torchao_integration |
* clean * int * check * better * working * remove unrelated stuff * rm print * torchao * Fix * added * fix quanto * revert * reverted * rm comment * fix
What does this PR do?
This PR fixes how we calculate the param size for quantized models. This should be simpler to hack around !
I added some tests to check that it works for the following methods
bnb,finegrained_fp8,torchao,mxfp4,quanto.cc @Cyrilvallez as you were concerned at some point.