Skip to content

Fix: Enable prefill phase key value caching of nemotron/minitron models#34742

Merged
zucchini-nlp merged 9 commits intohuggingface:mainfrom
jeongin601:main
Nov 25, 2024
Merged

Fix: Enable prefill phase key value caching of nemotron/minitron models#34742
zucchini-nlp merged 9 commits intohuggingface:mainfrom
jeongin601:main

Conversation

@jeongin601
Copy link
Copy Markdown
Contributor

@jeongin601 jeongin601 commented Nov 15, 2024

What does this PR do?

Fixes #34739

Problem

Current implementation does not enable key value caching of nemotron and minitron models.
This problem can be checked by a quick example code that generates key and value caches.

Modification

I modified the code to enabled key value caching while prefill phase, in reference to 'modeling_llama.py' file.

Key value caching example code

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load Minitron model and tokenizer from Hugging Face
model_name = "your-minitron-model-name"  # Replace with the actual Minitron model name
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Set the model to evaluation mode
model.eval()

# Sample input text
input_text = "Hello, how are you?"

# Tokenize the input
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

# First forward pass (prefill phase)
with torch.no_grad():
    outputs = model(input_ids, use_cache=True)  # Set use_cache=True
    logits = outputs.logits
    past_key_values = outputs.past_key_values

# Check the output
print("Logits shape:", logits.shape)
print("Number of layers in past_key_values:", len(past_key_values))
print("Shape of keys and values in the first layer:")
print("Key shape:", past_key_values[0][0].shape)
print("Value shape:", past_key_values[0][1].shape)

# Add new input to test cache utilization
new_input_text = " What about you?"
new_input_ids = tokenizer(new_input_text, return_tensors="pt").input_ids

# Pass the new input along with the previous key-value cache
with torch.no_grad():
    outputs_with_cache = model(new_input_ids, past_key_values=past_key_values, use_cache=True)

# Check results after caching
new_logits = outputs_with_cache.logits
new_past_key_values = outputs_with_cache.past_key_values

print("New logits shape:", new_logits.shape)
print("Number of layers in new past_key_values:", len(new_past_key_values))

As-Is Result

스크린샷 2024-11-15 오후 1 15 33

Key value caching is not done.

To-be Result

스크린샷 2024-11-15 오후 4 15 44

Key value caching is enabled

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker
Can you please check my modification? :)

Signed-off-by: jeongin601 <0200angela@gmail.com>
Signed-off-by: jeongin601 <0200angela@gmail.com>
Signed-off-by: jeongin601 <0200angela@gmail.com>
Signed-off-by: jeongin601 <0200angela@gmail.com>
@LysandreJik
Copy link
Copy Markdown
Member

cc @ArthurZucker @gante @zucchini-nlp

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! Let's remove the deprecation warning, otherwise LGTM!

Comment on lines +791 to +795
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to add deprecation message for newly added models, we can support only new Cache objects

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing my code! I removed the deprecation warning. :)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry if I wasn't clear, I mean totally removing support for the tuple format and this the from_legacy_cache

I'll ping the core maintainer after that for the final review :)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry I got it wrong. Now, I removed support for tuple shaped past_key_values. Is this what you meant?

Signed-off-by: jeongin601 <0200angela@gmail.com>
Signed-off-by: jeongin601 <0200angela@gmail.com>
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One suggestion and we can merge!

Comment thread src/transformers/models/nemotron/modeling_nemotron.py Outdated
@ArthurZucker
Copy link
Copy Markdown
Collaborator

BTW seem to be related to #34274

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@jeongin601
Copy link
Copy Markdown
Contributor Author

One suggestion and we can merge!

I updated it! :) Thanks

@zucchini-nlp zucchini-nlp merged commit 318fe25 into huggingface:main Nov 25, 2024
@ydshieh
Copy link
Copy Markdown
Collaborator

ydshieh commented Nov 26, 2024

Hi @jeongin601, thank you for this PR ❤️ !

It seems this PR has some regression: below is the 3 tests failing. Would you be up to take a look? You can find the job run log here.

In any case, thank you in advance.

        "nemotron": {
            "single-gpu": [
                {
                    "test": "tests/models/nemotron/test_modeling_nemotron.py::NemotronModelTest::test_torchscript_output_attentions",
                    "commit": "318fe25f22a99ce1226f8d2aadc268b40f7e55af",
                    "pr_number": 34742,
                    "author": "jeongin601",
                    "merged_by": "zucchini-nlp"
                },
                {
                    "test": "tests/models/nemotron/test_modeling_nemotron.py::NemotronModelTest::test_torchscript_output_hidden_state",
                    "commit": "318fe25f22a99ce1226f8d2aadc268b40f7e55af",
                    "pr_number": 34742,
                    "author": "jeongin601",
                    "merged_by": "zucchini-nlp"
                },
                {
                    "test": "tests/models/nemotron/test_modeling_nemotron.py::NemotronModelTest::test_torchscript_simple",
                    "commit": "318fe25f22a99ce1226f8d2aadc268b40f7e55af",
                    "pr_number": 34742,
                    "author": "jeongin601",
                    "merged_by": "zucchini-nlp"
                }
            ]
        }
    },

@ydshieh
Copy link
Copy Markdown
Collaborator

ydshieh commented Nov 26, 2024

@ArthurZucker @zucchini-nlp A kind remind: don't hesitate to ask for slow CI 🙂 - let's use the tools we have to make our life easier🙏

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Ah it makes sense, torchscript does not support DynamicCache class !

@ArthurZucker
Copy link
Copy Markdown
Collaborator

(AFAIR)

@ydshieh
Copy link
Copy Markdown
Collaborator

ydshieh commented Nov 28, 2024

ah ok. I will check than. But if we eventually move forward to DynamicCache and drop legacy cache, it would mean torchscript is not going to work for many models ..?

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Yeah 👀 unless used with optimum!

BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…ls (huggingface#34742)

* modeling nemotron kv caching bugfix

Signed-off-by: jeongin601 <0200angela@gmail.com>

* test file deleted

Signed-off-by: jeongin601 <0200angela@gmail.com>

* code refinement

Signed-off-by: jeongin601 <0200angela@gmail.com>

* remove unused variables

Signed-off-by: jeongin601 <0200angela@gmail.com>

* import block sorted

* removed deprecation warning

Signed-off-by: jeongin601 <0200angela@gmail.com>

* removed support for tuple shape past_key_values

Signed-off-by: jeongin601 <0200angela@gmail.com>

* Update conditional statement for cache initialization

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Signed-off-by: jeongin601 <0200angela@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG : Modeling nemotron file does not cache key values even though

6 participants