Skip to content

Delay and probably avoid unnecessary graph breaks in _upad_input of modeling_flash_attention_utils.py#41097

Open
cyyever wants to merge 1 commit intohuggingface:mainfrom
cyyever:flash_attn
Open

Delay and probably avoid unnecessary graph breaks in _upad_input of modeling_flash_attention_utils.py#41097
cyyever wants to merge 1 commit intohuggingface:mainfrom
cyyever:flash_attn

Conversation

@cyyever
Copy link
Copy Markdown
Contributor

@cyyever cyyever commented Sep 23, 2025

What does this PR do?

It works by refactoring _get_unpad_data

@cyyever cyyever marked this pull request as draft September 23, 2025 10:51
@cyyever cyyever force-pushed the flash_attn branch 2 times, most recently from 8cf2ae6 to dc5f4fe Compare September 23, 2025 10:54
@cyyever cyyever marked this pull request as ready for review September 23, 2025 10:57
Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
@cyyever cyyever changed the title Avoid unnecessary graph breaks in _upad_input of modeling_flash_attention_utils.py Delay and probably avoid unnecessary graph breaks in _upad_input of modeling_flash_attention_utils.py Sep 23, 2025
@Rocketknight1
Copy link
Copy Markdown
Member

cc @Cyrilvallez for attention

@Cyrilvallez
Copy link
Copy Markdown
Member

Humm, I don't really see how this delays the graph break? Moreover, max_seqlen_in_batch_k is now a Tensor instead of int, which is wrong

@cyyever
Copy link
Copy Markdown
Contributor Author

cyyever commented Sep 23, 2025

@Cyrilvallez It delays until query_length == kv_seq_len, then .item() is called, in other cases, .item() is avoided.

@Cyrilvallez
Copy link
Copy Markdown
Member

Yes, but the function is called only if attention_mask is not None in _flash_attention_forward, in which case graph breaks is unavoidable if I'm not mistaken 🤔

@cyyever
Copy link
Copy Markdown
Contributor Author

cyyever commented Sep 29, 2025

@Cyrilvallez 😭

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants