Skip to content

Test: generate with torch.compile(model.forward) as a fast test#34544

Merged
gante merged 13 commits intohuggingface:mainfrom
gante:generate_forward_compile_fix
Jan 28, 2025
Merged

Test: generate with torch.compile(model.forward) as a fast test#34544
gante merged 13 commits intohuggingface:mainfrom
gante:generate_forward_compile_fix

Conversation

@gante
Copy link
Copy Markdown
Contributor

@gante gante commented Oct 31, 2024

What does this PR do?

Follow-up to #34464

This PR:

  1. Converts test_generate_compile_model_forward to a fast test. This means we will check generate with torch.compile(model.forward) at each commit on ALL models that support StaticCache 💛
  2. Fixes failing cases of test_generate_compile_model_forward whenever possible
  3. Tags models with _supports_static_cache = False #Reason when the model doesn't support torch.compile(model.forward)

py.test tests/models/ -k test_generate_compile is all green, takes ~2 mins to run on all models on my machine

@gante gante requested review from ArthurZucker and ydshieh October 31, 2024 18:25
Comment thread tests/models/chameleon/test_modeling_chameleon.py Outdated
Comment thread tests/generation/test_utils.py Outdated
Copy link
Copy Markdown
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this!

Q: Is it really fast ...?

Remark: I feel get_max_cache_length is a better name than get_max_cache_shape but OK I know not great to change name all the time.

@gante
Copy link
Copy Markdown
Contributor Author

gante commented Oct 31, 2024

Q: Is it really fast ...?

@ydshieh yes :D
Screenshot 2024-10-31 at 18 41 14

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind, tho I don't think our priority should be this (full compile vs compile forward in generate!) + I don't see the test being run in the CI! 🤗

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Could you just make sure it's run

@ydshieh2
Copy link
Copy Markdown
Contributor

ydshieh2 commented Nov 5, 2024

We need to remove @require_torch_gpu too for def test_generate_compile

Comment thread tests/generation/test_utils.py Outdated
@ydshieh
Copy link
Copy Markdown
Collaborator

ydshieh commented Jan 23, 2025

Before merge, feel free to ping me for a check for (if there is any) flakyness :-) or anything you think I can double check again.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks can ignore my comments and merge 🤗

Comment on lines 359 to 360
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible for the HybridCache to inherit from Static cache?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might just need an extra class that says CompileCompatible , someone wanted is_static attr˜!

@gante
Copy link
Copy Markdown
Contributor Author

gante commented Jan 27, 2025

(sorry, the PR is not ready yet, a few cases are still failing 👀 I didn't mean to request a review)

@gante
Copy link
Copy Markdown
Contributor Author

gante commented Jan 28, 2025

Now it's working on all models, including encoder-decoder + cache 🤗

It's not too heavy on our CI, it should add ~2 mins if all models are run. And it should prevent us from many headaches! As we can see in diff, we had compilation enabled for a bunch of models that don't support it.

Screenshot 2025-01-28 at 12 48 05

@gante gante merged commit ece8c42 into huggingface:main Jan 28, 2025
@gante gante deleted the generate_forward_compile_fix branch January 28, 2025 14:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants