Skip to content

Blip model got performance regression on compile mode after refactor cache. #39774

@jiqing-feng

Description

@jiqing-feng

System Info

transformers version: 4.55.0.dev0
Platform: Linux-6.11.0-28-generic-x86_64-with-glibc2.35
Python version: 3.11.13
Huggingface_hub version: 0.34.2
Safetensors version: 0.5.3
Accelerate version: 1.8.1
Accelerate config: not found
DeepSpeed version: not installed
PyTorch version (accelerator?): 2.9.0.dev20250714+cpu (NA)
Tensorflow version (GPU?): not installed (NA)
Flax version (CPU?/GPU?/TPU?): not installed (NA)
Jax version: not installed
JaxLib version: not installed
Using distributed or parallel set-up in script?:

Who can help?

@zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

TORCH_LOGS="+graph_breaks,+recompiles" python test.py

import time
import requests
import torch
import PIL.Image
from transformers import pipeline

model_id = "Salesforce/blip-image-captioning-base"
image_to_text = pipeline("image-to-text", model=model_id, device="cpu", torch_dtype=torch.float16)
image_url = "https://ankur3107.github.io/assets/images/image-captioning-example.png"
image = PIL.Image.open(requests.get(image_url, stream=True, timeout=3000).raw)

for _ in range(10):
    output = image_to_text(image)

start = time.time()
output = image_to_text(image)
end = time.time()
print(f"eager mode pipeline latency {end - start}")

image_to_text.model.vision_model.forward = torch.compile(image_to_text.model.vision_model.forward, backend=args.backend)
image_to_text.model.text_decoder.forward = torch.compile(image_to_text.model.text_decoder.forward, backend=args.backend)


for _ in range(10):
    output = image_to_text(image)

start = time.time()
output = image_to_text(image)
end = time.time()
print(f"compile mode pipeline latency {end - start}")

Output logs:

W0730 06:58:23.995000 2266976 torch/_dynamo/convert_frame.py:1067] [12/8] torch._dynamo hit config.recompile_limit (8)
W0730 06:58:23.995000 2266976 torch/_dynamo/convert_frame.py:1067] [12/8]    function: 'forward' (/home/jiqing/transformers/src/transformers/models/blip/modeling_blip_text.py:358)
W0730 06:58:23.995000 2266976 torch/_dynamo/convert_frame.py:1067] [12/8]    last reason: 12/7: tensor 'past_key_value.self_attention_cache.layers[7].keys' size mismatch at index 2. expected 1, actual 2
W0730 06:58:23.995000 2266976 torch/_dynamo/convert_frame.py:1067] [12/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0730 06:58:23.995000 2266976 torch/_dynamo/convert_frame.py:1067] [12/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html
W0730 06:58:30.593000 2266976 torch/_dynamo/convert_frame.py:1067] [11/8] torch._dynamo hit config.recompile_limit (8)
W0730 06:58:30.593000 2266976 torch/_dynamo/convert_frame.py:1067] [11/8]    function: '__call__' (/home/jiqing/transformers/src/transformers/modeling_layers.py:61)
W0730 06:58:30.593000 2266976 torch/_dynamo/convert_frame.py:1067] [11/8]    last reason: 11/7: len(args[5].is_updated) == 6       
W0730 06:58:30.593000 2266976 torch/_dynamo/convert_frame.py:1067] [11/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0730 06:58:30.593000 2266976 torch/_dynamo/convert_frame.py:1067] [11/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html

Expected behavior

Before the PR #38635, the script runs well and can get 1.5x speed-up.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions