-
Notifications
You must be signed in to change notification settings - Fork 168
Description
Problem
AITER's unified_attention currently only supports causal attention:
# File: aiter/ops/triton/unified_attention.py:126
# Source: https://github.com/ROCm/aiter/blob/main/aiter/ops/triton/unified_attention.py#L126
assert causal, "Only causal attention is supported"This prevents encoder-only models (BERT, RoBERTa, sentence transformers, embeddings) from using AITER's optimized attention on ROCm.
Use Case
vLLM pooling models fail on ROCm because:
- ROCM_AITER_FA backend raises
NotImplementedErrorforENCODER_ONLYattention type - vLLM falls back to FlexAttention which has numerical precision issues on ROCm
- Result: 33 pooling tests failing on AMD CI
vLLM Issue: vllm-project/vllm#29466
vLLM PR workaround: vllm-project/vllm#31084
Impact
This limitation affects:
- All encoder-only models on ROCm: BERT, RoBERTa, sentence-transformers, embeddings models
- Frameworks: vLLM, Transformers, and other frameworks that use AITER
- Competitiveness: Prevents ROCm from competing with CUDA for encoder-only workloads
- Performance: Forces use of generic implementations instead of AITER-optimized kernels
Request
Add bidirectional (non-causal) attention support to unified_attention:
- Remove restriction: Remove or make conditional the
assert causalcheck - Add parameter: Add
is_causalparameter to kernel calls - Modify masks: Update attention mask logic to support bidirectional attention
- Testing: Test with encoder-only models (BERT-base, RoBERTa, sentence-transformers)
Example Models Affected
- Embeddings:
sentence-transformers/all-MiniLM-L12-v2,intfloat/e5-small - Cross-encoders:
cross-encoder/ms-marco-MiniLM-L-6-v2 - Classification:
nie3e/sentiment-polish-gpt2-small - Token classification:
boltuix/NeuroBERT-NER
All of these use ENCODER_ONLY attention (bidirectional, non-causal).
Current Workaround
vLLM is currently using generic FlashAttention for encoder-only models on ROCm, bypassing AITER. This works but doesn't benefit from AITER's optimizations.
References
- vLLM Issue #29466: 33 pooling tests failing on ROCm - [CI Failure]: mi325_1: Language Models Test (Extended Pooling) vllm-project/vllm#29466
- vLLM PR #31084: Workaround using generic FlashAttention - Fix ROCm attention backend selection for encoder-only models vllm-project/vllm#31084
- AITER Source:
unified_attention.py- https://github.com/ROCm/aiter/blob/main/aiter/ops/triton/unified_attention.py#L126 - AMD CI Board: AMD CI
Environment
- ROCm Version: Latest (tested on vLLM AMD CI)
- AITER Version: v0.1.7
- Hardware: AMD MI300X, MI250X (AMD CI)
- Framework: vLLM v0.7+
Thank you for considering this feature request! Adding encoder-only support would greatly benefit the ROCm ecosystem and enable AITER optimizations for a whole class of models currently forced to use generic implementations.