Mamba & RecurrentGemma: enable strict signature#31549
Mamba & RecurrentGemma: enable strict signature#31549gante merged 3 commits intohuggingface:mainfrom
Conversation
| use_cache: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None, | ||
| **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it |
There was a problem hiding this comment.
alternatively, we can accept attention_mask and raise an exception when it is not None or not all ones
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Let's googoogogogogo 🚀
| model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) | ||
| model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) |
There was a problem hiding this comment.
yesssss I think I have a PR open where I dod this! Finally!
| use_cache: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None, | ||
| **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it |
There was a problem hiding this comment.
@amyeroberts I had a look and it should be fine: this PR removes **kwargs from the model class (e.g. MambaModel), while the FSDP PR ensures there are **kwargs in the decoder layers (e.g. FalconDecoderLayer).
We can see on main that the model themselves don't have **kwargs, even after the FSDP fix (e.g. llama) 🤗
What does this PR do?
Fixes #31540
Mamba accepts
**kwargs, and thusattention_maskcan be passed. Many users thus assume it behaves just like other models and can support left-padding.RecurrentGemma also accept
**kwargs, but simply not to crashgenerate.This PR enables a strict signature on Mamba and RecurrentGemma.