Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,13 @@ def _forward(q, k, v, causal):
)
return out, softmax_lse, rng_state

def _kv_comm(i):
# Avoid overwriting attn input when it shares mem with buffer
if not RingAttention.ATTN_DONE.query():
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])

def _local_ring_forward():
# (Hopefully) overlap output correction with next flash attn
for i in range(local_sp_size):
Expand All @@ -698,12 +705,8 @@ def _local_ring_forward():
# NOTE: waiting outside the current stream will NOT correctly synchronize.
if i > 0:
local_kv_comms[(i + 1) % 2].wait()

# Avoid overwriting attn input when it shares mem with buffer
if not RingAttention.ATTN_DONE.query():
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
if i == 0:
_kv_comm(i)

if i == 0:
# Compute with local KV; no mask
Expand Down Expand Up @@ -734,6 +737,9 @@ def _local_ring_forward():
rng_states[i],
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record()
# Pipeline the next KV comm with output correction instead of the next flash attn
# to minimize idle time when comm takes longer than attn.
_kv_comm(i + 1)

block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
Expand Down Expand Up @@ -761,15 +767,13 @@ def _other_ring_forward(ring_num_idx, out, softmax_lse):
# all new KVs from the previous inner ring
for i in range(local_sp_size):
with torch.cuda.stream(sp_streams[i % 2]):
if not RingAttention.ATTN_DONE.query():
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])

# Send & recv KV
if i > 0:
local_kv_comms[(i + 1) % 2].wait()

if i == 0:
_kv_comm(i)

if ring_num_idx > inter_ring_rank:
kv_block = kv_buffers[i % 2]
(
Expand All @@ -778,6 +782,8 @@ def _other_ring_forward(ring_num_idx, out, softmax_lse):
rng_states[i + local_sp_size * ring_num_idx],
) = _forward(q1, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record()

_kv_comm(i + 1)
block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
)
Expand All @@ -792,6 +798,8 @@ def _other_ring_forward(ring_num_idx, out, softmax_lse):
rng_states[i + local_sp_size * ring_num_idx],
) = _forward(q, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record()

_kv_comm(i + 1)
block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
)
Expand Down