System Info
- `transformers` version: 4.57.1
- Platform: Linux-6.8.0-1043-aws-x86_64-with-glibc2.35
- Python version: 3.11.7
- Huggingface_hub version: 0.36.0
- Safetensors version: 0.6.2
- Accelerate version: 1.12.0
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.8.0+cu128 (CUDA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: <fill in>
- Using GPU in script?: <fill in>
- GPU type: NVIDIA L40S
Who can help?
@gante @ArthurZucker
Information
Tasks
Reproduction
Hi folks! I was doing some testing with static cache and noticed that it breaks when the batch-size changes:
RuntimeError: index_copy_(): Source/destination tensor must have same slice shapes. Destination slice shape: 4 20 64 at dimension 2 and source slice shape: 8 20 64 at dimension 0.
Is this expected behavior or a known limitation? The issue is that, in order to use static cache with cudaraphs, the cache needs to be pre-allocated. If it crashes like this, we need a separate cache for each batch-size which quickly blows up the VRAM usage...
import torch
from transformers import WhisperForConditionalGeneration, AutoProcessor
from transformers.cache_utils import StaticCache, EncoderDecoderCache
import math
import numpy as np
device = 'cuda:0'
torch_dtype = torch.float16
model_id = "openai/whisper-large-v3-turbo"
model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch_dtype, attn_implementation="sdpa", device_map=device)
model.generation_config.cache_implementation = "static"
@torch.no_grad()
def run_encoder(model, labels, encoder_outputs, past_key_values, cache_position=None):
out_decoder = model.model.decoder(
labels,
encoder_hidden_states=encoder_outputs,
past_key_values = past_key_values,
cache_position=cache_position,
use_cache = True,
return_dict=True,
)
cur_token = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(axis=-1)
past_key_values = out_decoder.past_key_values
return cur_token, past_key_values
max_batch_size = 32
max_cache_len = 256
enc_len = 1500
decoder = model.model.decoder
# Cache
self_cache = StaticCache(
config=decoder.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=torch_dtype,
)
cross_cache = StaticCache(
config=decoder.config,
max_batch_size=max_batch_size,
max_cache_len=enc_len,
device=device,
dtype=torch_dtype,
)
past_key_values = EncoderDecoderCache(self_cache, cross_cache)
for bs in [8, 4, 2, 1]:
past_key_values.reset()
assert bs <= max_batch_size, "batch_size should be <= max_batch_size"
seq_length = 3
labels = torch.tensor([[50258, 50259, 50360]] * bs, device=device, dtype=torch.int64)
cache_position = torch.arange(seq_length, device=device, dtype=torch.int64)
encoder_outputs = torch.randn([bs, enc_len, 1280], device=device, dtype=torch_dtype)
cur_token, past_key_values_out = run_encoder(model, labels, encoder_outputs, past_key_values, cache_position=cache_position)
# RuntimeError: index_copy_(): Source/destination tensor must have same slice shapes. Destination slice shape: 4 20 64 at dimension 2 and source slice shape: 8 20 64 at dimension 0.
Expected behavior
A single static cache instance should work across multiple batch-sizes as long as the batch-size <= max_batch_size
System Info
Who can help?
@gante @ArthurZucker
Information
Tasks
examplesfolder (such as GLUE/SQuAD, ...)Reproduction
Hi folks! I was doing some testing with static cache and noticed that it breaks when the batch-size changes:
Is this expected behavior or a known limitation? The issue is that, in order to use static cache with cudaraphs, the cache needs to be pre-allocated. If it crashes like this, we need a separate cache for each batch-size which quickly blows up the VRAM usage...
Expected behavior
A single static cache instance should work across multiple batch-sizes as long as the batch-size <=
max_batch_size