Skip to content

add xformers dep, xformers attn for gpt2#22665

Closed
ethansmith2000 wants to merge 1 commit intohuggingface:mainfrom
ethansmith2000:add-xformers-gpt2
Closed

add xformers dep, xformers attn for gpt2#22665
ethansmith2000 wants to merge 1 commit intohuggingface:mainfrom
ethansmith2000:add-xformers-gpt2

Conversation

@ethansmith2000
Copy link
Copy Markdown

What does this PR do?

Add xformers as a dependency and implement xformers attention for gpt2.
I am a bit of a novice to this, but would like to contribute in helping all models in the transformers library to have xformers support.

It is likely the case that this PR is not ready to merge, but I was hoping I could get some feedback on what I would be able to provide

Fixes # (issue)
Reduces VRAM and increases speed

@sgugger
Copy link
Copy Markdown
Collaborator

sgugger commented Apr 10, 2023

cc @younesbelkada I don't know if this redundant with the Better Transformer integration.

Copy link
Copy Markdown
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your PR!
Indeed this might collide with the recent features of the BetterTransformer API of optimum library. This API modifies the modeling script of supported models and replaces some core operations with native torch functions such as torch.nn.functional.scaled_dot_product_attention: https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention

As you can read it from the pytorch documentation the function above uses xformers as well, termed as "Memory efficient attention". Hence my understanding here is that this PR might be a duplicate with the recent SDPA integration with optimum's BetterTransformer API.

I would love to hear @fxmarty 's thoughts here in case I missed few details.

Thanks!

@fxmarty
Copy link
Copy Markdown
Contributor

fxmarty commented Apr 11, 2023

Related: #22386

I don't think it's an issue to collide - if it is just better in most cases, having it default to users makes sense, in transformers natively (with some refactoring). However, for now, pytorch's sdpa has some limitations:

  • no scale argument (some archs do not scale query/key)
  • no speedup/memory savings for custom attention mask (flash and mem-efficient not supported)
  • no support for mixed fp16/fp32, like in some models where softmax is in fp32 while the rest in fp32
  • C++ implementation is good for all hardware, mem-efficient and flash are Nvidia-only

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 8, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions Bot closed this May 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants