Cache: init empty cache when use_cache#34274
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin): | ||
| config_class = MllamaTextConfig | ||
| base_model_prefix = "model" | ||
| base_model_prefix = "language_model" |
There was a problem hiding this comment.
From the docstring of PreTrainedModel, regarding base_model_prefix:
A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
How is the hierarchy in the model weights? I'm assuming it's
model -|- language_model --------------|- model - (...)
| |- lm_head
|- vision_model-(...)
|- multi_modal_projector-(...)
If that's the case, then I agree with the change, assuming we also change self.model to self.language_model in this class
(make sure all slow tests pass!)
There was a problem hiding this comment.
Yes, that is how the checkpoints looks like and when I loaded the model it didn't load correctly if the base prefix is not fixed. The slow tests unfortunately can't be run because the model is read-token protected and EU has no access to them 🥲 But I tested with open mirrored weights
There was a problem hiding this comment.
Fair :)
Can we change self.model to self.language_model? Some parts of the codebase call variants of inner_model = getattr(model, model.base_model_prefix)
There was a problem hiding this comment.
that would mean we change the checkpoint state dict keys right? 🤔 anyway, lemme verify this and tell if it is possible without touching checkpoint
There was a problem hiding this comment.
that would mean we change the checkpoint state dict keys right?
Uhmmm possibly? Not sure 😅 The tests that save and reload would break if it would not be BC, I think
There was a problem hiding this comment.
Indeed, I found why it was changed, it was to get the self.base_model method for preparing causal mask in #33677 hehe
and yes, w/o changing state dict keys we cannot call it "model". Imo even if we change the official state dict, there are many mirrors/finetunes which will be BC breaking compared to the first model release. So the better way i think is to bring back the base-model-prefix as it was.
I'm thinking maybe we can have a default method for update_causal_mask and prepare_attention_mask which will be the fallback if the base-model has no such method defined? 🤔
EDIT: but wait, Arthur might disagree as he wanted to have attn preparation is all model files instead of having one copy in general modeling file. In that case, we might need to get smth better than getatte(self, base_model_prefix) as it doesn't work when same checkpoint is loaded as CausalLM and as ConditionalLM :(
There was a problem hiding this comment.
🤔
Regarding the generate-specific problem: if base_model = getattr(self, self.base_model_prefix, None) in the generalist prepare_inputs_for_generation (or its downstream usage) is the issue, then my recommendation would be to overwrite prepare_inputs_for_generation in mllama -- more specifically, in the classes where it doesn't work
Alternatively, we could define _prepare_4d_causal_attention_mask_with_cache_position in all model classes -- write it once in the innermost class, then the child classes would define this function as a parent call.
Any of these solutions would work well for me :) (with a preference for the second: when we rewrite prepare_inputs_for_generation, we know for sure we'll have extra maintenance in the future)
There was a problem hiding this comment.
done, and also enabled compile tests for the CausalLM class to test that it works
There was a problem hiding this comment.
It should indeed be langauge_model but probably in a different PR as it is unrelated!
There was a problem hiding this comment.
oke, will make a new PR
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks LGTM but let's separate unrelated changes!
| class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin): | ||
| config_class = MllamaTextConfig | ||
| base_model_prefix = "model" | ||
| base_model_prefix = "language_model" |
There was a problem hiding this comment.
It should indeed be langauge_model but probably in a different PR as it is unrelated!
| if use_cache and past_key_values is None: | ||
| past_key_values = DynamicCache() | ||
|
|
There was a problem hiding this comment.
this does make sense as it's helping users, and an old APi, but let's promote init of a cache and passing it! 🤗
There was a problem hiding this comment.
We could but still I think it is a lot easier if users want a forward pass with cache, and do not want extra lines of code for importing and passing the cache object. So i think we'd better keep the default cache for now
| ) | ||
| use_cache = False | ||
|
|
||
| if use_cache and past_key_values is None: |
There was a problem hiding this comment.
missing torch jit tracing escape here no?
* fix * fix tests * fix copies * add docs * Revert "add docs" This reverts commit 32d3563. * qwen move deltas * mllama can potentiall fullgraph compile * enable mllama compile and fix tests * remove mllama fixes
What does this PR do?
Fixes #34206. As per title we would have to initialize empty cache whenever
use_cache=True. AdditionallyMllamaForCausalLMwas not loading correctly for me, so I modified the base model prefix