Skip to content

Fix cache update!#38046

Merged
Cyrilvallez merged 2 commits intomainfrom
fix-cache
May 9, 2025
Merged

Fix cache update!#38046
Cyrilvallez merged 2 commits intomainfrom
fix-cache

Conversation

@Cyrilvallez
Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez commented May 9, 2025

What does this PR do?

As per the title. #37873 broke the cache update when going beyond the sliding window, see my comment here.
This PR fixes it.
This also incorporates the issue mentioned in #37574! TLDR, the order of operations here is important as we check strict inequality!!
Correcteness can be verified with

model_id = "google/gemma-2-9b-it"
device = 0


tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=device)



chat1 = [  # This size + the new tokens is > than sliding window
  {"role": "user", "content": "This is a nice place. " * 675 + "\n\nForget about the previous text, and tell me who you are?"},
]
prompt1 = tokenizer.apply_chat_template(chat1, tokenize=False, add_generation_prompt=True)
chat2 = [
  {"role": "user", "content": "create a list of at least 10 colors please"},
]
prompt2 = tokenizer.apply_chat_template(chat2, tokenize=False, add_generation_prompt=True)

inputs = tokenizer([prompt1, prompt2], padding=True, return_tensors="pt").to(0 if device == "auto" else device)

print(f"Sliding window: {getattr(model.config, 'sliding_window', None)}")
print(f"Input size: {inputs.input_ids.shape}")
# print(inputs.keys())





cache = "hybrid"

compile_config = CompileConfig(fullgraph=False)
out = model.generate(**inputs, do_sample=False, max_new_tokens=100, cache_implementation=cache, compile_config=compile_config)

text = tokenizer.batch_decode(out[:, inputs.input_ids.shape[-1] :], skip_special_tokens=False)
print("\n\n")
for seq in text:
    print("NEW SEQ:")
    print(seq)

It use to generate correctly for both the sequence > sliding window and the padded sequence, and now generates very badly. This fixes it once and for all.

@github-actions github-actions Bot marked this pull request as draft May 9, 2025 14:46
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 9, 2025

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@Cyrilvallez
Copy link
Copy Markdown
Member Author

cc @gante

@Cyrilvallez Cyrilvallez marked this pull request as ready for review May 9, 2025 14:46
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol little but destructive!

@Cyrilvallez Cyrilvallez merged commit aaed2f5 into main May 9, 2025
21 checks passed
@Cyrilvallez Cyrilvallez deleted the fix-cache branch May 9, 2025 15:54
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
* fix slicing

* better fix
@gante
Copy link
Copy Markdown
Contributor

gante commented May 20, 2025

@Cyrilvallez indeed, my previous PR fixed a small (tested) issue but created a large (untested) issue 🙈

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants