Skip to content

BUG : Modeling nemotron file does not cache key values even though  #34739

@jeongin601

Description

@jeongin601

System Info

huggingface-hub-0.26.2
tokenizers-0.20.3
transformers-4.47.0.dev0
Python 3.10.12
Driver Version: 535.129.03
CUDA Version: 12.3

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Description

You can run the below code to reproduce prefill key value caching problem of minitron models.
I used "nvidia/Minitron-8B-Base" model.

Code

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load Minitron model and tokenizer from Hugging Face
model_name = "your-minitron-model-name"  # Replace with the actual Minitron model name
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Set the model to evaluation mode
model.eval()

# Sample input text
input_text = "Hello, how are you?"

# Tokenize the input
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

# First forward pass (prefill phase)
with torch.no_grad():
    outputs = model(input_ids, use_cache=True)  # Set use_cache=True
    logits = outputs.logits
    past_key_values = outputs.past_key_values

# Check the output
print("Logits shape:", logits.shape)
print("Number of layers in past_key_values:", len(past_key_values))
print("Shape of keys and values in the first layer:")
print("Key shape:", past_key_values[0][0].shape)
print("Value shape:", past_key_values[0][1].shape)

# Add new input to test cache utilization
new_input_text = " What about you?"
new_input_ids = tokenizer(new_input_text, return_tensors="pt").input_ids

# Pass the new input along with the previous key-value cache
with torch.no_grad():
    outputs_with_cache = model(new_input_ids, past_key_values=past_key_values, use_cache=True)

# Check results after caching
new_logits = outputs_with_cache.logits
new_past_key_values = outputs_with_cache.past_key_values

print("New logits shape:", new_logits.shape)
print("Number of layers in new past_key_values:", len(new_past_key_values))

Expected behavior

As-Is
스크린샷 2024-11-15 오후 1 15 33

Past key value is 'Nonetype', which means the key value caches are not cached.

Metadata

Metadata

Assignees

No one assigned

    Labels

    CacheUsageGeneral questions about the librarybug

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions