Skip to content

StaticCache crashes when the batch-size changes #42454

@mobicham

Description

@mobicham

System Info

- `transformers` version: 4.57.1
- Platform: Linux-6.8.0-1043-aws-x86_64-with-glibc2.35
- Python version: 3.11.7
- Huggingface_hub version: 0.36.0
- Safetensors version: 0.6.2
- Accelerate version: 1.12.0
- Accelerate config:    not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.8.0+cu128 (CUDA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: <fill in>
- Using GPU in script?: <fill in>
- GPU type: NVIDIA L40S

Who can help?

@gante @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hi folks! I was doing some testing with static cache and noticed that it breaks when the batch-size changes:

RuntimeError: index_copy_(): Source/destination tensor must have same slice shapes. Destination slice shape: 4 20 64 at dimension 2 and source slice shape: 8 20 64 at dimension 0.

Is this expected behavior or a known limitation? The issue is that, in order to use static cache with cudaraphs, the cache needs to be pre-allocated. If it crashes like this, we need a separate cache for each batch-size which quickly blows up the VRAM usage...

import torch
from transformers import WhisperForConditionalGeneration, AutoProcessor
from transformers.cache_utils import StaticCache, EncoderDecoderCache
import math
import numpy as np

device = 'cuda:0'
torch_dtype = torch.float16
model_id = "openai/whisper-large-v3-turbo"

model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch_dtype, attn_implementation="sdpa", device_map=device) 
model.generation_config.cache_implementation = "static"

@torch.no_grad()
def run_encoder(model, labels, encoder_outputs, past_key_values, cache_position=None):
    out_decoder = model.model.decoder(
        labels, 
        encoder_hidden_states=encoder_outputs,
        past_key_values = past_key_values,
        cache_position=cache_position,
        use_cache = True,
        return_dict=True,
    ) 

    cur_token = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(axis=-1) 
    past_key_values = out_decoder.past_key_values

    return cur_token, past_key_values

max_batch_size = 32
max_cache_len = 256
enc_len = 1500
decoder = model.model.decoder

# Cache
self_cache = StaticCache(
    config=decoder.config,
    max_batch_size=max_batch_size,
    max_cache_len=max_cache_len,
    device=device,
    dtype=torch_dtype,
)

cross_cache = StaticCache(
    config=decoder.config,
    max_batch_size=max_batch_size,
    max_cache_len=enc_len,
    device=device,
    dtype=torch_dtype,
)

past_key_values = EncoderDecoderCache(self_cache, cross_cache)

for bs in [8, 4, 2, 1]:
	past_key_values.reset()
	assert bs <= max_batch_size, "batch_size should be <= max_batch_size"
	seq_length = 3
	labels = torch.tensor([[50258, 50259, 50360]] * bs, device=device, dtype=torch.int64)
	cache_position = torch.arange(seq_length, device=device, dtype=torch.int64)
	encoder_outputs = torch.randn([bs, enc_len, 1280], device=device, dtype=torch_dtype)

	cur_token, past_key_values_out = run_encoder(model, labels, encoder_outputs, past_key_values, cache_position=cache_position)

# RuntimeError: index_copy_(): Source/destination tensor must have same slice shapes. Destination slice shape: 4 20 64 at dimension 2 and source slice shape: 8 20 64 at dimension 0.

Expected behavior

A single static cache instance should work across multiple batch-sizes as long as the batch-size <= max_batch_size

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions