Skip to content

revert _prepare_4d_causal_attention_mask_with_cache_position for gpt2#41806

Closed
jiqing-feng wants to merge 1 commit intohuggingface:mainfrom
jiqing-feng:gpt2
Closed

revert _prepare_4d_causal_attention_mask_with_cache_position for gpt2#41806
jiqing-feng wants to merge 1 commit intohuggingface:mainfrom
jiqing-feng:gpt2

Conversation

@jiqing-feng
Copy link
Copy Markdown
Contributor

@jiqing-feng jiqing-feng commented Oct 23, 2025

Hi @zucchini-nlp

The PR #39754 deleted _prepare_4d_causal_attention_mask_with_cache_position on gpt2, which caused 40% performance regression on CPU. You can reproduce it by
numactl -C 0-7 --membind 0 python test.py

import time
import torch
from transformers import pipeline, set_seed, AutoTokenizer

set_seed(42)

model_id = "openai-community/gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.padding_side = 'left'
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

pipe = pipeline("text-generation", model=model_id, tokenizer=tokenizer, torch_dtype=torch.float16, device_map="cpu")

generation_config = pipe.model.generation_config
generation_config.do_sample = False
generation_config.use_cache = True
generation_config.max_new_tokens = 128
generation_config.min_new_tokens = 128
generation_config.cache_implementation="static"
generation_config.temperature = 1.0
generation_config.top_p = 1.0
generation_config.num_beams = 1
pipe.model.config._attn_implementation="sdpa"

inputs = "It is done, and submitted. You can play 'Survival of the Tastiest' on Android, and on the web. Playing on the web works, but you have to simulate multiple touch for table moving and that can be a bit confusing. There is a lot I'd like to talk about. I will go through every topic, insted of making the typical what went right/wrong list. Concept Working over the theme was probably one of the hardest tasks which I had to face. Originally, I had an idea of what kind of game I wanted to develop, gameplay wise - something with a lot of enemies/actors"

for _ in range(5):
    set_seed(42)
    pipe(inputs, generation_config=generation_config)

for _ in range(5):
    set_seed(42)
    start = time.time()
    pipe(inputs, generation_config=generation_config)
    end = time.time()
    print(f"{pipe.model.dtype} time costs {(end-start)*1000} ms")

Revert _prepare_4d_causal_attention_mask_with_cache_position can fix the regression.

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gpt2

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

run-slow: gpt2

@jiqing-feng jiqing-feng marked this pull request as ready for review October 23, 2025 07:27
@github-actions github-actions Bot requested a review from ArthurZucker October 23, 2025 07:27
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the same issue as in #41639 (vmap in causal masking), I already discussed with @Cyrilvallez that we will add a non-vmap path to the mask creations. This will revert the perf regressions, so I'd like you to wait for the PR instead as we don't want to introduce old functions (we want to deprecate) back into the code.

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Copy Markdown
Contributor Author

OK, please let me know when the PR is ready. Thanks!

@jiqing-feng jiqing-feng deleted the gpt2 branch December 15, 2025 02:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants