Skip to content

[RoadMap] Support more Attention Template for training and inference #13

@smallscientist1

Description

@smallscientist1

Roadmap

Plan to support more Attention pattern on more devices.

Tilelang Kernel Template

Tilelang kernel.
Example(mha forward&backward): attention_engine/core/template/tl_template/attn/attn_tl.py

  • MLA decode (@smallscientist1 )
  • GQA forward & backward
  • Varlen attention forward & backward & decode
  • Block-sparse mask for attention bwd & decode
  • Block-sparse indices for attention fwd & bwd & decode

Lowering

Lowering customized code into the kernel, such as score_mod, online_func and mask_mod.
Example: attention_engine/core/lower/lower.py

  • Customized Attention(sigmoid, relu, ...) decode (assigned to @smallscientist1 )
  • Retnet backward
  • dynamic max seqlen support (assigned to @smallscientist1 )
    • Support mha prefill&backward dynamic seqlen

Device

  • Amd mi300 Kernel Template & lowering (@smallscientist1 )
  • NVIDIA device(RTX4090, A100, ...) hardware config
  • NVIDIA device(RTX4090, A100, ...) performance tuning for more template (assigned to @smallscientist1 )
    • Implement mha fwd&bwd autotune in attention_engine/core/template/tl_template/attn/attn_tl.py

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requesthelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions