System Info
huggingface-hub-0.26.2
tokenizers-0.20.3
transformers-4.47.0.dev0
Python 3.10.12
Driver Version: 535.129.03
CUDA Version: 12.3
Who can help?
No response
Information
Tasks
Reproduction
Description
You can run the below code to reproduce prefill key value caching problem of minitron models.
I used "nvidia/Minitron-8B-Base" model.
Code
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load Minitron model and tokenizer from Hugging Face
model_name = "your-minitron-model-name" # Replace with the actual Minitron model name
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Set the model to evaluation mode
model.eval()
# Sample input text
input_text = "Hello, how are you?"
# Tokenize the input
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
# First forward pass (prefill phase)
with torch.no_grad():
outputs = model(input_ids, use_cache=True) # Set use_cache=True
logits = outputs.logits
past_key_values = outputs.past_key_values
# Check the output
print("Logits shape:", logits.shape)
print("Number of layers in past_key_values:", len(past_key_values))
print("Shape of keys and values in the first layer:")
print("Key shape:", past_key_values[0][0].shape)
print("Value shape:", past_key_values[0][1].shape)
# Add new input to test cache utilization
new_input_text = " What about you?"
new_input_ids = tokenizer(new_input_text, return_tensors="pt").input_ids
# Pass the new input along with the previous key-value cache
with torch.no_grad():
outputs_with_cache = model(new_input_ids, past_key_values=past_key_values, use_cache=True)
# Check results after caching
new_logits = outputs_with_cache.logits
new_past_key_values = outputs_with_cache.past_key_values
print("New logits shape:", new_logits.shape)
print("Number of layers in new past_key_values:", len(new_past_key_values))
Expected behavior
As-Is

Past key value is 'Nonetype', which means the key value caches are not cached.
System Info
huggingface-hub-0.26.2
tokenizers-0.20.3
transformers-4.47.0.dev0
Python 3.10.12
Driver Version: 535.129.03
CUDA Version: 12.3
Who can help?
No response
Information
Tasks
examplesfolder (such as GLUE/SQuAD, ...)Reproduction
Description
You can run the below code to reproduce prefill key value caching problem of minitron models.
I used "nvidia/Minitron-8B-Base" model.
Code
Expected behavior
As-Is

Past key value is 'Nonetype', which means the key value caches are not cached.