AttentionEngine(block_size=BLOCK_SIZE)(
q, # [batch, seqlenq, head, dimqk]
k, # [batch, seqlenkv, head, dimqk]
v, # [batch, seqlenkv, head, dimv]
block_mask, # [batch, seqlenq//BLOCK_SIZE, seqlenkv//BLOCK_SIZE]
)
AttentionEngine(block_size=BLOCK_SIZE)(
q, # [batch, seqlenq, head, dimqk]
k, # [batch, seqlenkv, head, dimqk]
v, # [batch, seqlenkv, head, dimv]
block_indices, # [batch, seqlenq//BLOCK_SIZE, head, MAX_BLOCKS]
selected_block_num, # [batch, head]
)