diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 5f0e9261c0de..4e7cd35dd189 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -431,7 +431,7 @@ class RingAttention(torch.autograd.Function): INTER_RING_GROUP_COPY: dist.ProcessGroup = None @staticmethod - def get_double_ring_groups(sp_group, inner_ring_size=None): + def get_double_ring_groups(sp_group, tp_group, inner_ring_size=None): """ Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size shouldn't be larger than the number of NICs on each node. @@ -442,7 +442,8 @@ def get_double_ring_groups(sp_group, inner_ring_size=None): Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. """ sp_size = dist.get_world_size(sp_group) - sp_rank = dist.get_rank(sp_group) + dist.get_rank(sp_group) + dist.get_world_size(tp_group) if inner_ring_size is None: if torch.cuda.device_count() >= dist.get_world_size(): @@ -471,19 +472,24 @@ def get_double_ring_groups(sp_group, inner_ring_size=None): inner_ring_group = None inter_ring_group = None + world_size = dist.get_world_size() + rank = dist.get_rank() + groups = int(world_size / sp_size) # Create inner ring groups - for i in range(inner_ring_size): - ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) - group = dist.new_group(ranks) - if sp_rank in ranks: - inner_ring_group = group + for group_id in range(groups): + for i in range(inner_ring_size): + ranks = list(range(i + (group_id * sp_size), (1 + group_id) * sp_size, inner_ring_size)) + group = dist.new_group(ranks) + if rank in ranks: + inner_ring_group = group # Create inter ring groups - for i in range(num_rings): - ranks = list(range(i, sp_size, num_rings)) - group = dist.new_group(ranks) - if sp_rank in ranks: - inter_ring_group = group + for group_id in range(groups): + for i in range(num_rings): + ranks = list(range(i + group_id * num_rings, world_size, sp_size)) + group = dist.new_group(ranks) + if rank in ranks: + inter_ring_group = group return inner_ring_group, inter_ring_group @@ -493,6 +499,7 @@ def attention( k, v, sp_group, + tp_group, attention_mask_type, cu_seqlens=None, max_seqlen=None, @@ -537,7 +544,6 @@ def attention( RingAttention.ATTN_DONE = torch.cuda.Event() if RingAttention.SP_STREAM is None: RingAttention.SP_STREAM = torch.cuda.Stream() - assert ( q.shape[2] == k.shape[2] ), "Q, K and V having different sequence lengths (inference or cross-attn)\ @@ -550,7 +556,9 @@ def attention( if RingAttention.SP_GROUP is not sp_group: RingAttention.SP_GROUP = sp_group - inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size) + inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups( + sp_group, tp_group, inner_ring_size + ) RingAttention.INNER_RING_GROUP = inner_ring_group RingAttention.INTER_RING_GROUP = inter_ring_group else: @@ -597,6 +605,7 @@ def attention( attention_mask_type == AttnMaskType.PADDED_CAUSAL, inner_ring_group, inter_ring_group, + tp_group, ) if attention_mask_type == AttnMaskType.PADDED_CAUSAL: @@ -627,6 +636,7 @@ def forward( is_packed: Optional[bool] = False, inner_ring_group: Optional[dist.ProcessGroup] = None, inter_ring_group: Optional[dist.ProcessGroup] = None, + tp_group: Optional[dist.ProcessGroup] = None, ): cu_seqlens_q = cu_seqlens_kv = cu_seqlens @@ -1123,7 +1133,7 @@ def _other_ring_backward(ring_num_idx, dq): if not is_packed: dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)] - return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None) + return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None) @staticmethod def prepare_varlen_batch( diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 47c17e7494f2..fc5bcac6b8d4 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -563,12 +563,14 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + tp_group = shard_config.tensor_parallel_process_group if sp_mode == "ring_attn": attn_output = RingAttention.attention( query_states, key_states, value_states, sp_group, + tp_group, **attention_mask, inner_ring_size=shard_config.inner_ring_size, )