Reenable SDPA's FA2 During Training with torch.compile#30442
Reenable SDPA's FA2 During Training with torch.compile#30442fxmarty merged 7 commits intohuggingface:mainfrom
Conversation
|
Tagging @ArthurZucker and @younesbelkada for review. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
I hate having this if else.... but I guess it's for the best here.
I wish this was natively supported.
Quick question I guess this adds a guard?
Otherwise, LGTM and slow test will be triggered one merged.
|
fyi @fxmarty when you come back |
|
Not sure why the CI errored out after these formatting changes. Locally I still have |
fxmarty
left a comment
There was a problem hiding this comment.
Thank you! It looks okay to me, just suggested a style change.
|
@warner-benjamin can you make sure the CI is green? Failing ones seem unrelated, maybe merging main will do |
|
This PR causes a dynamo graph break at |
|
@tombousso Yes. Why is it an issue for you? Do you see perf degradation? AFAIK there is no obvious way around it, maybe using newer APIs from pytorch/pytorch#114823 & maybe other PRs |
|
https://pytorch.org/docs/main/cond.html may be the way? |
|
Yes, I was seeing perf degradation. I was hoping to get a graph with no breaks to make it easier to see what's going on, and to give the compiler the best opportunity to make optimizations. |
|
@tombousso Could you open an issue for that? |
This PR resolves #30010 and completes #30070 by reenabling the SDPA Flash Attention 2 kernel for
torch.compilewhen the model is training. During eval, SDPA dispatches to the efficient kernel with the same logic as in #30070.This PR will prevent SDPA Attention models from using a low amount of memory during training in eager mode but using a large amount or OOM'ing when compiling due to using the wrong SDPA kernel. It shouldn't affect exporting or generation when the model is in eval mode.
Moving the
is_causaldispatch logic from inline to an if statement is required to support bothfullgraph=Trueanddynamic=True. The current code errors out withdynamic=Truedue toq_len > 1not being the correct bool type. But wrapping it in a boolbool(q_len>1)to fix dynamic breaksfullgraph=True.The Llama tests that I could run either all pass or fail in the same state as on main (
LlamaIntegrationTest::test_conversion&LlamaIntegrationTest::test_compile_static_cache). I couldn't run Gemma tests due to a model gating error despite having access to Gemma.