Skip to content

Cache: init empty cache when use_cache#34274

Merged
zucchini-nlp merged 15 commits intohuggingface:mainfrom
zucchini-nlp:cache-empty-init
Nov 25, 2024
Merged

Cache: init empty cache when use_cache#34274
zucchini-nlp merged 15 commits intohuggingface:mainfrom
zucchini-nlp:cache-empty-init

Conversation

@zucchini-nlp
Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp commented Oct 21, 2024

What does this PR do?

Fixes #34206. As per title we would have to initialize empty cache whenever use_cache=True. Additionally MllamaForCausalLM was not loading correctly for me, so I modified the base model prefix

@zucchini-nlp zucchini-nlp requested a review from gante October 21, 2024 08:15
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two items to check :D

class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
config_class = MllamaTextConfig
base_model_prefix = "model"
base_model_prefix = "language_model"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the docstring of PreTrainedModel, regarding base_model_prefix:

A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.

How is the hierarchy in the model weights? I'm assuming it's

model -|- language_model --------------|- model - (...)
       |                               |- lm_head
       |- vision_model-(...)
       |- multi_modal_projector-(...)
           

If that's the case, then I agree with the change, assuming we also change self.model to self.language_model in this class

(make sure all slow tests pass!)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is how the checkpoints looks like and when I loaded the model it didn't load correctly if the base prefix is not fixed. The slow tests unfortunately can't be run because the model is read-token protected and EU has no access to them 🥲 But I tested with open mirrored weights

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair :)

Can we change self.model to self.language_model? Some parts of the codebase call variants of inner_model = getattr(model, model.base_model_prefix)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would mean we change the checkpoint state dict keys right? 🤔 anyway, lemme verify this and tell if it is possible without touching checkpoint

Copy link
Copy Markdown
Contributor

@gante gante Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would mean we change the checkpoint state dict keys right?

Uhmmm possibly? Not sure 😅 The tests that save and reload would break if it would not be BC, I think

Copy link
Copy Markdown
Member Author

@zucchini-nlp zucchini-nlp Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I found why it was changed, it was to get the self.base_model method for preparing causal mask in #33677 hehe

and yes, w/o changing state dict keys we cannot call it "model". Imo even if we change the official state dict, there are many mirrors/finetunes which will be BC breaking compared to the first model release. So the better way i think is to bring back the base-model-prefix as it was.

I'm thinking maybe we can have a default method for update_causal_mask and prepare_attention_mask which will be the fallback if the base-model has no such method defined? 🤔

EDIT: but wait, Arthur might disagree as he wanted to have attn preparation is all model files instead of having one copy in general modeling file. In that case, we might need to get smth better than getatte(self, base_model_prefix) as it doesn't work when same checkpoint is loaded as CausalLM and as ConditionalLM :(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔

Regarding the generate-specific problem: if base_model = getattr(self, self.base_model_prefix, None) in the generalist prepare_inputs_for_generation (or its downstream usage) is the issue, then my recommendation would be to overwrite prepare_inputs_for_generation in mllama -- more specifically, in the classes where it doesn't work

Alternatively, we could define _prepare_4d_causal_attention_mask_with_cache_position in all model classes -- write it once in the innermost class, then the child classes would define this function as a parent call.

Any of these solutions would work well for me :) (with a preference for the second: when we rewrite prepare_inputs_for_generation, we know for sure we'll have extra maintenance in the future)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, and also enabled compile tests for the CausalLM class to test that it works

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should indeed be langauge_model but probably in a different PR as it is unrelated!

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oke, will make a new PR

Comment thread tests/generation/test_utils.py
Comment thread tests/models/qwen2_vl/test_modeling_qwen2_vl.py Outdated
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks LGTM but let's separate unrelated changes!

class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
config_class = MllamaTextConfig
base_model_prefix = "model"
base_model_prefix = "language_model"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should indeed be langauge_model but probably in a different PR as it is unrelated!

Comment on lines +1744 to +1746
if use_cache and past_key_values is None:
past_key_values = DynamicCache()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does make sense as it's helping users, and an old APi, but let's promote init of a cache and passing it! 🤗

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could but still I think it is a lot easier if users want a forward pass with cache, and do not want extra lines of code for importing and passing the cache object. So i think we'd better keep the default cache for now

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks 🤗

)
use_cache = False

if use_cache and past_key_values is None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing torch jit tracing escape here no?

@zucchini-nlp zucchini-nlp merged commit c1a8520 into huggingface:main Nov 25, 2024
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* fix

* fix tests

* fix copies

* add docs

* Revert "add docs"

This reverts commit 32d3563.

* qwen move deltas

* mllama can potentiall fullgraph compile

* enable mllama compile and fix tests

* remove mllama fixes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MllamaForCausalLM not returning past_key_values even with use_cache=True

4 participants