Skip to content

[Quantisation] account for nested tensors from quantisers#44228

Open
JonoLF wants to merge 2 commits intohuggingface:mainfrom
JonoLF:fix/nested_tensor_names
Open

[Quantisation] account for nested tensors from quantisers#44228
JonoLF wants to merge 2 commits intohuggingface:mainfrom
JonoLF:fix/nested_tensor_names

Conversation

@JonoLF
Copy link
Copy Markdown

@JonoLF JonoLF commented Feb 23, 2026

What does this PR do?

When using a quantisation config with a colpali-engine model like so:

from colpali_engine.models import BiQwen2_5, BiQwen2_5_Processor
import torch

model_name = f"nomic-ai/nomic-embed-multimodal-3b"

attention = None
if is_flash_attn_2_available():
    attention = "flash_attention_2"

assert is_accelerate_available(), (
    "Accelerate library needed for auto device mapping"
)

quantisation_config = QuantoConfig(
        weights="int4"
        )

model = BiQwen2_5.from_pretrained(
    model_name,
    dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation=attention,
    quantization_config=quantisation_config,
).eval()

processor = BiQwen2_5_Processor.from_pretrained(model_name)

Using 5.3.0.dev0, the following error would occur:

    model = BiQwen2_5.from_pretrained(
            ~~~~~~~~~~~~~~~~~~~~~~~~~^
        model_name,                                                
        ^^^^^^^^^^^                    
    ...<3 lines>...                        
        quantization_config=quantisation_config,         
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ).eval()                                                                                                                                                
    ^                              
  File ".../.venv/lib/python3.13/site-packages/colpali_engine/models/qwen2_5/biqwen2_5/modeling_biqwen2_5.py", line 27, in from_pretrained                                                                                                                                                     
    return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)
           ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                
  File ".../.venv/lib/python3.13/site-packages/transformers/modeling_utils.py", line 4127, in from_pretrained                      
    loading_info = model.load_adapter(                              
        _adapter_model_path,                                                                                                                                
    ...<2 lines>...         
        adapter_kwargs=adapter_kwargs,
    )                                                                                                                                                       
  File ".../.venv/lib/python3.13/site-packages/transformers/integrations/peft.py", line 575, in load_adapter                       
    loading_info, _ = self._load_pretrained_model(                                                                                                          
                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~^                                                                                                          
        model=self,                   
        ^^^^^^^^^^^             
    ...<2 lines>...                                                                                                                                         
        load_config=load_config,                                                                                                                            
        ^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                            
    )                                                                                                                                                       
    ^                        
  File ".../.venv/lib/python3.13/site-packages/transformers/modeling_utils.py", line 4175, in _load_pretrained_model               
    caching_allocator_warmup(model, expanded_device_map, load_config.hf_quantizer)                                                                          
    ~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                          
  File ".../.venv/lib/python3.13/site-packages/transformers/modeling_utils.py", line 4731, in caching_allocator_warmup             
    total_byte_count = get_total_byte_count(model, accelerator_device_map, hf_quantizer)                                                                                                                                                                                                                                
  File ".../.venv/lib/python3.13/site-packages/transformers/modeling_utils.py", line 4688, in get_total_byte_count                                                                                                                                                                             
    param = model.get_parameter_or_buffer(param_name)
  File ".../.venv/lib/python3.13/site-packages/transformers/modeling_utils.py", line 4592, in get_parameter_or_buffer
    module, param_name = get_module_from_name(self, target)
                         ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
  File ".../.venv/lib/python3.13/site-packages/transformers/quantizers/quantizers_utils.py", line 21, in get_module_from_name
    module = module.get_submodule(module_name)
  File ".../.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 732, in get_submodule
    raise AttributeError("`" + item + "` is not an nn.Module")
AttributeError: `weight` is not an nn.Module

Upon inspection, get_module_from_name, was assuming that the parameter name was only the string after the last '.' in the tensor_name, and when it got a tensor name like so:

tensor_name='visual.blocks.0.attn.qkv.weight._data._data'

the check for a module would fail, as it would assume that _data is the param name, and everything else is the module.

What this fix does, is correctly identify that the parameter is weight._data._data by checking if each possible module name is an instance of a torch module.

Since we know get a potentially listed param name, in get_parameter_or_buffer we then recurse through the param name to get the leaf param to be returned.

pytest tests/utils/test_modeling_utils.py passed, I'm not entirely sure what else to test.

I'm assuming that this was introduced by either the quantisation process, or is a facet of the colpali model, but again, I'm not entirely sure.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@SunMarc @MekkCyber

@JonoLF JonoLF force-pushed the fix/nested_tensor_names branch from 052c8e5 to d3485ea Compare February 23, 2026 15:11
@JonoLF JonoLF changed the title [WIP] account for nested tensors from quantisers [Quantisation] account for nested tensors from quantisers Feb 24, 2026
@JonoLF JonoLF force-pushed the fix/nested_tensor_names branch from de3920f to a2909bd Compare February 25, 2026 13:58
Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for that in quanto ? I think i've added some tests in other quantizers to make sure that the mem calculation is right and it was using get_total_byte_count.

@JonoLF JonoLF force-pushed the fix/nested_tensor_names branch from a2909bd to 6986cde Compare March 17, 2026 11:44
@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=44228&sha=6986cd

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants