Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ class RingAttention(torch.autograd.Function):
INTER_RING_GROUP_COPY: dist.ProcessGroup = None

@staticmethod
def get_double_ring_groups(sp_group, inner_ring_size=None):
def get_double_ring_groups(sp_group, tp_group, inner_ring_size=None):
"""
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
shouldn't be larger than the number of NICs on each node.
Expand All @@ -442,7 +442,8 @@ def get_double_ring_groups(sp_group, inner_ring_size=None):
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
dist.get_rank(sp_group)
dist.get_world_size(tp_group)

if inner_ring_size is None:
if torch.cuda.device_count() >= dist.get_world_size():
Expand Down Expand Up @@ -471,19 +472,24 @@ def get_double_ring_groups(sp_group, inner_ring_size=None):
inner_ring_group = None
inter_ring_group = None

world_size = dist.get_world_size()
rank = dist.get_rank()
groups = int(world_size / sp_size)
# Create inner ring groups
for i in range(inner_ring_size):
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
group = dist.new_group(ranks)
if sp_rank in ranks:
inner_ring_group = group
for group_id in range(groups):
for i in range(inner_ring_size):
ranks = list(range(i + (group_id * sp_size), (1 + group_id) * sp_size, inner_ring_size))
group = dist.new_group(ranks)
if rank in ranks:
inner_ring_group = group

# Create inter ring groups
for i in range(num_rings):
ranks = list(range(i, sp_size, num_rings))
group = dist.new_group(ranks)
if sp_rank in ranks:
inter_ring_group = group
for group_id in range(groups):
for i in range(num_rings):
ranks = list(range(i + group_id * num_rings, world_size, sp_size))
group = dist.new_group(ranks)
if rank in ranks:
inter_ring_group = group

return inner_ring_group, inter_ring_group

Expand All @@ -493,6 +499,7 @@ def attention(
k,
v,
sp_group,
tp_group,
attention_mask_type,
cu_seqlens=None,
max_seqlen=None,
Expand Down Expand Up @@ -537,7 +544,6 @@ def attention(
RingAttention.ATTN_DONE = torch.cuda.Event()
if RingAttention.SP_STREAM is None:
RingAttention.SP_STREAM = torch.cuda.Stream()

assert (
q.shape[2] == k.shape[2]
), "Q, K and V having different sequence lengths (inference or cross-attn)\
Expand All @@ -550,7 +556,9 @@ def attention(

if RingAttention.SP_GROUP is not sp_group:
RingAttention.SP_GROUP = sp_group
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size)
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(
sp_group, tp_group, inner_ring_size
)
RingAttention.INNER_RING_GROUP = inner_ring_group
RingAttention.INTER_RING_GROUP = inter_ring_group
else:
Expand Down Expand Up @@ -597,6 +605,7 @@ def attention(
attention_mask_type == AttnMaskType.PADDED_CAUSAL,
inner_ring_group,
inter_ring_group,
tp_group,
)

if attention_mask_type == AttnMaskType.PADDED_CAUSAL:
Expand Down Expand Up @@ -627,6 +636,7 @@ def forward(
is_packed: Optional[bool] = False,
inner_ring_group: Optional[dist.ProcessGroup] = None,
inter_ring_group: Optional[dist.ProcessGroup] = None,
tp_group: Optional[dist.ProcessGroup] = None,
):

cu_seqlens_q = cu_seqlens_kv = cu_seqlens
Expand Down Expand Up @@ -1123,7 +1133,7 @@ def _other_ring_backward(ring_num_idx, dq):
if not is_packed:
dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)]

return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None)
return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None)

@staticmethod
def prepare_varlen_batch(
Expand Down
2 changes: 2 additions & 0 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,12 +563,14 @@ def forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

tp_group = shard_config.tensor_parallel_process_group
if sp_mode == "ring_attn":
attn_output = RingAttention.attention(
query_states,
key_states,
value_states,
sp_group,
tp_group,
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
)
Expand Down