Skip to content

Add support for non-causal (encoder-only) attention #1702

@westers

Description

@westers

Problem

AITER's unified_attention currently only supports causal attention:

# File: aiter/ops/triton/unified_attention.py:126
# Source: https://github.com/ROCm/aiter/blob/main/aiter/ops/triton/unified_attention.py#L126
assert causal, "Only causal attention is supported"

This prevents encoder-only models (BERT, RoBERTa, sentence transformers, embeddings) from using AITER's optimized attention on ROCm.

Use Case

vLLM pooling models fail on ROCm because:

  1. ROCM_AITER_FA backend raises NotImplementedError for ENCODER_ONLY attention type
  2. vLLM falls back to FlexAttention which has numerical precision issues on ROCm
  3. Result: 33 pooling tests failing on AMD CI

vLLM Issue: vllm-project/vllm#29466
vLLM PR workaround: vllm-project/vllm#31084

Impact

This limitation affects:

  • All encoder-only models on ROCm: BERT, RoBERTa, sentence-transformers, embeddings models
  • Frameworks: vLLM, Transformers, and other frameworks that use AITER
  • Competitiveness: Prevents ROCm from competing with CUDA for encoder-only workloads
  • Performance: Forces use of generic implementations instead of AITER-optimized kernels

Request

Add bidirectional (non-causal) attention support to unified_attention:

  1. Remove restriction: Remove or make conditional the assert causal check
  2. Add parameter: Add is_causal parameter to kernel calls
  3. Modify masks: Update attention mask logic to support bidirectional attention
  4. Testing: Test with encoder-only models (BERT-base, RoBERTa, sentence-transformers)

Example Models Affected

  • Embeddings: sentence-transformers/all-MiniLM-L12-v2, intfloat/e5-small
  • Cross-encoders: cross-encoder/ms-marco-MiniLM-L-6-v2
  • Classification: nie3e/sentiment-polish-gpt2-small
  • Token classification: boltuix/NeuroBERT-NER

All of these use ENCODER_ONLY attention (bidirectional, non-causal).

Current Workaround

vLLM is currently using generic FlashAttention for encoder-only models on ROCm, bypassing AITER. This works but doesn't benefit from AITER's optimizations.

References

Environment

  • ROCm Version: Latest (tested on vLLM AMD CI)
  • AITER Version: v0.1.7
  • Hardware: AMD MI300X, MI250X (AMD CI)
  • Framework: vLLM v0.7+

Thank you for considering this feature request! Adding encoder-only support would greatly benefit the ROCm ecosystem and enable AITER optimizations for a whole class of models currently forced to use generic implementations.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions