support CP in native flash attention#12829
Conversation
Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
|
since native attention only support Ulysses Attention, we need an attention worked for Ring Attention. xpu enable _scaled_dot_product_flash_attention in torch. so we could use it for ring attention |
sayakpaul
left a comment
There was a problem hiding this comment.
Cool work! Could you also supplement a fully working code snippet?
|
yes, the PR also works for cuda. |
|
you could run it using torchrun --nproc-per-node 4 test.py, without the PR, the output is corrupted. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Hmm, it should raise an error no? On |
|
Weird. Will check and fix this. Cc: @DN6 |
Signed-off-by: Wang, Yi <yi.a.wang@intel.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
|
Okay I tracked it down. The order in which we're calling things matters. For example, if we do: pipeline.transformer.set_attention_backend("_native_flash")
pipeline.transformer.enable_parallelism(config=cp_config)It rightfully errors out: [rank0]: ValueError: Context parallelism is enabled but the attention processor 'FluxAttnProcessor' is using backend '_native_flash' which does not support context parallelism. Please set a compatible attention backend: ['_native_cudnn', 'flash', 'native', 'sage'] using `model.set_attention_backend()` before calling `enable_parallelism()`.But for any other combinations, it silently passes through. Will fix. |


What does this PR do?
native flash attention could support both Ulysses and ring attention