Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 25 additions & 49 deletions colossalai/nn/layer/parallel_sequence/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,25 @@ class RingQK(torch.autograd.Function):

@staticmethod
@custom_fwd
def forward(ctx,
sub_q,
sub_k,
batch_size,
num_attention_heads,
sub_seq_length):
def forward(ctx, sub_q, sub_k, batch_size, num_attention_heads, sub_seq_length):
# save tensor for backward
ctx.save_for_backward(sub_q, sub_k)
ctx.sub_seq_length = sub_seq_length

# create local segment of attention score
attention_score = torch.empty(
batch_size * num_attention_heads,
sub_seq_length,
sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
dtype=sub_q.dtype,
device=get_current_device()
)
attention_score = torch.empty(batch_size * num_attention_heads,
sub_seq_length,
sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
dtype=sub_q.dtype,
device=get_current_device())

# compute local QK^T
part_a = torch.matmul(sub_q, sub_k.transpose(2, 1))
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
start_idx = local_rank * sub_seq_length
end_idx = (local_rank + 1) * sub_seq_length
attention_score[:, :, start_idx: end_idx] = part_a
attention_score[:, :, start_idx:end_idx] = part_a

# compute QK^T in ring-all-reduce style
for i in range(local_world_size - 1):
Expand All @@ -63,19 +56,18 @@ def backward(ctx, grad_output):
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)

# calculate gradient of sub_k
grad_k = torch.matmul(
grad_output.transpose(2, 1),
sub_q
)
grad_k = torch.matmul(grad_output.transpose(2, 1), sub_q)

dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE))
grad_k = grad_k[:, local_rank * ctx.sub_seq_length: (local_rank + 1) * ctx.sub_seq_length]
grad_k = grad_k[:, local_rank * ctx.sub_seq_length:(local_rank + 1) * ctx.sub_seq_length]
grad_k /= local_world_size

# calculate gradient for sub_q
grad_q = torch.zeros_like(sub_q,
dtype=sub_q.dtype,
device=get_current_device(), )
grad_q = torch.zeros_like(
sub_q,
dtype=sub_q.dtype,
device=get_current_device(),
)

# compute with local sub_k
start_idx, end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length)
Expand All @@ -85,7 +77,7 @@ def backward(ctx, grad_output):
for i in range(local_world_size - 1):
sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE)
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length)
grad_q += torch.matmul(grad_output[:, :, start_idx: end_idx], sub_k)
grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k)

grad_q /= local_world_size

Expand All @@ -99,23 +91,16 @@ class RingAV(torch.autograd.Function):

@staticmethod
@custom_fwd
def forward(ctx,
attention_score,
sub_v,
batch_size,
num_attention_heads,
attention_head_size,
sub_seq_length):
def forward(ctx, attention_score, sub_v, batch_size, num_attention_heads, attention_head_size, sub_seq_length):
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length)

sub_attention_result = torch.zeros(
batch_size * num_attention_heads,
sub_seq_length,
attention_head_size,
device=get_current_device(),
dtype=attention_score.dtype)
sub_attention_result = torch.zeros(batch_size * num_attention_heads,
sub_seq_length,
attention_head_size,
device=get_current_device(),
dtype=attention_score.dtype)

# save tensors for backward
ctx.save_for_backward(attention_score, sub_v)
Expand Down Expand Up @@ -144,32 +129,23 @@ def backward(ctx, grad_output):
attention_scores, sub_v = ctx.saved_tensors

# calculate gradient of v
grad_v = torch.matmul(
attention_scores.transpose(2, 1),
grad_output
)
grad_v = torch.matmul(attention_scores.transpose(2, 1), grad_output)
dist.all_reduce(grad_v, group=gpc.get_group(ParallelMode.SEQUENCE))
grad_v = grad_v[:, local_start_idx:local_end_idx]
grad_v /= local_world_size

# calculate gradient for attention score
grad_attention_score = torch.zeros_like(attention_scores,
dtype=grad_output.dtype,
device=get_current_device())
grad_attention_score = torch.zeros_like(attention_scores, dtype=grad_output.dtype, device=get_current_device())

# compute with local sub_k
grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(
grad_output,
sub_v.transpose(2, 1))
grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1))

# compute QK^T in ring-all-reduce style
for i in range(local_world_size - 1):
sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE)
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length)

# compute grad_q
grad_attention_score[:, :, start_idx:end_idx] += torch.matmul(
grad_output,
sub_v.transpose(2, 1))
grad_attention_score[:, :, start_idx:end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1))

return grad_attention_score, grad_v, None, None, None, None