Skip to content

Enable relative positional embedding in flash attention #7997

@KumoLiu

Description

@KumoLiu

From reading this thread:
pytorch/pytorch#96099 (comment)
It seems to me that the relative positional embedding can be integrated with scaled_dot_product_attention 's attn_mask argument. However, it can be slow as it's not taking the "fast path".

Do you think we can keep this option open for users who wants to use flash_attention and rel_pos_embedding?

Originally posted by @mingxin-zheng in #7977 (comment)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions