From c823dc82e931d09eeb8c862e0ee7de6438353424 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 9 Oct 2024 15:24:08 +0000 Subject: [PATCH 1/8] improve comments --- colossalai/shardformer/layer/attn.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 5f0e9261c0de..5359f284468e 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -422,13 +422,17 @@ class RingAttention(torch.autograd.Function): ATTN_DONE: torch.cuda.Event = None SP_STREAM: torch.cuda.Stream = None SP_GROUP: dist.ProcessGroup = None - # duplicate process group for concurrent NCCL streams - # while both PyTorch and NCCL warns(https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7) - # against this, in practice it seems to work fine. + + # NOTE: Duplicating PGs for concurrent NCCL streams is a risky hack -- while it may increase throughput, + # both PyTorch and NCCL warn against this. (https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7) + # LoongTrain's original double ring impl. uses concurrent PGs (https://github.com/InternLM/InternEvo/blob/d0a19fb1f513ddbb53d6ba94bd87569b8a3ce5e7/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py#L191) + # but I confirmed with Pytorch developers this can cause obscure + # "Software caused connection abort" errors. (https://github.com/pytorch/pytorch/issues/132852) + # NOTE: So in general, a smarter idea is to try putting as many P2P calls as possible in one `batch_isend_irecv`. INNER_RING_GROUP: dist.ProcessGroup = None - INNER_RING_GROUP_COPY: dist.ProcessGroup = None + # INNER_RING_GROUP_COPY: dist.ProcessGroup = None INTER_RING_GROUP: dist.ProcessGroup = None - INTER_RING_GROUP_COPY: dist.ProcessGroup = None + # INTER_RING_GROUP_COPY: dist.ProcessGroup = None @staticmethod def get_double_ring_groups(sp_group, inner_ring_size=None): @@ -733,6 +737,8 @@ def _local_ring_forward(): # NOTE: waiting outside the current stream will NOT correctly synchronize. if i > 0: local_kv_comms[(i + 1) % 2].wait() + + # Prefetch if i == 0: _kv_comm(i) @@ -799,6 +805,7 @@ def _other_ring_forward(ring_num_idx, out, softmax_lse): if i > 0: local_kv_comms[(i + 1) % 2].wait() + # Prefetch if i == 0: _kv_comm(i) From 12bac5e0395f2cc34e4390018ad529fdce4b9bcb Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 9 Oct 2024 15:25:38 +0000 Subject: [PATCH 2/8] improve comments --- colossalai/shardformer/layer/attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 5359f284468e..e3d565e8f00f 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -772,7 +772,7 @@ def _local_ring_forward(): ) = _forward(q_block, kv_block[0], kv_block[1], causal=False) RingAttention.ATTN_DONE.record() # Pipeline the next KV comm with output correction instead of the next flash attn - # to minimize idle time when comm takes longer than attn. + # kernel, to minimize bubble when comm takes longer than attn. _kv_comm(i + 1) block_softmax_lse[i % 2] = ( From 9557d28fe0d1ce0ae528023db67a30d7cab6aa3c Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 9 Oct 2024 17:07:36 +0000 Subject: [PATCH 3/8] update --- colossalai/shardformer/layer/attn.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index e3d565e8f00f..6ec774aebe13 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -425,10 +425,11 @@ class RingAttention(torch.autograd.Function): # NOTE: Duplicating PGs for concurrent NCCL streams is a risky hack -- while it may increase throughput, # both PyTorch and NCCL warn against this. (https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7) - # LoongTrain's original double ring impl. uses concurrent PGs (https://github.com/InternLM/InternEvo/blob/d0a19fb1f513ddbb53d6ba94bd87569b8a3ce5e7/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py#L191) - # but I confirmed with Pytorch developers this can cause obscure - # "Software caused connection abort" errors. (https://github.com/pytorch/pytorch/issues/132852) - # NOTE: So in general, a smarter idea is to try putting as many P2P calls as possible in one `batch_isend_irecv`. + # LoongTrain's original double ring impl. uses concurrent PGs + # (https://github.com/InternLM/InternEvo/blob/e52f2ffc9acf818e8f2b1f97dfc69ceb2f06e154/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py#L192) + # but I confirmed with Pytorch developers this can cause obscure "Software caused connection abort" errors. + # (https://github.com/pytorch/pytorch/issues/132852) + # NOTE: In general, a smarter idea is put as many P2P calls as possible into one `batch_isend_irecv`. INNER_RING_GROUP: dist.ProcessGroup = None # INNER_RING_GROUP_COPY: dist.ProcessGroup = None INTER_RING_GROUP: dist.ProcessGroup = None From cba8b33faa8c698d8f3a94071e126ccc5b5c10d3 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 11 Oct 2024 16:49:09 +0000 Subject: [PATCH 4/8] update --- colossalai/shardformer/layer/attn.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 6ec774aebe13..06bf366085a4 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -675,7 +675,8 @@ def forward( sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) - # Attempt to achieve concurrent comm in the two-stream forward + + # Create two comms corresponding to two CUDA streams local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)] inter_ring_comm = RingComm(inter_ring_group) local_sp_size = dist.get_world_size(inner_ring_group) @@ -683,7 +684,7 @@ def forward( inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0 num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1 - # Non-contiguous indexing copies to a new contiguous tensor, + # Any type of indexing(slicing doesn't) copies to a new contiguous tensor, # so only do it once if sp_rank != sp_size - 1: q1 = q[half_idx_back] @@ -935,7 +936,7 @@ def backward(ctx, dout, _): local_sp_rank = dist.get_rank(sp_group) sp_size = dist.get_world_size(sp_group) - # Using separate streams (pg) for concurrent kv and dkv comm may + # NOTE: Using separate streams (PG) for concurrent kv and dkv comm may # cause NCCL "software caused connection abort" here... local_kv_comm = RingComm(local_kv_group) local_dkv_comm = RingComm(local_kv_group) From 4a6d73ef67a4dd3d2b699ead9ba79d83be2cabc6 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sat, 12 Oct 2024 19:29:47 +0000 Subject: [PATCH 5/8] update --- colossalai/shardformer/layer/attn.py | 2 +- colossalai/shardformer/layer/utils.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 06bf366085a4..c2f72e7c6df0 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -1158,7 +1158,7 @@ def prepare_varlen_batch( Returns: torch.Tensor: - Packed input embeddings of shape [B, Sq // sp_size, ...]. + Packed input embeddings of shape [B, Sq // sp_size, ...] if is_2d is True, else [T, ...]. Dict[str, Any]: A dictionary containing mask info. diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 4512e0c680f3..54891c8dad96 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -375,6 +375,9 @@ def split_varlen_zigzag( if isinstance(batch, torch.Tensor): batch = [batch] + # seq: (B, Sq, h, n) + # seq = seq[:, :rank * (seqlen // sp_size), ...] + for i, packed_seq in enumerate(batch): device = packed_seq.device dtype = packed_seq.dtype From 3273eb52aba3775e834c8c1eecaedd80d710958e Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sat, 12 Oct 2024 19:42:15 +0000 Subject: [PATCH 6/8] update --- colossalai/shardformer/layer/attn.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index c2f72e7c6df0..78ef9249b709 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -781,7 +781,9 @@ def _local_ring_forward(): block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() ) # (H, T) -> (T, H, 1) assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1] - # Output and log sum exp correction. Ideally overlap this with the next flash attn kernel. + # Output and log sum exp correction. + # NOTE: Ideally overlap this with the next flash attn kernel, + # since attn uses Tensor Core and rescale is element-wise and doesn't use Tensor Core. # In reality this always finishes before next flash attn; no need for extra sync. if i == 0: out = block_out[0] @@ -1153,7 +1155,7 @@ def prepare_varlen_batch( position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None. is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first token of each sequence. - is_2d (bool, optional): Whether to return 2D outputs padded to max_seqlen // sp_size or flatten + is_2d (bool, optional): Whether to return 2D embeddings padded to max_seqlen // sp_size or flatten the batch dim to a packed 1d sequence. Contingent on model forward shape definitions. Returns: From 6fef58bf50a41ea32712b440d170cb6e928b33d2 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sat, 12 Oct 2024 21:23:34 +0000 Subject: [PATCH 7/8] improve method arg names, add more perf. comments and TODOs --- colossalai/shardformer/layer/attn.py | 72 ++++++++++++++++----------- colossalai/shardformer/layer/utils.py | 29 ++++++----- 2 files changed, 58 insertions(+), 43 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 78ef9249b709..bea9bf17bd5e 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -633,7 +633,13 @@ def forward( inner_ring_group: Optional[dist.ProcessGroup] = None, inter_ring_group: Optional[dist.ProcessGroup] = None, ): - + """ + Forward supporting both packed (varlen) and batched(fixed length, no padding) sequences. + No separate version for batched seq (hard to maintain), which incurs + some overhead in sequence splitting due to python for loops. + Uses two CUDA streams to overlap softmax denominator correction with next flash attn + (see comments below). + """ cu_seqlens_q = cu_seqlens_kv = cu_seqlens max_seqlen_q = max_seqlen_kv = max_seqlen cu_seqlens_half = cu_seqlens // 2 @@ -701,6 +707,7 @@ def forward( rng_states = [None for _ in range(sp_size)] sp_streams = [torch.cuda.current_stream(), sp_stream] + # Helper to pass args to FA def _forward(q, k, v, causal): ( _, @@ -731,6 +738,7 @@ def _kv_comm(i): if i < local_sp_size - 1: local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) + # Forward within a node def _local_ring_forward(): # (Hopefully) overlap output correction with next flash attn for i in range(local_sp_size): @@ -781,10 +789,15 @@ def _local_ring_forward(): block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() ) # (H, T) -> (T, H, 1) assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1] + # Output and log sum exp correction. - # NOTE: Ideally overlap this with the next flash attn kernel, - # since attn uses Tensor Core and rescale is element-wise and doesn't use Tensor Core. - # In reality this always finishes before next flash attn; no need for extra sync. + # Ideally overlap this with the next flash attn kernel, + # since attn uses Tensor Core and rescale is element-wise, memory-bound and uses CUDA cores. + # (NOTE that this is the same as ping-pong scheduling idea in FA3) + # TODO However sometimes while the GPU has scheduled the next kernel, + # it's reluctant to launch it in overlap. Some potential causes: + # 1. need lower-level CUDA scheduling 2. further benchmark against Megatron-LM + # 3. register spilling by FA kernel. if i == 0: out = block_out[0] softmax_lse = block_softmax_lse[0] @@ -800,9 +813,10 @@ def _local_ring_forward(): torch.cuda.current_stream().wait_stream(sp_stream) return out, softmax_lse + # Forward for inter-node (the outer ring in 2D ring) def _other_ring_forward(ring_num_idx, out, softmax_lse): # Loop through the inner ring after receiving - # all new KVs from the previous inner ring + # all new KVs from another ring for i in range(local_sp_size): with torch.cuda.stream(sp_streams[i % 2]): # Send & recv KV @@ -906,7 +920,8 @@ def _other_ring_forward(ring_num_idx, out, softmax_lse): def backward(ctx, dout, _): """ During backward, we accumulate q grads on each rank locally, but iterate kv and their grads - over all ranks for accumulation. + over all ranks for accumulation. We avoid using two streams due to backward using doubled + buffers and more comm cost. """ (q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9] rng_states = ctx.saved_tensors[9:] @@ -970,6 +985,7 @@ def backward(ctx, dout, _): dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D) del k, v + # Helper to pass args to FA def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal): _flash_attn_backward( dout, @@ -990,8 +1006,7 @@ def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal): **misc_kwargs, ) - # NOTE: We avoid using two streams due to doubled buffers - # and that backward is more communication intensive. + # Backward within a node def _local_ring_backward(): for i in range(local_sp_size): if i > 0: @@ -1054,6 +1069,7 @@ def _local_ring_backward(): dkv_send = dkv_buffers[(local_sp_size - 1) % 2] return dq, dkv_recv, dkv_send + # Backward for inter-node (the outer ring in 2D ring) def _other_ring_backward(ring_num_idx, dq): if ring_num_idx > inter_ring_rank: # Indexing is expensive @@ -1138,34 +1154,34 @@ def _other_ring_backward(ring_num_idx, dq): @staticmethod def prepare_varlen_batch( - attention_mask: torch.Tensor, + padding_mask: torch.Tensor, sp_group: dist.ProcessGroup, inputs_embeds: torch.Tensor = None, position_ids: Optional[torch.Tensor] = None, is_label: bool = False, - is_2d: bool = True, + is_batched_seq: bool = True, ): + # TODO: support setting a batch dim (fix packing length) for packed mode, so that + # DP can be used (needs to modify dataloader too) """ Preprocess a batch of padded sequence by splitting input sequence by sp_size - sequence-wise and packing them into one sequence. Updates the mask info accordingly. + seq-wise and packing them into one sequence. Updates the mask info accordingly. Args: - attention_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked. + padding_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked. sp_group (dist.ProcessGroup): Process group for sequence parallelism inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...] position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None. is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first token of each sequence. - is_2d (bool, optional): Whether to return 2D embeddings padded to max_seqlen // sp_size or flatten - the batch dim to a packed 1d sequence. Contingent on model forward shape definitions. + is_batched_seq (bool, optional): If True, then the input is a batch of (potentially padded) sequences + of shape [B, Sq, ...]; else a packed sequence of shape [T, ...]. Returns: - torch.Tensor: - Packed input embeddings of shape [B, Sq // sp_size, ...] if is_2d is True, else [T, ...]. - - Dict[str, Any]: + inputs_embeds (torch.Tensor): + Packed input embeddings of shape [B, Sq // sp_size, ...] if is_batched_seq, else [T, ...]. + mask_info (Dict[str, Any]): A dictionary containing mask info. - - torch.Tensor: + position_ids (torch.Tensor): Packed position ids of shape [..., Sq // sp_size]. """ @@ -1173,12 +1189,11 @@ def prepare_varlen_batch( sp_size = dist.get_world_size(group=sp_group) sp_rank = dist.get_rank(group=sp_group) mask_info = {} - mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(attention_mask, return_indices=False) + mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(padding_mask, return_indices=False) - # Unpad, split seq-wise, then pad back to (B, max_seqlen // sp_size) - # Split mask to compute local nonzero position indices + # Unpad, split seq-wise, then pad to (B, max_seqlen // sp_size) # (B, Sq) -> (B, max_seqlen // sp_size) - attention_mask = attention_mask[:, : mask_info["max_seqlen"]] + padding_mask = padding_mask[:, : mask_info["max_seqlen"]] if inputs_embeds is not None: inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]] inputs_embeds = split_varlen_zigzag( @@ -1186,11 +1201,12 @@ def prepare_varlen_batch( mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], - is_2d=is_2d, + is_batched_seq=is_batched_seq, is_label=is_label, ) - attention_mask = split_varlen_zigzag( - attention_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d + # Split mask to get local nonzero seq positions + padding_mask = split_varlen_zigzag( + padding_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_batched_seq=is_batched_seq ) if position_ids is not None: @@ -1203,7 +1219,7 @@ def prepare_varlen_batch( ) mask_info["max_seqlen"] //= sp_size - mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + mask_info["valid_indices"] = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() mask_info["cu_seqlens"] //= sp_size mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL return inputs_embeds, mask_info, position_ids diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 54891c8dad96..2df68e18c64d 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -295,8 +295,8 @@ def split_batch_zigzag( batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False ) -> Union[torch.Tensor, List[torch.Tensor]]: """ - Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask - in the causal setting will result in the preceding ranks having much less workload. + Split the input sequence batch . Naively spliting the attention mask in the causal setting + will result in the preceding ranks having much less workload. We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2). For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |. @@ -346,31 +346,30 @@ def split_varlen_zigzag( cu_seqlens: torch.Tensor, sp_group: ProcessGroup, max_seqlen: int = 0, - is_2d: bool = False, + is_batched_seq: bool = False, is_label: bool = False, ) -> Union[List[torch.Tensor], torch.Tensor]: - """Split each sequence in a batch of packed sequences in a zigzag fashion. - For each tensor in batch, return packed sequences if is_2d is False; - else return a padded batch of sequences. - + """Split a packed seq/batch of padded sequences in a Zigzag fashion. + Different from split_batch_zigzag, inputs here have variable sequence lengths. Args: - batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d. + batch (List[torch.Tensor]): Packed sequences of shape (T, ...), or (B, Sq, ...) if is_batched_seq, + where T is the total number of tokens. cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting. sp_group (ProcessGroup): The process group for sequence parallelism. max_seqlen (int): The maximum sequence length in the batch before splitting. - is_2d (bool): If True, then input has batch size and sequence length split into two dimensions. + is_batched_seq (bool): If True, then the input is a batch of sequences padded to the same len. is_label (bool): If True, mask out the first token in each sequence (). Returns: - batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size) - or (B, max_seqlen // sp_size, ...) if is_2d + batch (List[torch.Tensor]): Packed sequences of shape (T, ..) + or (B, max_seqlen // sp_size, ...) if is_batched_seq """ sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) if sp_size == 1: return batch - if is_2d: + if is_batched_seq: assert max_seqlen > 0, "max_seqlen must be provided for 2D input" if isinstance(batch, torch.Tensor): @@ -382,7 +381,7 @@ def split_varlen_zigzag( device = packed_seq.device dtype = packed_seq.dtype - if is_2d: + if is_batched_seq: assert max_seqlen % (sp_size * 2) == 0 # Recreate a padded tensor with the new max seqlen shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:]) @@ -401,7 +400,7 @@ def split_varlen_zigzag( seqlen % (2 * sp_size) == 0 ), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting" - if is_2d: + if is_batched_seq: seq = packed_seq[j][:seqlen] if is_label: # Shift one position to the right for next token prediction @@ -418,7 +417,7 @@ def split_varlen_zigzag( seq = seq.chunk(sp_size * 2) local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]]) - if is_2d: + if is_batched_seq: batch[i] = local_seq.contiguous() else: batch[i] = torch.cat(local_seq, dim=0) From 373c1d9bfc18cae54c9619a47c9851172b20eebc Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 15 Oct 2024 22:00:17 +0000 Subject: [PATCH 8/8] update --- colossalai/shardformer/layer/attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index bea9bf17bd5e..95aeb6ca8c27 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -682,7 +682,7 @@ def forward( sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) - # Create two comms corresponding to two CUDA streams + # Create communicators corresponding to two CUDA streams local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)] inter_ring_comm = RingComm(inter_ring_group) local_sp_size = dist.get_world_size(inner_ring_group) @@ -690,7 +690,7 @@ def forward( inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0 num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1 - # Any type of indexing(slicing doesn't) copies to a new contiguous tensor, + # Any type of indexing(but not slicing) copies to a new contiguous tensor, # so only do it once if sp_rank != sp_size - 1: q1 = q[half_idx_back]