Skip to content

Fix to tuple conversion with config#39257

Open
qubvel wants to merge 1 commit intohuggingface:mainfrom
qubvel:fix-return-tuple
Open

Fix to tuple conversion with config#39257
qubvel wants to merge 1 commit intohuggingface:mainfrom
qubvel:fix-return-tuple

Conversation

@qubvel
Copy link
Copy Markdown
Contributor

@qubvel qubvel commented Jul 7, 2025

What does this PR do?

setting return_dict=False with config fails for models with sub-models wrapped with can_return_tuple or check_model_inputs

import torch
from transformers import LlamaConfig, LlamaForCausalLM

config = LlamaConfig(vocab_size=256, hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
model = LlamaForCausalLM(config)

# default: ModelOutput
input_ids = torch.tensor([[0, 1, 2, 3]])
with torch.no_grad():
    output = model(input_ids)

print(output)


# passing return_dict=False as a kwarg 
input_ids = torch.tensor([[0, 1, 2, 3]])
with torch.no_grad():
    output = model(input_ids, return_dict=False)

print(output)


# ERROR: setting return_dict=False in the config
model.config.return_dict = False
with torch.no_grad():
    output = model(input_ids)

print(output)

# Traceback (most recent call last):
#   File "/home/ubuntu/projects/transformers/test_llama_small.py", line 17, in <module>
#     output = model(input_ids)
#   File "/home/ubuntu/projects/transformers/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
#     return self._call_impl(*args, **kwargs)
#   File "/home/ubuntu/projects/transformers/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
#     return forward_call(*args, **kwargs)
#   File "/home/ubuntu/projects/transformers/src/transformers/utils/generic.py", line 962, in wrapper
#     output = func(self, *args, **kwargs)
#   File "/home/ubuntu/projects/transformers/src/transformers/models/llama/modeling_llama.py", line 506, in forward
#     hidden_states = outputs.last_hidden_state
# AttributeError: 'tuple' object has no attribute 'last_hidden_state'

if return_dict_passed is not None:
return_dict = return_dict_passed
output = func(self, *args, **kwargs)
output = func(self, *args, **kwargs, return_dict=True)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way, it's going to work in case **kwargs: [TransformersKwargs] are properly propagated from the top module up to each wrapped module.

Just for the context, previously, we were recursively setting the module attribute _is_top_module to avoid passing **kwargs everywhere.

Comment on lines -926 to -946
def set_attribute_for_modules(module: "torch.nn.Module", key: str, value: Any):
"""
Set a value to a module and all submodules.
"""
setattr(module, key, value)
for submodule in module.children():
set_attribute_for_modules(submodule, key, value)


def del_attribute_from_modules(module: "torch.nn.Module", key: str):
"""
Delete a value from a module and all submodules.
"""
# because we might remove it previously in case it's a shared module, e.g. activation function
if hasattr(module, key):
delattr(module, key)

for submodule in module.children():
del_attribute_from_modules(submodule, key)


Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no longer needed

@qubvel qubvel marked this pull request as ready for review July 7, 2025 16:15
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants