Skip to content

Flash-attn performance: remove cuda sync during inference#33570

Merged
Cyrilvallez merged 1 commit intohuggingface:mainfrom
Cyrilvallez:fix-flash
Oct 7, 2024
Merged

Flash-attn performance: remove cuda sync during inference#33570
Cyrilvallez merged 1 commit intohuggingface:mainfrom
Cyrilvallez:fix-flash

Conversation

@Cyrilvallez
Copy link
Copy Markdown
Member

What does this PR do?

#31629 & #32241 introduced a functionality in FA2 intended for training efficiency. However, it adds unnecessary cuda synchronization at inference time in every forward pass due to always checking (torch.diff(position_ids, dim=-1) >= 0).all() in the elif condition. This PR fixes the performance issue by simply switching the order of the different checks in the elif condition, to make good use of Python's default short-circuit evaluation. Indeed, at inference time, query_length will always be 1 except during prefill, thus we will short-circuit torch synchronization all the time.

Performance degradation was not so significant, but this PR allows to win back around 5-10% speed at inference time from the quick tests I ran.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Cyrilvallez
Copy link
Copy Markdown
Member Author

cc @ArthurZucker, forgot to ping you

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀 nice hack, awesome that you found about it!

@Cyrilvallez Cyrilvallez merged commit 1f33023 into huggingface:main Oct 7, 2024
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…e#33570)

Switch conditions to use short-circuit during inference
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants