From 0593f04eefe9425f5c4565ae0ac0b92861f1b110 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Sun, 10 Dec 2023 00:34:12 +0800 Subject: [PATCH 01/50] sequence parallel optimization --- .../booster/plugin/hybrid_parallel_plugin.py | 2 + colossalai/shardformer/layer/__init__.py | 2 + colossalai/shardformer/layer/_operation.py | 38 +++ colossalai/shardformer/modeling/llama.py | 157 ++++++++++- colossalai/shardformer/policies/llama.py | 28 ++ colossalai/shardformer/shard/shard_config.py | 1 + tests/kit/model_zoo/transformers/llama.py | 25 +- tests/test_shardformer/test_layer/demo.py | 161 ++++++++++++ .../test_layer/test_sequence_parallel.py | 248 ++++++++++++++++++ tests/test_shardformer/test_model/_utils.py | 18 +- .../test_model/test_llama_seq.py | 140 ++++++++++ .../test_model/test_shard_llama.py | 7 + 12 files changed, 819 insertions(+), 8 deletions(-) create mode 100644 tests/test_shardformer/test_layer/demo.py create mode 100644 tests/test_shardformer/test_layer/test_sequence_parallel.py create mode 100644 tests/test_shardformer/test_model/test_llama_seq.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index f51cb060c356..537146c1e470 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -970,6 +970,7 @@ def __init__( pp_style: str = "1f1b", num_model_chunks: int = 1, enable_metadata_cache: bool = True, + test_seq_parallelism: bool = False, ) -> None: super().__init__() assert ( @@ -1043,6 +1044,7 @@ def __init__( enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, + test_seq_parallelism=test_seq_parallelism, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index c9b4317a6f17..0e368dbf94ee 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,4 +1,5 @@ from .attn import AttnMaskType, ColoAttention +from ._operation import all_to_all_comm from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row @@ -26,4 +27,5 @@ "ParallelModule", "AttnMaskType", "ColoAttention", + "all_to_all_comm", ] diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 241770901ed7..0afb36937e19 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -515,6 +515,33 @@ def backward(ctx, grad_output): return _split(grad_output, ctx.dim, ctx.process_group), None, None +class _AllToAll(torch.autograd.Function): + """All-to-all communication. + + Args: + input_: input matrix + process_group: communication group + scatter_dim: scatter dimension + gather_dim: gather dimension + """ + + @staticmethod + def forward(ctx, input_, process_group, scatter_dim, gather_dim): + ctx.process_group = process_group + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + world_size = dist.get_world_size(process_group) + return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim) + + @staticmethod + def backward(ctx, *grad_output): + process_group = ctx.process_group + scatter_dim = ctx.gather_dim + gather_dim = ctx.scatter_dim + return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) + return (return_grad, None, None, None) + + class HookParameter(torch.autograd.Function): """In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm""" @@ -608,6 +635,13 @@ def _reduce_scatter(input_, dim=1, process_group=None): return output +def _all_to_all(input_, world_size, group, scatter_dim, gather_dim): + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) @@ -650,3 +684,7 @@ def reduce_forward(input_, process_group): def reduce_backward(input_, process_group): return _ReduceBackward.apply(input_, process_group) + + +def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): + return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 29dc8200f338..52d3bf18d439 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -1,8 +1,12 @@ +import math import warnings -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple import torch +import torch.distributed as dist import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_outputs import ( BaseModelOutputWithPast, @@ -16,6 +20,12 @@ from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, cross_entropy_1d +from colossalai.shardformer.layer._operation import all_to_all_comm + + +def print_rank(prompt, value, rank=0): + if dist.get_rank() == rank: + print(f"rank-{rank}, {prompt}: {value}") try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask @@ -714,3 +724,148 @@ def forward( ) return forward +def test_llama_seq_parallel_attention(): + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + print_rank("cos-0", cos.shape) # torch.Size([16, 32]) + print_rank("cos-position", cos[position_ids]) # torch.Size([1, 1, 4, 32]) + + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + print_rank("cos-1", cos.shape) # torch.Size([1, 1, 4, 32]) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + print_rank("hidden_states-origin", hidden_states.shape) + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # introduce sequence parallel + query_states = all_to_all_comm(query_states) + key_states = all_to_all_comm(key_states) + value_states = all_to_all_comm(value_states) + + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + past_key_value_length = 0 + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + print_rank("cos", cos.shape) + + print_rank("position_ids", position_ids) + if position_ids is not None: + position_ids = torch.arange( + past_key_value_length, q_len + past_key_value_length, dtype=torch.long, device=hidden_states.device + ) + + print_rank("position_ids-2", position_ids) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + return forward diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index db8468713f66..6a8536f893ba 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -2,6 +2,7 @@ from functools import partial from typing import Callable, Dict, List, Union +import torch.distributed as dist import torch.nn as nn from torch import Tensor from torch.nn import Module @@ -13,6 +14,7 @@ get_llama_flash_attention_forward, get_llama_model_forward_for_flash_attn, get_lm_forward_with_dist_cross_entropy, + test_llama_seq_parallel_attention, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -49,6 +51,32 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: self.shard_config.enable_sequence_parallelism = False warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + # todo: seq + if self.shard_config.test_seq_parallelism: + sequence_parallelism_size = dist.get_world_size() + decoder_attribute_replacement = { + "num_heads": self.model.config.num_attention_heads // sequence_parallelism_size, + "head_dim": self.model.config.hidden_size // self.model.config.num_attention_heads, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["num_key_value_heads"] = ( + self.model.config.num_key_value_heads // sequence_parallelism_size + ) + decoder_attribute_replacement["num_key_value_groups"] = ( + self.model.config.hidden_size // self.model.config.num_attention_heads + ) + policy[LlamaAttention] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + ) + + self.append_or_create_method_replacement( + description={ + "forward": test_llama_seq_parallel_attention(), + }, + policy=policy, + target_key=LlamaAttention, + ) + if self.shard_config.enable_tensor_parallelism: decoder_attribute_replacement = { "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index da27341d9c29..b805bac48d65 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -35,6 +35,7 @@ class ShardConfig: enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False parallel_output: bool = True + test_seq_parallelism: bool = False extra_kwargs: Dict[str, Any] = field(default_factory=dict) # TODO padding vocab # make_vocab_size_divisible_by: int = 128 diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 4730642705ff..9cc3f03a67cc 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -27,10 +27,29 @@ def data_gen(): # tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') # ----------------------------------- + # input_ids = torch.Tensor( + # [[1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082]] + # ).long() + input_ids = torch.Tensor( - [[1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082]] + [ + [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], + [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], + [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], + [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], + ] + ).long() + + # attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long() + attention_mask = torch.Tensor( + [ + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ] ).long() - attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long() + return dict(input_ids=input_ids, attention_mask=attention_mask) # label is needed for casual lm @@ -49,7 +68,7 @@ def data_gen_for_casual_lm(): loss_fn_for_seq_classification = lambda output: output["logits"].mean() config = LlamaConfig( - num_hidden_layers=4, + num_hidden_layers=1, hidden_size=128, intermediate_size=256, num_attention_heads=4, diff --git a/tests/test_shardformer/test_layer/demo.py b/tests/test_shardformer/test_layer/demo.py new file mode 100644 index 000000000000..33bf7a7fac2e --- /dev/null +++ b/tests/test_shardformer/test_layer/demo.py @@ -0,0 +1,161 @@ +import os +import time + +import torch +import torch.distributed as dist +import torch.nn.functional as F + + +class AllGatherLinearWithRingCommunication(torch.autograd.Function): + """ + col-linear with hidden all_gather + + Y: [batch_size, seq_len / TP_size, hidden_size] + A: [batch_size, hidden_size, w_len / TP_size] + | + | Ring-based LinearOverlap + v + YA: [batch_size, seq_len, w_len / TP_size] + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group): + # Input expected: (input_, bias) sharded on the row(sequence dim) and weight on the col + ctx.save_for_backward(input_, weight, bias) + ctx.use_bias = bias is not None + ctx.process_group = process_group + + # if bias is not None: + # output = F.linear(input_, weight, bias) + # else: + # output = F.linear(input_, weight) + group_size = dist.get_world_size(process_group) + cur_rank = dist.get_rank(process_group) + + input_shape = input_.shape + weight_shape = weight.shape + + output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)] + # output_tensor = torch.empty((input_shape[0], input_shape[1] * group_size, weight_shape[0]), device=input_.device) + + # initialization of ring communication + input_shape[1] + recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 + send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1 + recv_tensor = input_.clone() + send_tensor = input_.clone() + input_tensor = input_.clone() + # output_pt = output_tensor + + recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) + send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) + handles = dist.batch_isend_irecv([recv_op, send_op]) + # first round: special case, retrive from local tensor + output_tensors[0] = F.linear(input_, weight) + # output_pt[:][:local_seq_len][:] = F.linear(input_, weight) + # output_pt = output_pt[:][local_seq_len:][:] + for i in range(group_size - 2): + handles[0].wait() + handles[1].wait() + + tmp_tensor = input_tensor + input_tensor = recv_tensor + recv_tensor = tmp_tensor + send_tensor = input_tensor.clone() + + recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) + send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) + handles = dist.batch_isend_irecv([recv_op, send_op]) + + # actual computation + output_tensors[i + 1] = F.linear(input_tensor, weight) + # output_pt[:][:local_seq_len][:] = F.linear(input_, weight) + # output_pt = output_pt[:][local_seq_len:][:] + + # final round: special case, no need to send/recv again + handles[0].wait() + # output_pt[:][:local_seq_len][:] = F.linear(recv_tensor, weight) + output_tensors[group_size - 1] = F.linear(recv_tensor, weight) + handles[1].wait() + return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=1) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + input, weight, bias = ctx.saved_tensors + use_bias = ctx.use_bias + + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. + if use_bias: + bias.view(bias.shape) + + total_input = input + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # all-reduce scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + + +def main(): + dist.init_process_group("nccl") + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + + scale_c = 4 + y = torch.randn(4, 5120 * scale_c, 1024 * scale_c, requires_grad=False).cuda() + w = torch.randn(256 * scale_c, 1024 * scale_c, requires_grad=False).cuda() + + trial_time = 5 + + ## warm up + tensor_list = [torch.zeros_like(y) for _ in range(4)] + dist.all_gather(tensor_list, y) + + Y = torch.cat(tensor_list, dim=1) + torch_out = F.linear(Y, w) + + ring_out = AllGatherLinearWithRingCommunication.apply(y, w, None, None) + ## + + tic = time.perf_counter() + for _ in range(trial_time): + tensor_list = [torch.zeros_like(y) for _ in range(4)] + dist.all_gather(tensor_list, y) + + Y = torch.cat(tensor_list, dim=1) + torch_out = F.linear(Y, w) + print(torch_out[0][0][0]) + toc = time.perf_counter() + print(f"original function in {toc - tic:0.4f} seconds") + + tic = time.perf_counter() + for _ in range(trial_time): + ring_out = AllGatherLinearWithRingCommunication.apply(y, w, None, None) + print(ring_out[0][0][0]) + toc = time.perf_counter() + print(f"fused function in {toc - tic:0.4f} seconds") + + if not torch.allclose(torch_out, ring_out, atol=1e-3): + raise RuntimeError("ring_overlap: failed!!") + print("ring_overlap: pass.") + + +if __name__ == "__main__": + main() diff --git a/tests/test_shardformer/test_layer/test_sequence_parallel.py b/tests/test_shardformer/test_layer/test_sequence_parallel.py new file mode 100644 index 000000000000..c2ad6918cd2b --- /dev/null +++ b/tests/test_shardformer/test_layer/test_sequence_parallel.py @@ -0,0 +1,248 @@ +import copy + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer import all_to_all_comm +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +class DistributedAttention(torch.nn.Module): + """Initialization. + + Arguments: + local_attention (Module): local attention with q,k,v + sequence_process_group (ProcessGroup): sequence parallel process group + scatter_idx (int): scatter_idx for all2all comm + gather_idx (int): gather_idx for all2all comm + """ + + def __init__( + self, + heads_num, + hidden_dim, + q_proj, + k_proj, + v_proj, + out_proj, + sequence_process_group: dist.ProcessGroup = None, + scatter_idx: int = 2, + gather_idx: int = 1, + ) -> None: + super(DistributedAttention, self).__init__() + self.spg = sequence_process_group + self.scatter_idx = scatter_idx + self.gather_idx = gather_idx + self.heads_num = heads_num + self.hidden_dim = hidden_dim + assert hidden_dim % heads_num == 0 + self.head_dim = hidden_dim // heads_num + + self.q = q_proj + self.k = k_proj + self.v = v_proj + self.out = out_proj + + def attn(self, q, k, v): + batch_size, seq_len = q.shape[0], q.shape[1] + + scale = self.head_dim**0.5 + qk = torch.matmul(q, k.transpose(-2, -1)) / scale + + # if attn_mask is not None: + # mask = attn_mask == 0 + # qk[mask] = torch.tensor(float('-inf')) + + weights = F.softmax(qk, dim=-1) + + attention_score = torch.matmul(weights, v) + + return attention_score + + def forward(self, x) -> Tensor: + """forward + + Arguments: + query (Tensor): query input to the layer + key (Tensor): key input to the layer + value (Tensor): value input to the layer + args: other args + + Returns: + * output (Tensor): context output + """ + + # in shape : e.g., [s/p:h:] + query = self.q(x) + key = self.k(x) + value = self.v(x) + # TODO Merge three alltoall calls into one + query_layer = all_to_all_comm(query, self.spg, self.scatter_idx, self.gather_idx) + key_layer = all_to_all_comm(key, self.spg, self.scatter_idx, self.gather_idx) + value_layer = all_to_all_comm(value, self.spg, self.scatter_idx, self.gather_idx) + + # out shape : e.g., [s:h/p:] + attn_score = self.attn(query_layer, key_layer, value_layer) + + output = all_to_all_comm(attn_score, self.spg, self.gather_idx, self.scatter_idx) + + # output e.g., [s/p::h] + output = self.out(output) + + return output + + +class MultiHeadAttn(nn.Module): + def __init__(self, head_num, hidden_dim, q_proj, k_proj, v_proj, out_proj): + super(MultiHeadAttn, self).__init__() + self.head_num = head_num + self.hidden_dim = hidden_dim + assert hidden_dim % head_num == 0 + self.head_dim = hidden_dim // head_num + + self.q = q_proj + self.k = k_proj + self.v = v_proj + self.out = out_proj + + def attn(self, q, k, v): + batch_size, seq_len = q.shape[0], q.shape[1] + + scale = self.head_dim**0.5 + qk = torch.matmul(q, k.transpose(-2, -1)) / scale + + # if attn_mask is not None: + # mask = attn_mask == 0 + # qk[mask] = torch.tensor(float('-inf')) + + weights = F.softmax(qk, dim=-1) + + attention_score = torch.matmul(weights, v) + + return attention_score + + def split(self, x, batch_size, seq_len): + res = x.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2) + return res + + def forward(self, x): + batch_size, seq_len, hidden_dim = x.shape + assert hidden_dim == self.hidden_dim, "hidden_dim should be equal to self.hidden_dim" + query_mha = self.split(self.q(x), batch_size, seq_len) + key_mha = self.split(self.k(x), batch_size, seq_len) + value_mha = self.split(self.v(x), batch_size, seq_len) + score_mha = self.attn(query_mha, key_mha, value_mha) + score_mha_final = score_mha.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_dim) + output_mha = self.out(score_mha_final) + + return output_mha + + +def seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size): + seq_len = seq_len + hidden_dim = hidden_dim + head_num = head_num + batch_size = batch_size + world_size = dist.get_world_size() + + q_proj = nn.Linear(hidden_dim, hidden_dim) + k_proj = nn.Linear(hidden_dim, hidden_dim) + v_proj = nn.Linear(hidden_dim, hidden_dim) + out_proj = nn.Linear(hidden_dim, hidden_dim) + + q_proj_copy = copy.deepcopy(q_proj) + k_proj_copy = copy.deepcopy(k_proj) + v_proj_copy = copy.deepcopy(v_proj) + out_proj_copy = copy.deepcopy(out_proj) + + x = torch.randn(batch_size, seq_len, hidden_dim).cuda() + x_unshard = x.clone() + x_unshard.requires_grad_(True) + x_input = torch.chunk(x.clone(), world_size, dim=1)[dist.get_rank()] + x_input.requires_grad_(True) + + # x_unshard = torch.randn(batch_size, seq_len, hidden_dim).cuda() + # x_unshard.requires_grad_(True) + # x_input = torch.chunk(x_unshard.clone(), world_size, dim=1)[dist.get_rank()] + # x_input.requires_grad_(True) + + # Multi-head Attention + mhn = MultiHeadAttn(head_num, hidden_dim, q_proj, k_proj, v_proj, out_proj).cuda() + # Multi-head Attention forward + mhn_out = mhn(x_unshard) + + # Sequence parallel Attention + dist_attn = DistributedAttention(head_num, hidden_dim, q_proj_copy, k_proj_copy, v_proj_copy, out_proj_copy).cuda() + # Sequence parallel Attention forward + dist_attn_out = dist_attn(x_input) + # gather the output of sequence parallel attention + out_list = [torch.empty_like(dist_attn_out) for _ in range(world_size)] + dist.all_gather(out_list, dist_attn_out) + seq_out = torch.cat(out_list, dim=1) + + # forward result check + assert_close(seq_out, mhn_out) + + # Multi-head Attention backward + mhn_out.sum().backward() + q_grad = mhn.q.weight.grad + k_grad = mhn.k.weight.grad + v_grad = mhn.v.weight.grad + o_grad = mhn.out.weight.grad + x_grad = x_unshard.grad + + # Sequence parallel Attention backward + dist_attn_out.sum().backward() + q_grad_seq = dist_attn.q.weight.grad + k_grad_seq = dist_attn.k.weight.grad + v_grad_seq = dist_attn.v.weight.grad + o_grad_seq = dist_attn.out.weight.grad + x_grad_seq = x_input.grad + # all_reduce the grad of sequence parallel attention weight + dist.all_reduce(q_grad_seq) + dist.all_reduce(k_grad_seq) + dist.all_reduce(v_grad_seq) + dist.all_reduce(o_grad_seq) + # gather the grad of sequence parallel attention input + x_grad_seq_list = [torch.empty_like(x_grad_seq) for _ in range(world_size)] + dist.all_gather(x_grad_seq_list, x_grad_seq) + x_grad_seq_gather = torch.cat(x_grad_seq_list, dim=1) + + # backward result check + assert_close(q_grad_seq, q_grad) + assert_close(k_grad_seq, k_grad) + assert_close(v_grad_seq, v_grad) + assert_close(o_grad_seq, o_grad) + assert_close(x_grad_seq_gather, x_grad) + + # print_rank('x_grad', x_grad_seq, 0) + # print_rank('x_grad', x_grad_seq, 1) + # print_rank('x_grad', x_grad_seq, 2) + # print_rank('x_grad', x_grad_seq, 3) + + +@parameterize("seq_len", [128]) +@parameterize("hidden_dim", [64]) +@parameterize("head_num", [4]) +@parameterize("batch_size", [4]) +def run_seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size): + seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size) + + +def check_all2all_attn(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_seq_parallel_attn() + + +@rerun_if_address_is_in_use() +def test_all_to_all_attention(): + spawn(check_all2all_attn, nprocs=4) + + +if __name__ == "__main__": + test_all_to_all_attention() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 85be9a242715..5bb347ab049f 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -161,6 +161,16 @@ def _criterion(outputs, inputs): return loss data = data_gen_fn() + shard_test_data = {} + for k, v in data.items(): + shard_test_data[k] = ( + data[k].clone() + if booster.plugin.shard_config.test_seq_parallelism is False + else torch.chunk(data[k].clone(), dist.get_world_size(), dim=1)[dist.get_rank()] + ) + unshard_test_data = {} + for k, v in data.items(): + unshard_test_data[k] = data[k].clone() if booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.tp_size != 0: seq_len = data["input_ids"].shape[-1] @@ -190,15 +200,15 @@ def _criterion(outputs, inputs): ) sharded_loss = sharded_output["loss"] else: - data = {k: v.cuda() for k, v in data.items()} - sharded_output = sharded_model(**data) + shard_test_data = {k: v.cuda() for k, v in shard_test_data.items()} + sharded_output = sharded_model(**shard_test_data) sharded_loss = criterion(sharded_output) sharded_optimizer.backward(sharded_loss) org_model.train() - data = {k: v.cuda() for k, v in data.items()} - org_output = org_model(**data) + unshard_test_data = {k: v.cuda() for k, v in unshard_test_data.items()} + org_output = org_model(**unshard_test_data) org_loss = criterion(org_output) org_loss.backward() diff --git a/tests/test_shardformer/test_model/test_llama_seq.py b/tests/test_shardformer/test_model/test_llama_seq.py new file mode 100644 index 000000000000..e4931b97c241 --- /dev/null +++ b/tests/test_shardformer/test_model/test_llama_seq.py @@ -0,0 +1,140 @@ +import os + +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + + +def print_rank(prompt, value, rank=0): + if dist.get_rank() == rank: + print(f"rank-{rank}, {prompt}: {value}") + + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwrap model + llama_model = unwrap_model(org_model, "LlamaModel", "model") + shard_llama_model = unwrap_model(sharded_model, "LlamaModel", "model") + + row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] + col_layer_for_check = ["layers[0].self_attn.o_proj"] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config["precision"] == "fp32": + atol, rtol = 1e-6, 1e-4 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check( + llama_model, shard_llama_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + col_layer_grads = get_grad_tensors_for_check( + llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == "LlamaModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if stage_manager is None or stage_manager.is_first_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + check_weight( + llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + + # check grads + check_all_grad_tensors(grads_to_check) + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 1, + "pp_size": 1, + "num_microbatches": 1, + "test_seq_parallelism": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + } + ], +) +def run_llama_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name != "transformers_llama": + continue + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_llama_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(): + spawn(check_llama, 4) + + +if __name__ == "__main__": + test_llama() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 126ff23a9f25..f4822fc79b69 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -2,6 +2,7 @@ import pytest import torch +import torch.distributed as dist import colossalai from colossalai.logging import disable_existing_loggers @@ -20,6 +21,12 @@ unwrap_model, ) + +def print_rank(prompt, value, rank=0): + if dist.get_rank() == rank: + print(f"rank-{rank}, {prompt}: {value}") + + os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" From fe5fac2057b65da1acaf061d8b786cb8d3bacf6c Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 3 Jan 2024 17:30:14 +0800 Subject: [PATCH 02/50] validate sequence parallel in llama (code to be polished) --- .../booster/plugin/hybrid_parallel_plugin.py | 9 + colossalai/shardformer/layer/utils.py | 17 +- colossalai/shardformer/modeling/llama.py | 175 +++++++++++++++--- colossalai/shardformer/policies/llama.py | 10 +- tests/kit/model_zoo/transformers/llama.py | 4 +- tests/test_shardformer/test_model/_utils.py | 32 ++-- .../test_model/test_llama_seq.py | 5 +- 7 files changed, 205 insertions(+), 47 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 537146c1e470..536494a2abe7 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -168,6 +168,15 @@ def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None): Returns: None """ + + if self.shard_config.test_seq_parallelism: + if grads is not None: + # Synchronize provided gradient tensors across the tensor parallelism group. + SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads) + else: + # Synchronize gradients from the model across the tensor parallelism group. + SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module) + if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism: if grads is not None: # Synchronize provided gradient tensors across the tensor parallelism group. diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 0d2cc1b3370d..bef51a3031b2 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -53,23 +53,24 @@ def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None." # Get the size of the process group, which determines whether synchronization is needed. - tp_size = get_world_size(tp_group) if tp_group is not None else 1 - - if tp_size == 1: - # If the process group size is 1, no synchronization is required. - return + # tp_size = get_world_size(tp_group) if tp_group is not None else 1 + # if tp_size == 1: + # # If the process group size is 1, no synchronization is required. + # return + print("在这呢") if model is not None: # If `model` is provided, extract partial derived gradients from the model's parameters. grads = [] for p in model.parameters(): - if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p): + if p.grad is not None: grads.append(p.grad.data) + print("在这呢", len(grads)) # Flatten and reduce the gradients using the specified process group. coalesced = _flatten_dense_tensors(grads) - dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group) - + dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=None) + print("all-reduce了") # Unflatten the synchronized gradients and update the model's gradients. for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 52d3bf18d439..6317a578a004 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -1,6 +1,6 @@ import math import warnings -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -20,7 +20,7 @@ from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, cross_entropy_1d -from colossalai.shardformer.layer._operation import all_to_all_comm +from colossalai.shardformer.layer._operation import all_to_all_comm, gather_forward_split_backward def print_rank(prompt, value, rank=0): @@ -735,12 +735,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - print_rank("cos-0", cos.shape) # torch.Size([16, 32]) - print_rank("cos-position", cos[position_ids]) # torch.Size([1, 1, 4, 32]) - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - print_rank("cos-1", cos.shape) # torch.Size([1, 1, 4, 32]) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed @@ -765,7 +761,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - print_rank("hidden_states-origin", hidden_states.shape) + ret = None if self.config.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp query_slices = self.q_proj.weight.split( @@ -787,6 +783,7 @@ def forward( query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) + ret = (query_states,) # introduce sequence parallel query_states = all_to_all_comm(query_states) @@ -803,18 +800,15 @@ def forward( past_key_value_length = 0 if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - + past_key_value_length = past_key_value[0].shape[2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - print_rank("cos", cos.shape) - - print_rank("position_ids", position_ids) - if position_ids is not None: - position_ids = torch.arange( - past_key_value_length, q_len + past_key_value_length, dtype=torch.long, device=hidden_states.device - ) - - print_rank("position_ids-2", position_ids) + # modify position ids + position_ids_device = position_ids.device if position_ids is not None else hidden_states.device + position_ids = torch.arange( + past_key_value_length, q_len + past_key_value_length, dtype=torch.long, device=position_ids_device + ) + position_ids = position_ids.unsqueeze(0).view(-1, q_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -846,7 +840,6 @@ def forward( # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" @@ -854,7 +847,9 @@ def forward( ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + # attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + attn_output = all_to_all_comm(attn_output, None, scatter_dim=1, gather_dim=2) if self.config.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) @@ -862,10 +857,148 @@ def forward( attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) else: attn_output = self.o_proj(attn_output) - + # ret = (past_key_value, ) if not output_attentions: attn_weights = None + return attn_output, attn_weights, past_key_value, ret + + return forward + + +def test_llama_seq_parallel_model(): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length + ) - return attn_output, attn_weights, past_key_value + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + decoder_output = [] + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + decoder_output.append((layer_outputs[-2], layer_outputs[-1])) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + hidden_states = gather_forward_split_backward(hidden_states, 1, None) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + ret_model = "test_shard_model" + + return ( + BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ), + decoder_output, + ret_model, + ) return forward diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 6a8536f893ba..9aa1053b1ec2 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -15,6 +15,7 @@ get_llama_model_forward_for_flash_attn, get_lm_forward_with_dist_cross_entropy, test_llama_seq_parallel_attention, + test_llama_seq_parallel_model, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -63,7 +64,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: self.model.config.num_key_value_heads // sequence_parallelism_size ) decoder_attribute_replacement["num_key_value_groups"] = ( - self.model.config.hidden_size // self.model.config.num_attention_heads + self.model.config.num_attention_heads // self.model.config.num_key_value_heads ) policy[LlamaAttention] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, @@ -76,6 +77,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=LlamaAttention, ) + self.append_or_create_method_replacement( + description={ + "forward": test_llama_seq_parallel_model(), + }, + policy=policy, + target_key=LlamaModel, + ) if self.shard_config.enable_tensor_parallelism: decoder_attribute_replacement = { diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 9cc3f03a67cc..288d4d71b322 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -68,8 +68,8 @@ def data_gen_for_casual_lm(): loss_fn_for_seq_classification = lambda output: output["logits"].mean() config = LlamaConfig( - num_hidden_layers=1, - hidden_size=128, + num_hidden_layers=2, + hidden_size=64, intermediate_size=256, num_attention_heads=4, max_position_embeddings=128, diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 5bb347ab049f..2d907fa4063c 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -23,6 +23,11 @@ from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +def print_rank(prompt, value, rank=0): + if dist.get_rank() == rank: + print(f"rank-{rank}, {prompt}: {value}") + + def build_model( model_fn, enable_fused_normalization=True, @@ -123,7 +128,6 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c sharded_model = copy.deepcopy(org_model) if use_lazy_init: ctx.materialize(org_model) - org_model = org_model.cuda() org_optimizer = Adam(org_model.parameters(), lr=1e-3) sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3) @@ -163,11 +167,15 @@ def _criterion(outputs, inputs): data = data_gen_fn() shard_test_data = {} for k, v in data.items(): - shard_test_data[k] = ( - data[k].clone() - if booster.plugin.shard_config.test_seq_parallelism is False - else torch.chunk(data[k].clone(), dist.get_world_size(), dim=1)[dist.get_rank()] - ) + if k == "attention_mask": + shard_test_data[k] = data[k].clone() + else: + shard_test_data[k] = ( + data[k].clone() + if booster.plugin.shard_config.test_seq_parallelism is False + # else data[k].clone() + else torch.chunk(data[k].clone(), dist.get_world_size(), dim=1)[dist.get_rank()] + ) unshard_test_data = {} for k, v in data.items(): unshard_test_data[k] = data[k].clone() @@ -201,15 +209,12 @@ def _criterion(outputs, inputs): sharded_loss = sharded_output["loss"] else: shard_test_data = {k: v.cuda() for k, v in shard_test_data.items()} - sharded_output = sharded_model(**shard_test_data) - + sharded_output, decoder_output_shard, ret_model_shard = sharded_model(**shard_test_data) sharded_loss = criterion(sharded_output) sharded_optimizer.backward(sharded_loss) - org_model.train() unshard_test_data = {k: v.cuda() for k, v in unshard_test_data.items()} - org_output = org_model(**unshard_test_data) - + org_output, decoder_output_ori, ret_model_ori = org_model(**unshard_test_data) org_loss = criterion(org_output) org_loss.backward() @@ -222,13 +227,16 @@ def check_output_hidden_state( stage_manager: Optional[PipelineStageManager] = None, atol: float = 1e-5, rtol: float = 1e-3, - dim: int = 0, + booster: Booster = None, ): org_hidden_state = org_output.last_hidden_state if stage_manager and stage_manager.is_last_stage(ignore_chunk=True): sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] else: + # if booster and booster.plugin.shard_config.test_seq_parallelism: + # sharded_hidden_state = sharded_output + # else: sharded_hidden_state = sharded_output.last_hidden_state assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol) diff --git a/tests/test_shardformer/test_model/test_llama_seq.py b/tests/test_shardformer/test_model/test_llama_seq.py index e4931b97c241..061d83bd9138 100644 --- a/tests/test_shardformer/test_model/test_llama_seq.py +++ b/tests/test_shardformer/test_model/test_llama_seq.py @@ -76,7 +76,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == "LlamaModel": - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, booster=booster) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) @@ -92,7 +92,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grads check_all_grad_tensors(grads_to_check) - torch.cuda.empty_cache() @@ -105,7 +104,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 1, "test_seq_parallelism": True, "use_lazy_init": True, - "precision": "fp16", + "precision": "fp32", "initial_scale": 1, } ], From ee95f9406fe3c5c58934664e5e7d3a28aa9d2451 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Fri, 5 Jan 2024 17:27:38 +0800 Subject: [PATCH 03/50] shardformer api writing --- .../booster/plugin/hybrid_parallel_plugin.py | 41 ++++++++++-- colossalai/cluster/process_group_mesh.py | 1 - colossalai/shardformer/layer/linear.py | 25 ++++---- colossalai/shardformer/layer/utils.py | 17 +++-- colossalai/shardformer/modeling/bert.py | 20 +++--- colossalai/shardformer/modeling/bloom.py | 18 +++--- colossalai/shardformer/modeling/chatglm2.py | 22 +++---- colossalai/shardformer/modeling/gpt2.py | 22 +++---- colossalai/shardformer/modeling/llama.py | 30 ++++----- colossalai/shardformer/policies/bert.py | 22 ++++--- colossalai/shardformer/shard/shard_config.py | 55 +++++++++++++--- tests/test_shardformer/test_model/_utils.py | 62 +++++++++++++++---- .../test_model/test_shard_bert.py | 8 +++ 13 files changed, 227 insertions(+), 116 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 536494a2abe7..5bdd16188ceb 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -34,7 +34,8 @@ from .pp_plugin_base import PipelinePluginBase -DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 +DP_AXIS, PP_AXIS, TP_AXIS, SP_AXIS = 0, 1, 2, 3 +SUPPORT_SP_MODE = ["1", "2", "3"] PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} @@ -946,6 +947,7 @@ def __init__( self, tp_size: int, pp_size: int, + sp_size: int = 1, precision: str = "fp16", zero_stage: int = 0, enable_all_optimization: bool = False, @@ -953,6 +955,7 @@ def __init__( enable_flash_attention: bool = False, enable_jit_fused: bool = False, enable_sequence_parallelism: bool = False, + sequence_parallelism_mode: str = None, enable_sequence_overlap: bool = False, parallel_output: bool = True, num_microbatches: Optional[int] = None, @@ -984,14 +987,36 @@ def __init__( super().__init__() assert ( dist.get_world_size() % (tp_size * pp_size) == 0 - ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" if enable_sequence_parallelism: - assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism" + self.sequence_parallelism_mode = sequence_parallelism_mode if sequence_parallelism_mode is not None else "1" + assert ( + self.sequence_parallelism_mode in SUPPORT_SP_MODE + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" + if self.sequence_parallelism_mode in ["1", "2"]: + assert ( + tp_size > 1 + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" + if sp_size != 1: + warnings.warn( + f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." + ) + self.sp_size = tp_size + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + elif self.sequence_parallelism_mode in ["3"]: + assert ( + tp_size == 1 + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with tensor parallelism" + self.sp_size = sp_size + self.dp_size = dist.get_world_size() // (sp_size * pp_size) + else: + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + assert sp_size == 1, f"sp_size can only be set to a >1 number when enable_sequence_parallelism is True" + self.sp_size = 1 self.tp_size = tp_size self.pp_size = pp_size - self.dp_size = dist.get_world_size() // (tp_size * pp_size) self.precision = precision self.zero_stage = zero_stage self.cpu_offload = cpu_offload @@ -1000,7 +1025,7 @@ def __init__( self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism - self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) self.stage_manager = None self.schedule = None self.custom_policy = custom_policy @@ -1041,9 +1066,14 @@ def __init__( self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["1", "2"]: + self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + else: + self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS) self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, + sequence_parallel_process_group=self.sp_group, pipeline_stage_manager=self.stage_manager, enable_tensor_parallelism=self.tp_size > 1, enable_all_optimization=self.enable_all_optimization, @@ -1051,6 +1081,7 @@ def __init__( enable_flash_attention=self.enable_flash_attention, enable_jit_fused=self.enable_jit_fused, enable_sequence_parallelism=enable_sequence_parallelism, + sequence_parallelism_mode=sequence_parallelism_mode, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, test_seq_parallelism=test_seq_parallelism, diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index ae3956c693ab..1f32541a7b21 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -225,4 +225,3 @@ def get_group_along_axis( # no need to cache it explicitly, since it will be cached in `create_group_along_axis` return self.create_group_along_axis(axis, indices_at_axis, backend=backend) return self._ranks_to_group[ranks_in_group] - \ No newline at end of file diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index eeb0ef39975f..069f7c8fd8c6 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -74,7 +74,7 @@ def __init__( device: torch.device = None, process_group: ProcessGroup = None, gather_output: bool = False, - seq_parallel: bool = False, + seq_parallel_mode: str = None, seq_parallel_dim: int = 1, overlap: torch.cuda.Stream = None, skip_bias_add: bool = False, @@ -89,7 +89,7 @@ def __init__( self.in_features = in_features self.out_features = out_features self.gather_output = gather_output - self.seq_parallel = seq_parallel + self.seq_parallel_mode = seq_parallel_mode self.seq_parallel_dim = seq_parallel_dim self.overlap = overlap self.skip_bias_add = skip_bias_add @@ -196,12 +196,13 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - if self.seq_parallel: + + if self.seq_parallel_mode is None: + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + elif self.seq_parallel_mode == "1": output_parallel = linear_gather_forward_reducescatter_backward( input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap ) - else: - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) if self.gather_output: # All-gather across the partitions. @@ -245,7 +246,8 @@ def __init__( dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, - seq_parallel: bool = False, + # seq_parallel: bool = False, + seq_parallel_mode: str = None, seq_parallel_dim: int = 1, parallel_input: bool = True, skip_bias_add: bool = False, @@ -265,7 +267,7 @@ def __init__( self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group - self.seq_parallel = seq_parallel + self.seq_parallel_mode = seq_parallel_mode self.seq_parallel_dim = seq_parallel_dim self.num_partitions = dist.get_world_size(self.process_group) @@ -408,13 +410,14 @@ def forward(self, input_: Tensor) -> Tensor: handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - output_parallel = linear_with_async_comm(input_, self.weight, None, None, False) - if self.seq_parallel: + output_parallel = F.linear(input_, self.weight) + + if self.seq_parallel_mode is None: + output = reduce_forward(output_parallel, self.process_group) + elif self.seq_parallel_mode == "1": output = linear_reducescatter_forward_gather_backward( output_parallel, self.process_group, self.seq_parallel_dim ) - else: - output = reduce_forward(output_parallel, self.process_group) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index bef51a3031b2..0d2cc1b3370d 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -53,24 +53,23 @@ def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None." # Get the size of the process group, which determines whether synchronization is needed. - # tp_size = get_world_size(tp_group) if tp_group is not None else 1 + tp_size = get_world_size(tp_group) if tp_group is not None else 1 + + if tp_size == 1: + # If the process group size is 1, no synchronization is required. + return - # if tp_size == 1: - # # If the process group size is 1, no synchronization is required. - # return - print("在这呢") if model is not None: # If `model` is provided, extract partial derived gradients from the model's parameters. grads = [] for p in model.parameters(): - if p.grad is not None: + if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p): grads.append(p.grad.data) - print("在这呢", len(grads)) # Flatten and reduce the gradients using the specified process group. coalesced = _flatten_dense_tensors(grads) - dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=None) - print("all-reduce了") + dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group) + # Unflatten the synchronized gradients and update the model's gradients. for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 7411e1d0ec46..99cd7acd312e 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -186,13 +186,14 @@ def bert_model_forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config is not None and shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group - ) - if encoder_hidden_states is not None: - encoder_hidden_states = split_forward_gather_backward( - encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + if shard_config.sequence_parallelism_mode == "1": + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group ) + if encoder_hidden_states is not None: + encoder_hidden_states = split_forward_gather_backward( + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): if stage_manager.is_first_stage() and idx == 0: @@ -240,9 +241,10 @@ def custom_forward(*inputs): # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config is not None and shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group - ) + if shard_config.sequence_parallelism_mode == "1": + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index d94c30d29e71..370017384616 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -213,10 +213,11 @@ def bloom_model_forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - if shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "1": + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) start_idx, end_idx = stage_index[0], stage_index[1] for i, (block, layer_past) in enumerate( @@ -261,10 +262,11 @@ def custom_forward(*inputs): all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) # When sequence parallelism done, gather the output tensor in forward and split it in backward - if shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "1": + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) if stage_manager.is_last_stage(): # Add last hidden state diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index a3e000e6ef66..9bfb7053943b 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -191,12 +191,11 @@ def chatglm_model_forward( all_hidden_states = () if output_hidden_states else None start_idx, end_idx = stage_index[0], stage_index[1] - if shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward( - hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group, - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "1": + hidden_states = split_forward_gather_backward( + hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + ) for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) if output_hidden_states: @@ -222,12 +221,11 @@ def chatglm_model_forward( if use_cache: presents = presents + (kv_cache,) - if shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward( - hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group, - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "1": + hidden_states = gather_forward_split_backward( + hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index ea22cfb15a33..1d034fff1639 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -218,12 +218,11 @@ def gpt2_model_forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - if shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "1": + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) # Going through held blocks. start_idx, end_idx = stage_index[0], stage_index[1] @@ -278,12 +277,11 @@ def custom_forward(*inputs): all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) # When sequence parallelism done, gather the output tensor in forward and split it in backward - if shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "1": + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) if stage_manager.is_last_stage(): hidden_states = self.ln_f(hidden_states) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 6317a578a004..81af2cf1acd3 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -466,9 +466,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -761,7 +759,6 @@ def forward( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - ret = None if self.config.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp query_slices = self.q_proj.weight.split( @@ -783,7 +780,6 @@ def forward( query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - ret = (query_states,) # introduce sequence parallel query_states = all_to_all_comm(query_states) @@ -857,10 +853,10 @@ def forward( attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) else: attn_output = self.o_proj(attn_output) - # ret = (past_key_value, ) + if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value, ret + return attn_output, attn_weights, past_key_value return forward @@ -938,7 +934,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None - decoder_output = [] + for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -968,7 +964,7 @@ def custom_forward(*inputs): output_attentions=output_attentions, use_cache=use_cache, ) - decoder_output.append((layer_outputs[-2], layer_outputs[-1])) + hidden_states = layer_outputs[0] if use_cache: @@ -979,7 +975,9 @@ def custom_forward(*inputs): hidden_states = self.norm(hidden_states) + # Todo: Maybe this line can be optimized hidden_states = gather_forward_split_backward(hidden_states, 1, None) + # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) @@ -988,17 +986,11 @@ def custom_forward(*inputs): if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - ret_model = "test_shard_model" - - return ( - BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ), - decoder_output, - ret_model, + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, ) return forward diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 0ab63b7650c1..1a5dba0e0064 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -66,8 +66,10 @@ def module_policy(self): else: norm_cls = col_nn.LayerNorm - use_sequence_parallel = self.shard_config.enable_sequence_parallelism + sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None overlap = self.shard_config.enable_sequence_overlap + sp_partial_derived = sp_mode in ["1"] + if self.shard_config.enable_tensor_parallelism: policy[BertLayer] = ModulePolicyDescription( attribute_replacement={ @@ -84,17 +86,17 @@ def module_policy(self): SubModuleReplacementDescription( suffix="attention.self.query", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap}, ), SubModuleReplacementDescription( suffix="attention.self.key", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap}, ), SubModuleReplacementDescription( suffix="attention.self.value", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap}, ), SubModuleReplacementDescription( suffix="attention.self.dropout", @@ -103,7 +105,7 @@ def module_policy(self): SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={"seq_parallel_mode": sp_mode}, ), SubModuleReplacementDescription( suffix="attention.output.dropout", @@ -112,12 +114,12 @@ def module_policy(self): SubModuleReplacementDescription( suffix="intermediate.dense", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap}, ), SubModuleReplacementDescription( suffix="output.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={"seq_parallel_mode": sp_mode}, ), SubModuleReplacementDescription( suffix="output.dropout", @@ -139,7 +141,7 @@ def module_policy(self): ] ) - if use_sequence_parallel: + if sp_mode == "1": self.append_or_create_method_replacement( description={"forward": bert_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, @@ -153,12 +155,12 @@ def module_policy(self): SubModuleReplacementDescription( suffix="attention.output.LayerNorm", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="output.LayerNorm", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index b805bac48d65..6a3bf38fa5f9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -1,3 +1,4 @@ +import warnings from dataclasses import dataclass, field from typing import Any, Dict, Optional @@ -7,6 +8,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager __all__ = ["ShardConfig"] +SUPPORT_SP_MODE = ["1", "2", "3"] @dataclass @@ -26,13 +28,15 @@ class ShardConfig: enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. """ tensor_parallel_process_group: Optional[ProcessGroup] = None + sequence_parallel_process_group: Optional[ProcessGroup] = None pipeline_stage_manager: Optional[PipelineStageManager] = None enable_tensor_parallelism: bool = True + enable_all_optimization: bool = False enable_fused_normalization: bool = False enable_flash_attention: bool = False enable_jit_fused: bool = False - enable_all_optimization: bool = False enable_sequence_parallelism: bool = False + sequence_parallelism_mode: str = None enable_sequence_overlap: bool = False parallel_output: bool = True test_seq_parallelism: bool = False @@ -47,18 +51,53 @@ class ShardConfig: def tensor_parallel_size(self): return self._tensor_parallel_size + @property + def sequence_parallel_size(self): + return self._sequence_parallel_size + def __post_init__(self): - if not self.enable_tensor_parallelism and self.enable_sequence_parallelism: - raise ValueError( - "enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True" + if self.enable_sequence_parallelism: + self.sequence_parallelism_mode = ( + "1" if self.sequence_parallelism_mode is None else self.sequence_parallelism_mode ) - if not self.enable_sequence_parallelism and self.enable_sequence_overlap: - raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True") + assert ( + self.sequence_parallelism_mode in SUPPORT_SP_MODE + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" + if self.sequence_parallelism_mode in ["1", "2"]: + assert ( + self.enable_tensor_parallelism + ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True" + elif self.sequence_parallelism_mode in ["3"]: + assert ( + not self.enable_tensor_parallelism + ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False" + if self.enable_sequence_overlap: + self.enable_sequence_overlap = False + warnings.warn( + f"The enable_sequence_overlap flag will be ignored in sequence parallelism mode {self.sequence_parallelism_mode}" + ) + else: + if self.sequence_parallelism_mode: + self.sequence_parallelism_mode = None + warnings.warn( + f"The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False" + ) + assert ( + not self.enable_sequence_overlap + ), f"enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True" + + # get the tensor parallel size if not self.enable_tensor_parallelism: self._tensor_parallel_size = 1 else: - # get the parallel size self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + + # get the sequence parallel size + if not self.enable_sequence_parallelism: + self._sequence_parallel_size = 1 + else: + self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group) + # turn on all optimization if all_optimization is set to True if self.enable_all_optimization: self._turn_on_all_optimization() @@ -73,6 +112,8 @@ def _turn_on_all_optimization(self): self.enable_jit_fused = True self.enable_sequence_parallelism = True self.enable_sequence_overlap = True + # todo modify default sequence parallelism mode + self.sequence_parallelism_mode = "1" def _infer(self): """ diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 2d907fa4063c..b7312563489d 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -165,6 +165,20 @@ def _criterion(outputs, inputs): return loss data = data_gen_fn() + + if ( + booster.plugin.shard_config.enable_sequence_parallelism + and booster.plugin.shard_config.sequence_parallelism_mode in ["1", "2"] + and booster.plugin.tp_size != 0 + ): + seq_len = data["input_ids"].shape[-1] + lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) + times = lcm // seq_len + input_shape = data["input_ids"].shape + for k, v in data.items(): + if v.shape == input_shape: + data[k] = v.repeat((1,) * (v.dim() - 1) + (times,)) + shard_test_data = {} for k, v in data.items(): if k == "attention_mask": @@ -180,24 +194,38 @@ def _criterion(outputs, inputs): for k, v in data.items(): unshard_test_data[k] = data[k].clone() - if booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.tp_size != 0: - seq_len = data["input_ids"].shape[-1] - lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) - times = lcm // seq_len - input_shape = data["input_ids"].shape - for k, v in data.items(): - if v.shape == input_shape: - data[k] = v.repeat((1,) * (v.dim() - 1) + (times,)) + # if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0: + # seq_len = data["input_ids"].shape[-1] + # lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) + # times = lcm // seq_len + # input_shape = data["input_ids"].shape + # for k, v in data.items(): + # if v.shape == input_shape: + # data[k] = v.repeat((1,) * (v.dim() - 1) + (times,)) + + # sharded_model.train() + # if booster.plugin.stage_manager is not None: + # for k, v in data.items(): + # if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: + # new_shape = [1] * v.dim() + # new_shape[0] = 4 + # data[k] = v.to("cuda").repeat(*new_shape) + + # data_iter = iter([data]) + # sharded_output = booster.execute_pipeline( + # data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=True + # ) + # sharded_loss = sharded_output["loss"] sharded_model.train() if booster.plugin.stage_manager is not None: - for k, v in data.items(): + for k, v in shard_test_data.items(): if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: new_shape = [1] * v.dim() new_shape[0] = 4 - data[k] = v.to("cuda").repeat(*new_shape) + shard_test_data[k] = v.to("cuda").repeat(*new_shape) - data_iter = iter([data]) + data_iter = iter([shard_test_data]) sharded_output = booster.execute_pipeline( data_iter, sharded_model, @@ -207,14 +235,22 @@ def _criterion(outputs, inputs): return_outputs=True, ) sharded_loss = sharded_output["loss"] + else: shard_test_data = {k: v.cuda() for k, v in shard_test_data.items()} - sharded_output, decoder_output_shard, ret_model_shard = sharded_model(**shard_test_data) + sharded_output = sharded_model(**shard_test_data) sharded_loss = criterion(sharded_output) sharded_optimizer.backward(sharded_loss) + org_model.train() + if booster.plugin.stage_manager is not None: + for k, v in unshard_test_data.items(): + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + unshard_test_data[k] = v.to("cuda").repeat(*new_shape) unshard_test_data = {k: v.cuda() for k, v in unshard_test_data.items()} - org_output, decoder_output_ori, ret_model_ori = org_model(**unshard_test_data) + org_output = org_model(**unshard_test_data) org_loss = criterion(org_output) org_loss.backward() diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 768bd95bdb42..ddfd30388998 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -155,6 +155,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, def run_bert_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + import torch.distributed as dist + + def print_rank(prompt, value, rank=0): + if dist.get_rank() == rank: + print(f"rank-{rank}, {prompt}: {value}") + + print_rank("test_config", test_config) + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) From 98a2eeb4816948348143e0b7d294097e5704c358 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Mon, 8 Jan 2024 16:12:50 +0800 Subject: [PATCH 04/50] integrate sequence parallel in ShardFormer --- .../booster/plugin/hybrid_parallel_plugin.py | 44 ++++++++++++------- .../shardformer/layer/qkv_fused_linear.py | 28 ++++++------ colossalai/shardformer/layer/utils.py | 22 +++++++--- colossalai/shardformer/modeling/llama.py | 4 +- colossalai/shardformer/policies/bloom.py | 19 ++++---- colossalai/shardformer/policies/chatglm2.py | 15 ++++--- colossalai/shardformer/policies/gpt2.py | 31 ++++++------- colossalai/shardformer/policies/llama.py | 24 +++++----- colossalai/shardformer/shard/shard_config.py | 2 - tests/test_shardformer/test_model/_utils.py | 12 ++--- .../test_model/test_llama_seq.py | 3 +- .../test_model/test_shard_bert.py | 9 ---- 12 files changed, 114 insertions(+), 99 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 5bdd16188ceb..cef9c9faf641 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -54,6 +54,7 @@ def __init__( shard_config: ShardConfig, dp_group: ProcessGroup, tp_group: ProcessGroup, + sp_group: ProcessGroup, use_ddp: bool, ddp_config: dict, custom_policy: Policy, @@ -62,6 +63,7 @@ def __init__( self.shard_config = shard_config self.dp_group = dp_group self.tp_group = tp_group + self.sp_group = sp_group self.use_dpp = use_ddp self.require_grad_sync = True @@ -170,21 +172,30 @@ def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None): None """ - if self.shard_config.test_seq_parallelism: - if grads is not None: - # Synchronize provided gradient tensors across the tensor parallelism group. - SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads) + if self.shard_config.enable_sequence_parallelism: + if self.shard_config.sequence_parallelism_mode in ["1", "2"]: + # If sequence parallelism is enabled and mode is 1 or 2, gradients are synchronized + # across the tensor parallelism group. + group = self.tp_group + require_flag = True + elif self.shard_config.sequence_parallelism_mode == "3": + # If sequence parallelism is enabled and mode is 3, gradients are synchronized + # across the sequence parallelism group. + group = self.sp_group + require_flag = False else: - # Synchronize gradients from the model across the tensor parallelism group. - SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module) + raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}") - if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism: if grads is not None: # Synchronize provided gradient tensors across the tensor parallelism group. - SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads) + SeqParallelUtils.allreduce_partial_data_grad( + process_group=group, grads=grads, require_flag=require_flag + ) else: # Synchronize gradients from the model across the tensor parallelism group. - SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module) + SeqParallelUtils.allreduce_partial_data_grad( + process_group=group, model=self.module, require_flag=require_flag + ) def forward(self, *args, **kwargs): if self.convert_fn is not None: @@ -740,7 +751,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]: if self.require_grad_sync and grads_to_sync is not None: # Synchronize sequence parallelism gradients if required. - SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_pg, grads=grads_to_sync) + SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync) else: return @@ -947,7 +958,7 @@ def __init__( self, tp_size: int, pp_size: int, - sp_size: int = 1, + sp_size: int = None, precision: str = "fp16", zero_stage: int = 0, enable_all_optimization: bool = False, @@ -982,7 +993,6 @@ def __init__( pp_style: str = "1f1b", num_model_chunks: int = 1, enable_metadata_cache: bool = True, - test_seq_parallelism: bool = False, ) -> None: super().__init__() assert ( @@ -1008,11 +1018,13 @@ def __init__( assert ( tp_size == 1 ), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with tensor parallelism" - self.sp_size = sp_size - self.dp_size = dist.get_world_size() // (sp_size * pp_size) + self.sp_size = dist.get_world_size() if sp_size is None else sp_size + self.dp_size = dist.get_world_size() // (self.sp_size * pp_size) else: self.dp_size = dist.get_world_size() // (tp_size * pp_size) - assert sp_size == 1, f"sp_size can only be set to a >1 number when enable_sequence_parallelism is True" + assert ( + sp_size == 1 or sp_size is None + ), f"sp_size can only be set to a >1 number when enable_sequence_parallelism is True" self.sp_size = 1 self.tp_size = tp_size @@ -1084,7 +1096,6 @@ def __init__( sequence_parallelism_mode=sequence_parallelism_mode, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, - test_seq_parallelism=test_seq_parallelism, ) self.amp_config = dict( initial_scale=initial_scale, @@ -1159,6 +1170,7 @@ def configure( shard_config=self.shard_config, dp_group=self.dp_group, tp_group=self.tp_group, + sp_group=self.sp_group, use_ddp=use_ddp, ddp_config=self.ddp_config, custom_policy=self.custom_policy, diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 12476d050600..6feca7ee5fc3 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -175,7 +175,8 @@ def __init__( process_group: ProcessGroup = None, async_communication: bool = False, gather_output: bool = False, - seq_parallel: bool = False, + # seq_parallel: bool = False, + seq_parallel_mode: str = None, overlap: bool = False, skip_bias_add: bool = False, n_fused: int = 3, @@ -190,7 +191,7 @@ def __init__( self.in_features = in_features self.out_features = out_features self.gather_output = gather_output - self.seq_parallel = seq_parallel + self.seq_parallel_mode = seq_parallel_mode self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device @@ -312,17 +313,17 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - if self.seq_parallel: - input_parallel = input_ - output_parallel = matmul_gather_forward_reducescatter_backward( - input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap - ) - else: + if self.seq_parallel_mode is None: # Set up backprop all-reduce. input_parallel = reduce_backward(input_, self.process_group) output_parallel = matmul_with_async_comm( input_parallel, self.weight, bias, self.process_group, self.async_communication ) + elif self.seq_parallel_mode == "1": + input_parallel = input_ + output_parallel = matmul_gather_forward_reducescatter_backward( + input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap + ) if self.gather_output: # All-gather across the partitions. @@ -366,7 +367,8 @@ def __init__( dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, - seq_parallel: bool = False, + # seq_parallel: bool = False, + seq_parallel_mode: str = None, parallel_input: bool = True, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -385,7 +387,7 @@ def __init__( self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group - self.seq_parallel = seq_parallel + self.seq_parallel_mode = seq_parallel_mode self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: @@ -529,10 +531,10 @@ def forward(self, input_: Tensor) -> Tensor: output = torch.cat(output_parallel_list, dim=-1) else: output_parallel = torch.matmul(input_, self.weight) - if self.seq_parallel: - output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) - else: + if self.seq_parallel_mode is None: output = reduce_forward(output_parallel, self.process_group) + elif self.seq_parallel_mode == "1": + output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 0d2cc1b3370d..b83a12438ed2 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -35,7 +35,12 @@ def is_sp_partial_derived_param(param): return getattr(param, "partial_derived", False) @staticmethod - def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None): + def allreduce_partial_data_grad( + process_group: ProcessGroup, + model: nn.Module = None, + grads: List[torch.Tensor] = None, + require_flag: bool = True, + ): """ Allreduce partial derived gradients across the specified process group. @@ -53,9 +58,9 @@ def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None." # Get the size of the process group, which determines whether synchronization is needed. - tp_size = get_world_size(tp_group) if tp_group is not None else 1 + group_size = get_world_size(process_group) if process_group is not None else 1 - if tp_size == 1: + if group_size == 1: # If the process group size is 1, no synchronization is required. return @@ -63,12 +68,15 @@ def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, # If `model` is provided, extract partial derived gradients from the model's parameters. grads = [] for p in model.parameters(): - if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p): - grads.append(p.grad.data) + if p.grad is not None: + if require_flag and SeqParallelUtils.is_sp_partial_derived_param(p): + grads.append(p.grad.data) + elif not require_flag: + grads.append(p.grad.data) # Flatten and reduce the gradients using the specified process group. coalesced = _flatten_dense_tensors(grads) - dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group) + dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group) # Unflatten the synchronized gradients and update the model's gradients. for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): @@ -76,7 +84,7 @@ def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, else: # If `grads` are provided explicitly, synchronize those gradients directly. coalesced = _flatten_dense_tensors(grads) - dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group) + dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group) for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 81af2cf1acd3..d3f5157bd26f 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -722,7 +722,7 @@ def forward( ) return forward -def test_llama_seq_parallel_attention(): +def get_llama_seq_parallel_attention_forward(): def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -861,7 +861,7 @@ def forward( return forward -def test_llama_seq_parallel_model(): +def get_llama_seq_parallel_model_forward(): def forward( self, input_ids: torch.LongTensor = None, diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index eddfafdcbcdc..2d235b8a0085 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -55,8 +55,11 @@ def module_policy(self): norm_cls = col_nn.FusedLayerNorm else: norm_cls = col_nn.LayerNorm - use_sequence_parallel = self.shard_config.enable_sequence_parallelism + + sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None overlap = self.shard_config.enable_sequence_overlap + sp_partial_derived = sp_mode in ["1"] + if self.shard_config.enable_tensor_parallelism: policy[BloomBlock] = ModulePolicyDescription( attribute_replacement={ @@ -70,12 +73,12 @@ def module_policy(self): SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap}, ), SubModuleReplacementDescription( suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={"seq_parallel_mode": sp_mode}, ), SubModuleReplacementDescription( suffix="self_attention.attention_dropout", @@ -84,12 +87,12 @@ def module_policy(self): SubModuleReplacementDescription( suffix="mlp.dense_h_to_4h", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap}, ), SubModuleReplacementDescription( suffix="mlp.dense_4h_to_h", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={"seq_parallel_mode": sp_mode}, ), ], ) @@ -132,19 +135,19 @@ def module_policy(self): SubModuleReplacementDescription( suffix="input_layernorm", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="post_attention_layernorm", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, target_key=BloomBlock, ) - if use_sequence_parallel: + if sp_mode == "1": self.append_or_create_method_replacement( description={"forward": get_bloom_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index d1ad9f91478b..7d1ff3f8e59e 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -55,8 +55,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: norm_cls = col_nn.RMSNorm else: norm_cls = col_nn.LayerNorm - use_sequence_parallel = self.shard_config.enable_sequence_parallelism + + sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None overlap = self.shard_config.enable_sequence_overlap + sp_partial_derived = sp_mode in ["1"] + if self.shard_config.enable_tensor_parallelism: policy[ChatGLMModel] = ModulePolicyDescription( attribute_replacement={}, @@ -91,12 +94,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0, "overlap": overlap}, + kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, "overlap": overlap}, ), SubModuleReplacementDescription( suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0}, + kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0}, ), SubModuleReplacementDescription( suffix="self_attention.core_attention.attention_dropout", @@ -110,12 +113,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="input_layernorm", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="post_attention_layernorm", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, @@ -145,7 +148,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) # use sequence parallel - if use_sequence_parallel: + if sp_mode == "1": self.append_or_create_method_replacement( description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 5b43ecaed0c7..4ce94c9fec66 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -50,8 +50,11 @@ def module_policy(self): norm_cls = col_nn.FusedLayerNorm else: norm_cls = col_nn.LayerNorm - use_sequence_parallel = self.shard_config.enable_sequence_parallelism + + sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None overlap = self.shard_config.enable_sequence_overlap + sp_partial_derived = sp_mode in ["1"] + if self.shard_config.enable_tensor_parallelism: policy[GPT2Model] = ModulePolicyDescription( sub_module_replacement=[ @@ -76,32 +79,26 @@ def module_policy(self): SubModuleReplacementDescription( suffix="attn.c_attn", target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 3, - "seq_parallel": use_sequence_parallel, - "overlap": overlap, - }, + kwargs={"n_fused": 3, "seq_parallel_mode": sp_mode, "overlap": overlap}, ), SubModuleReplacementDescription( suffix="attn.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, kwargs={ - "seq_parallel": use_sequence_parallel, + "seq_parallel_mode": sp_mode, }, ), SubModuleReplacementDescription( suffix="mlp.c_fc", target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 1, - "seq_parallel": use_sequence_parallel, - "overlap": overlap, - }, + kwargs={"n_fused": 1, "seq_parallel_mode": sp_mode, "overlap": overlap}, ), SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={ + "seq_parallel_mode": sp_mode, + }, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", @@ -133,18 +130,18 @@ def module_policy(self): SubModuleReplacementDescription( suffix="ln_1", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="ln_2", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="ln_cross_attn", target_module=norm_cls, ignore_if_not_exist=True, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, @@ -164,7 +161,7 @@ def module_policy(self): "forward": get_gpt_model_forward_for_flash_attn(self.shard_config) } - if self.shard_config.enable_sequence_parallelism: + if sp_mode == "1": policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} return policy diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 9aa1053b1ec2..243ef9bf1e84 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -2,7 +2,6 @@ from functools import partial from typing import Callable, Dict, List, Union -import torch.distributed as dist import torch.nn as nn from torch import Tensor from torch.nn import Module @@ -14,8 +13,8 @@ get_llama_flash_attention_forward, get_llama_model_forward_for_flash_attn, get_lm_forward_with_dist_cross_entropy, - test_llama_seq_parallel_attention, - test_llama_seq_parallel_model, + get_llama_seq_parallel_attention_forward, + get_llama_seq_parallel_model_forward, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -48,13 +47,18 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: else: norm_cls = RMSNorm - if self.shard_config.enable_sequence_parallelism: + sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + # overlap = self.shard_config.enable_sequence_overlap + # sp_partial_derived = sp_mode in ["1"] + + # todo: Support SP for LlaMa model + if sp_mode == "1": self.shard_config.enable_sequence_parallelism = False + self.shard_config.sequence_parallelism_mode = None + sp_mode = None warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") - - # todo: seq - if self.shard_config.test_seq_parallelism: - sequence_parallelism_size = dist.get_world_size() + elif sp_mode == "3": + sequence_parallelism_size = self.shard_config.sequence_parallel_size decoder_attribute_replacement = { "num_heads": self.model.config.num_attention_heads // sequence_parallelism_size, "head_dim": self.model.config.hidden_size // self.model.config.num_attention_heads, @@ -72,14 +76,14 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: self.append_or_create_method_replacement( description={ - "forward": test_llama_seq_parallel_attention(), + "forward": get_llama_seq_parallel_attention_forward(), }, policy=policy, target_key=LlamaAttention, ) self.append_or_create_method_replacement( description={ - "forward": test_llama_seq_parallel_model(), + "forward": get_llama_seq_parallel_model_forward(), }, policy=policy, target_key=LlamaModel, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 6a3bf38fa5f9..2d18a839a5d6 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -39,8 +39,6 @@ class ShardConfig: sequence_parallelism_mode: str = None enable_sequence_overlap: bool = False parallel_output: bool = True - test_seq_parallelism: bool = False - extra_kwargs: Dict[str, Any] = field(default_factory=dict) # TODO padding vocab # make_vocab_size_divisible_by: int = 128 # pipeline_parallel_size: int diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index b7312563489d..8f9a158ba012 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -185,10 +185,10 @@ def _criterion(outputs, inputs): shard_test_data[k] = data[k].clone() else: shard_test_data[k] = ( - data[k].clone() - if booster.plugin.shard_config.test_seq_parallelism is False - # else data[k].clone() - else torch.chunk(data[k].clone(), dist.get_world_size(), dim=1)[dist.get_rank()] + torch.chunk(data[k].clone(), booster.plugin.shard_config.sequence_parallel_size, dim=1)[dist.get_rank()] + if booster.plugin.shard_config.enable_sequence_parallelism + and booster.plugin.shard_config.sequence_parallelism_mode == "3" + else data[k].clone() ) unshard_test_data = {} for k, v in data.items(): @@ -263,16 +263,12 @@ def check_output_hidden_state( stage_manager: Optional[PipelineStageManager] = None, atol: float = 1e-5, rtol: float = 1e-3, - booster: Booster = None, ): org_hidden_state = org_output.last_hidden_state if stage_manager and stage_manager.is_last_stage(ignore_chunk=True): sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] else: - # if booster and booster.plugin.shard_config.test_seq_parallelism: - # sharded_hidden_state = sharded_output - # else: sharded_hidden_state = sharded_output.last_hidden_state assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol) diff --git a/tests/test_shardformer/test_model/test_llama_seq.py b/tests/test_shardformer/test_model/test_llama_seq.py index 061d83bd9138..25be9706533e 100644 --- a/tests/test_shardformer/test_model/test_llama_seq.py +++ b/tests/test_shardformer/test_model/test_llama_seq.py @@ -102,7 +102,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 1, "pp_size": 1, "num_microbatches": 1, - "test_seq_parallelism": True, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "3", "use_lazy_init": True, "precision": "fp32", "initial_scale": 1, diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index ddfd30388998..9cd0b57365df 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -154,15 +154,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) def run_bert_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") - - import torch.distributed as dist - - def print_rank(prompt, value, rank=0): - if dist.get_rank() == rank: - print(f"rank-{rank}, {prompt}: {value}") - - print_rank("test_config", test_config) - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) From 28c11b7e397037ebe08c1a0b50492dbba876cf10 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Tue, 9 Jan 2024 14:16:05 +0800 Subject: [PATCH 05/50] fix pp bugs and sp bugs for LlaMa model --- tests/kit/model_zoo/transformers/llama.py | 2 +- tests/test_shardformer/test_layer/demo.py | 161 ------------------ tests/test_shardformer/test_model/_utils.py | 25 +-- .../test_model/test_llama_seq.py | 140 --------------- .../test_model/test_shard_llama.py | 10 ++ 5 files changed, 12 insertions(+), 326 deletions(-) delete mode 100644 tests/test_shardformer/test_layer/demo.py delete mode 100644 tests/test_shardformer/test_model/test_llama_seq.py diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 288d4d71b322..ec97e8c142b7 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -68,7 +68,7 @@ def data_gen_for_casual_lm(): loss_fn_for_seq_classification = lambda output: output["logits"].mean() config = LlamaConfig( - num_hidden_layers=2, + num_hidden_layers=4, hidden_size=64, intermediate_size=256, num_attention_heads=4, diff --git a/tests/test_shardformer/test_layer/demo.py b/tests/test_shardformer/test_layer/demo.py deleted file mode 100644 index 33bf7a7fac2e..000000000000 --- a/tests/test_shardformer/test_layer/demo.py +++ /dev/null @@ -1,161 +0,0 @@ -import os -import time - -import torch -import torch.distributed as dist -import torch.nn.functional as F - - -class AllGatherLinearWithRingCommunication(torch.autograd.Function): - """ - col-linear with hidden all_gather - - Y: [batch_size, seq_len / TP_size, hidden_size] - A: [batch_size, hidden_size, w_len / TP_size] - | - | Ring-based LinearOverlap - v - YA: [batch_size, seq_len, w_len / TP_size] - """ - - @staticmethod - def forward(ctx, input_, weight, bias, process_group): - # Input expected: (input_, bias) sharded on the row(sequence dim) and weight on the col - ctx.save_for_backward(input_, weight, bias) - ctx.use_bias = bias is not None - ctx.process_group = process_group - - # if bias is not None: - # output = F.linear(input_, weight, bias) - # else: - # output = F.linear(input_, weight) - group_size = dist.get_world_size(process_group) - cur_rank = dist.get_rank(process_group) - - input_shape = input_.shape - weight_shape = weight.shape - - output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)] - # output_tensor = torch.empty((input_shape[0], input_shape[1] * group_size, weight_shape[0]), device=input_.device) - - # initialization of ring communication - input_shape[1] - recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 - send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1 - recv_tensor = input_.clone() - send_tensor = input_.clone() - input_tensor = input_.clone() - # output_pt = output_tensor - - recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) - send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) - handles = dist.batch_isend_irecv([recv_op, send_op]) - # first round: special case, retrive from local tensor - output_tensors[0] = F.linear(input_, weight) - # output_pt[:][:local_seq_len][:] = F.linear(input_, weight) - # output_pt = output_pt[:][local_seq_len:][:] - for i in range(group_size - 2): - handles[0].wait() - handles[1].wait() - - tmp_tensor = input_tensor - input_tensor = recv_tensor - recv_tensor = tmp_tensor - send_tensor = input_tensor.clone() - - recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) - send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) - handles = dist.batch_isend_irecv([recv_op, send_op]) - - # actual computation - output_tensors[i + 1] = F.linear(input_tensor, weight) - # output_pt[:][:local_seq_len][:] = F.linear(input_, weight) - # output_pt = output_pt[:][local_seq_len:][:] - - # final round: special case, no need to send/recv again - handles[0].wait() - # output_pt[:][:local_seq_len][:] = F.linear(recv_tensor, weight) - output_tensors[group_size - 1] = F.linear(recv_tensor, weight) - handles[1].wait() - return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=1) - - @staticmethod - def backward(ctx, grad_output): - return grad_output - input, weight, bias = ctx.saved_tensors - use_bias = ctx.use_bias - - # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. - if use_bias: - bias.view(bias.shape) - - total_input = input - grad_input = grad_output.matmul(weight) - grad_output = grad_output.contiguous() - # Convert the tensor shapes to 2D for execution compatibility - if len(grad_output.shape) > 2: - grad_output = grad_output.view(-1, grad_output.shape[-1]) - total_input = total_input.view(-1, total_input.shape[-1]) - - if ctx.async_grad_allreduce: - # Asynchronous all-reduce - handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) - # Delay the start of weight gradient computation shortly (3us) to have - # all-reduce scheduled first and have GPU resources allocated - _ = torch.empty(1, device=grad_output.device) + 1 - - grad_weight = grad_output.t().matmul(total_input) - grad_bias = grad_output.sum(dim=0) if use_bias else None - - if ctx.async_grad_allreduce: - handle.wait() - - return grad_input, grad_weight, grad_bias, None, None, None - - -def main(): - dist.init_process_group("nccl") - local_rank = int(os.environ["LOCAL_RANK"]) - torch.cuda.set_device(local_rank) - - scale_c = 4 - y = torch.randn(4, 5120 * scale_c, 1024 * scale_c, requires_grad=False).cuda() - w = torch.randn(256 * scale_c, 1024 * scale_c, requires_grad=False).cuda() - - trial_time = 5 - - ## warm up - tensor_list = [torch.zeros_like(y) for _ in range(4)] - dist.all_gather(tensor_list, y) - - Y = torch.cat(tensor_list, dim=1) - torch_out = F.linear(Y, w) - - ring_out = AllGatherLinearWithRingCommunication.apply(y, w, None, None) - ## - - tic = time.perf_counter() - for _ in range(trial_time): - tensor_list = [torch.zeros_like(y) for _ in range(4)] - dist.all_gather(tensor_list, y) - - Y = torch.cat(tensor_list, dim=1) - torch_out = F.linear(Y, w) - print(torch_out[0][0][0]) - toc = time.perf_counter() - print(f"original function in {toc - tic:0.4f} seconds") - - tic = time.perf_counter() - for _ in range(trial_time): - ring_out = AllGatherLinearWithRingCommunication.apply(y, w, None, None) - print(ring_out[0][0][0]) - toc = time.perf_counter() - print(f"fused function in {toc - tic:0.4f} seconds") - - if not torch.allclose(torch_out, ring_out, atol=1e-3): - raise RuntimeError("ring_overlap: failed!!") - print("ring_overlap: pass.") - - -if __name__ == "__main__": - main() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 8f9a158ba012..de03ed06fa5e 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -181,7 +181,7 @@ def _criterion(outputs, inputs): shard_test_data = {} for k, v in data.items(): - if k == "attention_mask": + if k == "attention_mask" or k == "labels": shard_test_data[k] = data[k].clone() else: shard_test_data[k] = ( @@ -194,29 +194,6 @@ def _criterion(outputs, inputs): for k, v in data.items(): unshard_test_data[k] = data[k].clone() - # if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0: - # seq_len = data["input_ids"].shape[-1] - # lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) - # times = lcm // seq_len - # input_shape = data["input_ids"].shape - # for k, v in data.items(): - # if v.shape == input_shape: - # data[k] = v.repeat((1,) * (v.dim() - 1) + (times,)) - - # sharded_model.train() - # if booster.plugin.stage_manager is not None: - # for k, v in data.items(): - # if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: - # new_shape = [1] * v.dim() - # new_shape[0] = 4 - # data[k] = v.to("cuda").repeat(*new_shape) - - # data_iter = iter([data]) - # sharded_output = booster.execute_pipeline( - # data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=True - # ) - # sharded_loss = sharded_output["loss"] - sharded_model.train() if booster.plugin.stage_manager is not None: for k, v in shard_test_data.items(): diff --git a/tests/test_shardformer/test_model/test_llama_seq.py b/tests/test_shardformer/test_model/test_llama_seq.py deleted file mode 100644 index 25be9706533e..000000000000 --- a/tests/test_shardformer/test_model/test_llama_seq.py +++ /dev/null @@ -1,140 +0,0 @@ -import os - -import pytest -import torch -import torch.distributed as dist - -import colossalai -from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.layer.utils import Randomizer -from colossalai.tensor.d_tensor.api import clear_layout_converter -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import ( - build_model_from_hybrid_plugin, - check_all_grad_tensors, - check_loss, - check_output_hidden_state, - check_weight, - get_grad_tensors_for_check, - run_forward_backward_with_hybrid_plugin, - unwrap_model, -) - - -def print_rank(prompt, value, rank=0): - if dist.get_rank() == rank: - print(f"rank-{rank}, {prompt}: {value}") - - -os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" - - -def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( - model_fn, loss_fn, test_config - ) - org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster - ) - - stage_manager = booster.plugin.stage_manager - tp_group = booster.plugin.tp_group - - # unwrap model - llama_model = unwrap_model(org_model, "LlamaModel", "model") - shard_llama_model = unwrap_model(sharded_model, "LlamaModel", "model") - - row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] - col_layer_for_check = ["layers[0].self_attn.o_proj"] - - # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. - grads_to_check = {} - if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config["precision"] == "fp32": - atol, rtol = 1e-6, 1e-4 - else: - atol, rtol = 5e-3, 5e-3 - row_layer_grads = get_grad_tensors_for_check( - llama_model, shard_llama_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False - ) - col_layer_grads = get_grad_tensors_for_check( - llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False - ) - grads_to_check.update(col_layer_grads) - grads_to_check.update(row_layer_grads) - - # optimizer executes step - org_optimizer.step() - sharded_optimizer.step() - - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config["precision"] == "fp32": - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - - if org_model.__class__.__name__ == "LlamaModel": - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, booster=booster) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - - # check weights - if stage_manager is None or stage_manager.is_first_stage(): - if test_config["precision"] == "fp32": - atol, rtol = 1e-4, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - check_weight( - llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False - ) - - # check grads - check_all_grad_tensors(grads_to_check) - torch.cuda.empty_cache() - - -@parameterize( - "test_config", - [ - { - "tp_size": 1, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "3", - "use_lazy_init": True, - "precision": "fp32", - "initial_scale": 1, - } - ], -) -def run_llama_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") - - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name != "transformers_llama": - continue - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) - - clear_layout_converter() - Randomizer.reset_index() - torch.cuda.empty_cache() - - -def check_llama(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_llama_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_llama(): - spawn(check_llama, 4) - - -if __name__ == "__main__": - test_llama() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index f4822fc79b69..d9413675b641 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -100,6 +100,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 1, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "3", + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, From cd41e4233e99a322340ef74db0bf1a3470f29d74 Mon Sep 17 00:00:00 2001 From: Zhongkai Zhao Date: Wed, 10 Jan 2024 15:03:54 +0800 Subject: [PATCH 06/50] integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 --- .../booster/plugin/hybrid_parallel_plugin.py | 2 +- colossalai/shardformer/layer/_operation.py | 165 +++++++++++++++++- colossalai/shardformer/layer/linear.py | 16 +- colossalai/shardformer/layer/utils.py | 2 + colossalai/shardformer/modeling/llama.py | 72 +++++--- colossalai/shardformer/policies/llama.py | 38 ++-- tests/test_shardformer/test_model/_utils.py | 4 +- .../test_model/test_shard_llama.py | 10 ++ 8 files changed, 261 insertions(+), 48 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index cef9c9faf641..6c8a7ff4d2bc 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1012,7 +1012,7 @@ def __init__( warnings.warn( f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." ) - self.sp_size = tp_size + self.sp_size = 1 self.dp_size = dist.get_world_size() // (tp_size * pp_size) elif self.sequence_parallelism_mode in ["3"]: assert ( diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 0afb36937e19..e3ffaca4ed43 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -167,6 +167,49 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias, None, None, None +def _AllgatherLinear(input_, weight, process_group): + group_size = dist.get_world_size(process_group) + cur_rank = dist.get_rank(process_group) + + input_shape = input_.shape + weight_shape = weight.shape + + output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)] + + # initialization of ring communication + input_shape[1] + recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 + send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1 + recv_tensor = input_.clone() + send_tensor = input_.clone() + + recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) + send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) + handles = dist.batch_isend_irecv([send_op, recv_op]) + # first round: special case, retrive from local tensor + output_tensors[0] = F.linear(input_, weight) + for i in range(group_size - 2): + for handle in handles: + handle.wait() + + tmp_tensor = send_tensor + send_tensor = recv_tensor + recv_tensor = tmp_tensor + + recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) + send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) + handles = dist.batch_isend_irecv([recv_op, send_op]) + + # actual computation + output_tensors[i + 1] = F.linear(send_tensor, weight) + + # final round: special case, no need to send/recv again + for handle in handles: + handle.wait() + output_tensors[group_size - 1] = F.linear(recv_tensor, weight) + return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=1) + + class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """Gather input from sequence parallel in forward and reduce-scatter gradient in backward @@ -186,12 +229,11 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, ctx.dim = dim ctx.overlap = overlap - input_parallel = _gather(input_, dim, process_group) - if bias is not None: + input_parallel = _gather(input_, dim, process_group) output = F.linear(input_parallel, weight, bias) else: - output = F.linear(input_parallel, weight) + output = _AllgatherLinear(input_, weight, process_group) return output @@ -297,7 +339,116 @@ def backward(ctx, grad_output): return output, grad_weight, grad_bias, None, None, None, None +def _ReduceScatterLinear(input_, weight, process_group): + group_size = dist.get_world_size(process_group) + cur_rank = dist.get_rank(process_group) + + input_shape = input_.shape + + # initialization of ring communication + # communicate(e.g.): 0->1->2->3 + # compute(e.g.): 3->2->1->0 + input_tensors = list(torch.split(input_, int(input_shape[1] / group_size), dim=1)) + input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank] + input_tensors.reverse() + recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1 + send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 + + # first round: special case, no reduce operation + output_tensor = F.linear(input_tensors[0], weight) + recv_tensor = output_tensor.clone() + send_tensor = output_tensor.clone() + recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) + send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) + handles = dist.batch_isend_irecv([recv_op, send_op]) + for i in range(group_size - 2): + # actual computation + output_tensor = F.linear(input_tensors[i + 1], weight) + + for handle in handles: + handle.wait() + output_tensor += recv_tensor + + tmp_tensor = send_tensor + send_tensor = output_tensor + output_tensor = tmp_tensor + + recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) + send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) + handles = dist.batch_isend_irecv([recv_op, send_op]) + + # final round: special case, no need to send/recv again + output_tensor = F.linear(input_tensors[group_size - 1], weight) + for handle in handles: + handle.wait() + output_tensor += recv_tensor + return output_tensor + + class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): + """Reduce-scatter input from sequence parallel in forward and gather gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. + + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, dim): + ctx.save_for_backward(input_, weight, bias) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.dim = dim + + if bias is not None: + partial_output = F.linear(input_, weight, bias) + else: + return _ReduceScatterLinear(input_, weight, process_group) + + output_shape = list(partial_output.shape) + assert ( + output_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " + output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group) + + output_list = [ + item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous() + dist.reduce_scatter(output, output_list, group=process_group) + + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight, bias = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm + if use_bias: + bias = bias.view(bias.shape) + + grad_output = _gather(grad_output, dim, process_group) + + # TODO Need to fully optimize + total_input = input_ + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + return grad_input, grad_weight, grad_bias, None, None + + +class _ReduceScatterForwardGatherBackward(torch.autograd.Function): """Gather input from sequence parallel in forward and reduce-scatter gradient in backward Args: @@ -658,8 +809,12 @@ def linear_gather_forward_reducescatter_backward( ) -def linear_reducescatter_forward_gather_backward(input_, process_group, dim): - return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim) +def reducescatter_forward_gather_backward(input_, process_group, dim): + return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim) + + +def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1): + return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim) def matmul_gather_forward_reducescatter_backward( diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 069f7c8fd8c6..23593ab8cb2f 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -199,7 +199,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: if self.seq_parallel_mode is None: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) - elif self.seq_parallel_mode == "1": + elif self.seq_parallel_mode in ["1", "2"]: output_parallel = linear_gather_forward_reducescatter_backward( input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap ) @@ -410,14 +410,24 @@ def forward(self, input_: Tensor) -> Tensor: handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - output_parallel = F.linear(input_, self.weight) - if self.seq_parallel_mode is None: + output_parallel = F.linear(input_, self.weight) output = reduce_forward(output_parallel, self.process_group) elif self.seq_parallel_mode == "1": + output_parallel = F.linear(input_, self.weight) output = linear_reducescatter_forward_gather_backward( output_parallel, self.process_group, self.seq_parallel_dim ) + elif self.seq_parallel_mode == "2": + # TODO how to maintain compatibility? + # output = reducescatter_forward_gather_backward( + # output_parallel, self.process_group, self.seq_parallel_dim + # ) + output = linear_reducescatter_forward_gather_backward( + input_, + self.weight, + dim=self.seq_parallel_dim, + ) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index b83a12438ed2..0f7206ec1c27 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -75,6 +75,8 @@ def allreduce_partial_data_grad( grads.append(p.grad.data) # Flatten and reduce the gradients using the specified process group. + if len(grads) == 0: + return coalesced = _flatten_dense_tensors(grads) dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index d3f5157bd26f..c7f4f90fda1c 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -17,16 +17,18 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import ( + _gather, + all_to_all_comm, + gather_forward_split_backward, + split_forward_gather_backward, +) from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, cross_entropy_1d from colossalai.shardformer.layer._operation import all_to_all_comm, gather_forward_split_backward -def print_rank(prompt, value, rank=0): - if dist.get_rank() == rank: - print(f"rank-{rank}, {prompt}: {value}") - try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask @@ -436,7 +438,8 @@ def llama_for_sequence_classification_forward( return {"hidden_states": hidden_states} -def get_llama_flash_attention_forward(shard_config: ShardConfig): + +def get_llama_flash_attention_forward(sp_mode, sp_size): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb llama_version = 2 @@ -457,6 +460,8 @@ def forward( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() + if sp_mode == "2": + q_len *= sp_size assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -722,7 +727,8 @@ def forward( ) return forward -def get_llama_seq_parallel_attention_forward(): + +def get_llama_seq_parallel_attention_forward(sp_mode, sp_size): def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -759,6 +765,11 @@ def forward( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + # sp: modify sp_len when sequence parallel mode is 2 + if sp_mode == "2": + q_len *= sp_size if self.config.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp query_slices = self.q_proj.weight.split( @@ -781,30 +792,21 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - # introduce sequence parallel - query_states = all_to_all_comm(query_states) - key_states = all_to_all_comm(key_states) - value_states = all_to_all_comm(value_states) - - bsz, q_len, _ = query_states.size() + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "3": + query_states = all_to_all_comm(query_states) + key_states = all_to_all_comm(key_states) + value_states = all_to_all_comm(value_states) + bsz, q_len, _ = query_states.size() query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] - past_key_value_length = 0 if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - past_key_value_length = past_key_value[0].shape[2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - # modify position ids - position_ids_device = position_ids.device if position_ids is not None else hidden_states.device - position_ids = torch.arange( - past_key_value_length, q_len + past_key_value_length, dtype=torch.long, device=position_ids_device - ) - position_ids = position_ids.unsqueeze(0).view(-1, q_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -836,6 +838,7 @@ def forward( # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" @@ -843,9 +846,12 @@ def forward( ) attn_output = attn_output.transpose(1, 2).contiguous() - # attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) - attn_output = all_to_all_comm(attn_output, None, scatter_dim=1, gather_dim=2) + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "3": + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + attn_output = all_to_all_comm(attn_output, None, scatter_dim=1, gather_dim=2) + else: + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) if self.config.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) @@ -860,8 +866,8 @@ def forward( return forward - -def get_llama_seq_parallel_model_forward(): + +def get_llama_seq_parallel_model_forward(sp_mode, sp_size): def forward( self, input_ids: torch.LongTensor = None, @@ -892,11 +898,16 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + # sp: modify seq_length when using sequence parallel + seq_length *= sp_size + seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] + # modify past_key_values_length when using sequence parallel + past_key_values_length *= sp_size seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: @@ -909,13 +920,20 @@ def forward( position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + if sp_mode == "2": + input_ids = _gather(input_ids, 1, None) + inputs_embeds = self.embed_tokens(input_ids) + input_ids = input_ids.chunk(4, dim=1)[torch.distributed.get_rank()] + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, None) + else: + inputs_embeds = self.embed_tokens(input_ids) # embed positions if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) + attention_mask = _gather(attention_mask, 1, None) attention_mask = self._prepare_decoder_attention_mask( attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 243ef9bf1e84..a87d876de503 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -48,42 +48,53 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: norm_cls = RMSNorm sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + sp_size = self.shard_config.sequence_parallel_size # overlap = self.shard_config.enable_sequence_overlap # sp_partial_derived = sp_mode in ["1"] - # todo: Support SP for LlaMa model if sp_mode == "1": self.shard_config.enable_sequence_parallelism = False self.shard_config.sequence_parallelism_mode = None sp_mode = None warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + elif sp_mode == "2": + self.append_or_create_method_replacement( + description={ + "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size), + }, + policy=policy, + target_key=LlamaAttention, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size), + }, + policy=policy, + target_key=LlamaModel, + ) elif sp_mode == "3": - sequence_parallelism_size = self.shard_config.sequence_parallel_size decoder_attribute_replacement = { - "num_heads": self.model.config.num_attention_heads // sequence_parallelism_size, + "num_heads": self.model.config.num_attention_heads // sp_size, "head_dim": self.model.config.hidden_size // self.model.config.num_attention_heads, } if getattr(self.model.config, "num_key_value_heads", False): - decoder_attribute_replacement["num_key_value_heads"] = ( - self.model.config.num_key_value_heads // sequence_parallelism_size - ) + decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size decoder_attribute_replacement["num_key_value_groups"] = ( self.model.config.num_attention_heads // self.model.config.num_key_value_heads ) policy[LlamaAttention] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) - self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_attention_forward(), + "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size), }, policy=policy, target_key=LlamaAttention, ) self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_model_forward(), + "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size), }, policy=policy, target_key=LlamaModel, @@ -105,30 +116,37 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, + kwargs=dict(seq_parallel_mode=sp_mode), ), ], ) @@ -171,7 +189,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_llama_flash_attention_forward(self.shard_config), + "forward": get_llama_flash_attention_forward(sp_mode, sp_size), }, policy=policy, target_key=LlamaAttention, diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index de03ed06fa5e..e33d672be1bd 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -181,13 +181,13 @@ def _criterion(outputs, inputs): shard_test_data = {} for k, v in data.items(): - if k == "attention_mask" or k == "labels": + if k == "labels": shard_test_data[k] = data[k].clone() else: shard_test_data[k] = ( torch.chunk(data[k].clone(), booster.plugin.shard_config.sequence_parallel_size, dim=1)[dist.get_rank()] if booster.plugin.shard_config.enable_sequence_parallelism - and booster.plugin.shard_config.sequence_parallelism_mode == "3" + and booster.plugin.shard_config.sequence_parallelism_mode in ["2", "3"] else data[k].clone() ) unshard_test_data = {} diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index d9413675b641..d6878b26cdff 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -100,6 +100,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "2", + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, { "tp_size": 1, "pp_size": 1, From 391dc644effe9f00c2d59e3dc4642be3454a0559 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 10 Jan 2024 15:38:46 +0800 Subject: [PATCH 07/50] fix bugs when useing sp and flashattention together --- colossalai/shardformer/layer/linear.py | 4 ++-- colossalai/shardformer/modeling/llama.py | 9 ++++----- colossalai/shardformer/policies/llama.py | 2 +- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 23593ab8cb2f..a921cee9f08b 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -411,10 +411,10 @@ def forward(self, input_: Tensor) -> Tensor: output = torch.cat(output_parallel_list, dim=-1) else: if self.seq_parallel_mode is None: - output_parallel = F.linear(input_, self.weight) + output_parallel = linear_with_async_comm(input_, self.weight, None, None, False) output = reduce_forward(output_parallel, self.process_group) elif self.seq_parallel_mode == "1": - output_parallel = F.linear(input_, self.weight) + output_parallel = linear_with_async_comm(input_, self.weight, None, None, False) output = linear_reducescatter_forward_gather_backward( output_parallel, self.process_group, self.seq_parallel_dim ) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index c7f4f90fda1c..00ff5c78bc24 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -438,8 +438,7 @@ def llama_for_sequence_classification_forward( return {"hidden_states": hidden_states} - -def get_llama_flash_attention_forward(sp_mode, sp_size): +def get_llama_flash_attention_forward(shard_config): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb llama_version = 2 @@ -460,8 +459,8 @@ def forward( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - if sp_mode == "2": - q_len *= sp_size + if shard_config.sequence_parallelism_mode == "2": + q_len *= shard_config.sequence_parallel_size assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -866,7 +865,7 @@ def forward( return forward - + def get_llama_seq_parallel_model_forward(sp_mode, sp_size): def forward( self, diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index a87d876de503..e512cb79b91e 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -189,7 +189,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_llama_flash_attention_forward(sp_mode, sp_size), + "forward": get_llama_flash_attention_forward(shard_config=self.shard_config), }, policy=policy, target_key=LlamaAttention, From 13fc14cc5c7563c9e38e91440e3d09977a779dce Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Fri, 12 Jan 2024 16:43:35 +0800 Subject: [PATCH 08/50] fix operation function name --- colossalai/shardformer/layer/_operation.py | 1 - colossalai/shardformer/layer/linear.py | 3 ++- colossalai/shardformer/layer/qkv_fused_linear.py | 4 ++-- colossalai/shardformer/modeling/chatglm2.py | 2 ++ 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index e3ffaca4ed43..3167129c195c 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -401,7 +401,6 @@ def forward(ctx, input_, weight, bias, process_group, dim): ctx.use_bias = bias is not None ctx.process_group = process_group ctx.dim = dim - if bias is not None: partial_output = F.linear(input_, weight, bias) else: diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index a921cee9f08b..a908a862da88 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -28,6 +28,7 @@ linear_reducescatter_forward_gather_backward, linear_with_async_comm, reduce_forward, + reducescatter_forward_gather_backward, split_forward_gather_backward, ) from .parallel_module import ParallelModule @@ -415,7 +416,7 @@ def forward(self, input_: Tensor) -> Tensor: output = reduce_forward(output_parallel, self.process_group) elif self.seq_parallel_mode == "1": output_parallel = linear_with_async_comm(input_, self.weight, None, None, False) - output = linear_reducescatter_forward_gather_backward( + output = reducescatter_forward_gather_backward( output_parallel, self.process_group, self.seq_parallel_dim ) elif self.seq_parallel_mode == "2": diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 6feca7ee5fc3..aa408c5fbb3e 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -25,12 +25,12 @@ from ._operation import ( gather_forward_split_backward, - linear_reducescatter_forward_gather_backward, linear_with_async_comm, matmul_gather_forward_reducescatter_backward, matmul_with_async_comm, reduce_backward, reduce_forward, + reducescatter_forward_gather_backward, split_forward_gather_backward, ) from .parallel_module import ParallelModule @@ -534,7 +534,7 @@ def forward(self, input_: Tensor) -> Tensor: if self.seq_parallel_mode is None: output = reduce_forward(output_parallel, self.process_group) elif self.seq_parallel_mode == "1": - output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 9bfb7053943b..81528a0a2d62 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -72,6 +72,7 @@ def forward( # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) + print("hidden_states", hidden_states.shape) # Self attention. attention_output, kv_cache = self.self_attention( layernorm_output, @@ -379,6 +380,7 @@ def forward( dim=0, process_group=shard_config.tensor_parallel_process_group, ) + print("inputs_embeds", inputs_embeds.shape) hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( inputs_embeds, full_attention_mask, From 83e6044a75470e8e4a5a32976ecc573bab5b36cd Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 17 Jan 2024 13:38:00 +0800 Subject: [PATCH 09/50] support flash attention for ulysses-style sp --- colossalai/shardformer/modeling/chatglm2.py | 2 - colossalai/shardformer/modeling/llama.py | 116 ++++++++++++++++-- tests/kit/model_zoo/transformers/llama.py | 4 +- tests/test_shardformer/test_model/_utils.py | 6 + .../test_model/test_shard_llama.py | 24 ++++ 5 files changed, 140 insertions(+), 12 deletions(-) diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 81528a0a2d62..9bfb7053943b 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -72,7 +72,6 @@ def forward( # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) - print("hidden_states", hidden_states.shape) # Self attention. attention_output, kv_cache = self.self_attention( layernorm_output, @@ -380,7 +379,6 @@ def forward( dim=0, process_group=shard_config.tensor_parallel_process_group, ) - print("inputs_embeds", inputs_embeds.shape) hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( inputs_embeds, full_attention_mask, diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 00ff5c78bc24..3f6b55999ad0 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -28,7 +28,6 @@ from ..layer import ColoAttention, cross_entropy_1d from colossalai.shardformer.layer._operation import all_to_all_comm, gather_forward_split_backward - try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask @@ -37,6 +36,11 @@ LATEST_VERSION = False +def print_rank(prompt, value, rank=0): + if dist.get_rank() == rank: + print(f"rank-{rank}, {prompt}: {value}") + + class LlamaPipelineForwards: """ This class serves as a micro library for forward function substitution of Llama models @@ -459,13 +463,27 @@ def forward( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - if shard_config.sequence_parallelism_mode == "2": + sp_mode = shard_config.sequence_parallelism_mode + sp_size = shard_config.sequence_parallel_size + + if sp_mode == "2": q_len *= shard_config.sequence_parallel_size - assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." + # assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "3": + query_states = all_to_all_comm(query_states) + key_states = all_to_all_comm(key_states) + value_states = all_to_all_comm(value_states) + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -489,7 +507,35 @@ def forward( attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + # me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) + # query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) + # key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) + # value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape) + + # flash_attention_mask = None + # attn_mask_type = AttnMaskType.causal + # if not getattr(shard_config, "causal_lm", False) and attention_mask != None: + # if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + # raise ValueError( + # f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + # ) + # flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + # attn_mask_type = AttnMaskType.paddedcausal + # hidden_size = self.hidden_size // sp_size if sp_mode == "3" else self.hidden_size + + # attention = ColoAttention(embed_dim=hidden_size, num_heads=self.num_heads) + # attn_output = attention( + # query_states, + # key_states, + # value_states, + # attn_mask=flash_attention_mask, + # attn_mask_type=attn_mask_type, + # origin_attn_mask=attention_mask, + # ) + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "3": + attn_output = all_to_all_comm(attn_output, None, scatter_dim=1, gather_dim=2) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value @@ -727,6 +773,7 @@ def forward( return forward + def get_llama_seq_parallel_attention_forward(sp_mode, sp_size): def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -765,7 +812,6 @@ def forward( use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - # sp: modify sp_len when sequence parallel mode is 2 if sp_mode == "2": q_len *= sp_size @@ -867,6 +913,59 @@ def forward( def get_llama_seq_parallel_model_forward(sp_mode, sp_size): + # Copied from transformers.models.bart.modeling_bart._make_causal_mask + def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + # Copied from transformers.models.bart.modeling_bart._expand_mask + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + # inverted_mask = 1.0 - expanded_mask + inverted_mask = expanded_mask.mul_(-1).add_(1.0) + return inverted_mask.masked_fill_(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + def forward( self, input_ids: torch.LongTensor = None, @@ -932,9 +1031,10 @@ def forward( attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) + attention_mask = _gather(attention_mask, 1, None) - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_decoder_attention_mask( attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length ) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index ec97e8c142b7..49244ea688b4 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -68,10 +68,10 @@ def data_gen_for_casual_lm(): loss_fn_for_seq_classification = lambda output: output["logits"].mean() config = LlamaConfig( - num_hidden_layers=4, + num_hidden_layers=8, hidden_size=64, intermediate_size=256, - num_attention_heads=4, + num_attention_heads=8, max_position_embeddings=128, num_labels=16, ) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index e33d672be1bd..33f631689a67 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -165,6 +165,11 @@ def _criterion(outputs, inputs): return loss data = data_gen_fn() + # for k, v in data.items(): + # size = list(v.shape) + # tg_size = [1] * len(size) + # tg_size[1] = 64 * 2 + # data[k] = v.repeat(tg_size) if ( booster.plugin.shard_config.enable_sequence_parallelism @@ -232,6 +237,7 @@ def _criterion(outputs, inputs): org_loss.backward() return org_loss, org_output, sharded_loss, sharded_output + # return sharded_loss, sharded_output, sharded_loss, sharded_output def check_output_hidden_state( diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index d6878b26cdff..a3600f208707 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -106,6 +106,27 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "2", + "enable_flash_attention": True, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "2", + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "3", "use_lazy_init": True, "precision": "fp32", "initial_scale": 1, @@ -116,6 +137,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "3", + "enable_flash_attention": True, "use_lazy_init": True, "precision": "fp32", "initial_scale": 1, @@ -177,6 +199,8 @@ def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name != "transformers_llama": + continue check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() From 755769139a07e0d8a22dbad43b119d2282e25d10 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 17 Jan 2024 14:06:20 +0800 Subject: [PATCH 10/50] clarify sp process group --- colossalai/shardformer/modeling/llama.py | 19 ++++++++++--------- colossalai/shardformer/policies/llama.py | 7 ++++--- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 3f6b55999ad0..988d4360b52f 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -465,6 +465,7 @@ def forward( bsz, q_len, _ = hidden_states.size() sp_mode = shard_config.sequence_parallelism_mode sp_size = shard_config.sequence_parallel_size + sp_group = shard_config.sequence_parallel_process_group if sp_mode == "2": q_len *= shard_config.sequence_parallel_size @@ -476,9 +477,9 @@ def forward( # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "3": - query_states = all_to_all_comm(query_states) - key_states = all_to_all_comm(key_states) - value_states = all_to_all_comm(value_states) + query_states = all_to_all_comm(query_states, sp_group) + key_states = all_to_all_comm(key_states, sp_group) + value_states = all_to_all_comm(value_states, sp_group) bsz, q_len, _ = query_states.size() query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -535,7 +536,7 @@ def forward( # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "3": - attn_output = all_to_all_comm(attn_output, None, scatter_dim=1, gather_dim=2) + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value @@ -774,7 +775,7 @@ def forward( return forward -def get_llama_seq_parallel_attention_forward(sp_mode, sp_size): +def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -839,9 +840,9 @@ def forward( # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "3": - query_states = all_to_all_comm(query_states) - key_states = all_to_all_comm(key_states) - value_states = all_to_all_comm(value_states) + query_states = all_to_all_comm(query_states, sp_group) + key_states = all_to_all_comm(key_states, sp_group) + value_states = all_to_all_comm(value_states, sp_group) bsz, q_len, _ = query_states.size() query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -894,7 +895,7 @@ def forward( # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "3": attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) - attn_output = all_to_all_comm(attn_output, None, scatter_dim=1, gather_dim=2) + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index e512cb79b91e..98524c03daea 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -12,9 +12,9 @@ LlamaPipelineForwards, get_llama_flash_attention_forward, get_llama_model_forward_for_flash_attn, - get_lm_forward_with_dist_cross_entropy, get_llama_seq_parallel_attention_forward, get_llama_seq_parallel_model_forward, + get_lm_forward_with_dist_cross_entropy, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -49,6 +49,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None sp_size = self.shard_config.sequence_parallel_size + sp_group = self.shard_config.sequence_parallel_process_group # overlap = self.shard_config.enable_sequence_overlap # sp_partial_derived = sp_mode in ["1"] # todo: Support SP for LlaMa model @@ -60,7 +61,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: elif sp_mode == "2": self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size), + "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), }, policy=policy, target_key=LlamaAttention, @@ -87,7 +88,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size), + "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), }, policy=policy, target_key=LlamaAttention, From 9698a87922a12d775883555302ed7881154465cc Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 17 Jan 2024 14:17:48 +0800 Subject: [PATCH 11/50] fix compatibility bugs in moe plugin --- colossalai/booster/plugin/moe_hybrid_parallel_plugin.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index ae372dd034e0..83888e5069a7 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -254,6 +254,9 @@ def __init__( self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + # TODO: Currently moe only support partially sequence parallel + self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, @@ -365,6 +368,7 @@ def configure( shard_config=self.shard_config, dp_group=self.dp_group, tp_group=self.tp_group, + sp_group=self.sp_group, use_ddp=use_ddp, ddp_config=self.ddp_config, custom_policy=self.custom_policy, From 7a31083dcef443e4258c913d650985fa934a3388 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 17 Jan 2024 15:22:51 +0800 Subject: [PATCH 12/50] fix fused linear bugs --- .../shardformer/layer/qkv_fused_linear.py | 6 ++--- .../test_gpt2_qkv_fused_linear_1d.py | 27 ++++++++++++------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index aa408c5fbb3e..6c5fb41494f0 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -150,7 +150,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): device (`torch.device`): The device of parameters, defaults to None. n_fused (int): The number items fused, defaults to 3 (QKV). process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. - seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None. gather_output (bool, optional): If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is :math:`Y_i = XA_i`, defaults to False @@ -175,7 +175,6 @@ def __init__( process_group: ProcessGroup = None, async_communication: bool = False, gather_output: bool = False, - # seq_parallel: bool = False, seq_parallel_mode: str = None, overlap: bool = False, skip_bias_add: bool = False, @@ -348,7 +347,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): dtype (`torch.dtype`): The dtype of parameters, defaults to None. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, - seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None. which is preserved for kernel fusion, defaults to False weight_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer. @@ -367,7 +366,6 @@ def __init__( dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, - # seq_parallel: bool = False, seq_parallel_mode: str = None, parallel_input: bool = True, skip_bias_add: bool = False, diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index e056860ede57..06586c0dc8ca 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -56,13 +56,18 @@ def rearrange(tensor: torch.Tensor, dim: int): return rearanged_tensor -def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): +def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module( - linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, n_fused=3, overlap=overlap + linear_copy, + process_group=None, + gather_output=True, + seq_parallel_mode=seq_parallel_mode, + n_fused=3, + overlap=overlap, ) assert linear.weight.shape == torch.Size([48, 192]) @@ -79,7 +84,9 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool) # check computation correctness x = torch.rand(1, 4, 48).cuda() out = linear(x) - x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + x_for_shard = ( + torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] if seq_parallel_mode == "1" else x.expand_as(x.clone()) + ) gather_out = linear_conv_col(x_for_shard) assert_close(rearrange(out, -1), gather_out) @@ -91,14 +98,14 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool) assert_close(target_grad, linear_conv_col.weight.grad) -def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): +def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() linear_row = GPT2FusedLinearConv1D_Row.from_native_module( - linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel + linear_copy, process_group=None, parallel_input=False, seq_parallel_mode=seq_parallel_mode ) assert linear.weight.shape == torch.Size([48, 192]) @@ -115,7 +122,7 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): x = torch.rand(1, 4, 48).cuda() out = linear(x) gather_out = linear_row(x) - target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + target_out = torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] if seq_parallel_mode == "1" else out assert_close(target_out, gather_out) # check backward correctness @@ -128,11 +135,11 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): @parameterize("lazy_init", [False, True]) -@parameterize("seq_parallel", [False, True]) +@parameterize("seq_parallel_mode", ["1", None]) @parameterize("overlap", [True]) -def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool): - check_linear_conv_1d_col(lazy_init, seq_parallel, overlap) - check_linear_conv_1d_row(lazy_init, seq_parallel) +def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, overlap: bool): + check_linear_conv_1d_col(lazy_init, seq_parallel_mode, overlap) + check_linear_conv_1d_row(lazy_init, seq_parallel_mode) def run_dist(rank, world_size, port): From 74457dfb91eaa50cab3716c9123f6b687f71cbdf Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 17 Jan 2024 16:50:35 +0800 Subject: [PATCH 13/50] fix linear layer test --- .../test_gpt2_qkv_fused_linear_1d.py | 4 +- .../test_layer/test_linear_1d.py | 42 +++++++++++-------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index 06586c0dc8ca..e4351ddae7f4 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -85,7 +85,7 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b x = torch.rand(1, 4, 48).cuda() out = linear(x) x_for_shard = ( - torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] if seq_parallel_mode == "1" else x.expand_as(x.clone()) + x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] ) gather_out = linear_conv_col(x_for_shard) assert_close(rearrange(out, -1), gather_out) @@ -122,7 +122,7 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool): x = torch.rand(1, 4, 48).cuda() out = linear(x) gather_out = linear_row(x) - target_out = torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] if seq_parallel_mode == "1" else out + target_out = out if seq_parallel_mode is None else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] assert_close(target_out, gather_out) # check backward correctness diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index defa4afb919b..28b4c36f7e89 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -15,13 +15,13 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" -def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): +def check_linear_1d_col(lazy_init: bool, seq_parallel_mode: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() linear_col = Linear1D_Col.from_native_module( - linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, overlap=overlap + linear_copy, process_group=None, gather_output=True, seq_parallel_mode=seq_parallel_mode, overlap=overlap ) # ensure that the parameters are distributed @@ -43,7 +43,9 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) - x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + x_for_shard = ( + x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + ) x_for_shard.requires_grad_(True) out = linear(x_for_unshard) @@ -63,20 +65,20 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): assert x_for_unshard.grad is not None target_unshard_gard = ( x_for_unshard.grad - if seq_parallel is False + if seq_parallel_mode is None else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] ) assert_close(target_unshard_gard, x_for_shard.grad) -def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): +def check_linear_1d_row(lazy_init: bool, seq_parallel_mode: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() linear_row = Linear1D_Row.from_native_module( - linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel + linear_copy, process_group=None, parallel_input=False, seq_parallel_mode=seq_parallel_mode ) assert linear_row.weight.shape == torch.Size([128, 16]) @@ -98,7 +100,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): # run forward out = linear(x_for_unshard) gather_out = linear_row(x_for_shard) - target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + target_out = out if seq_parallel_mode is None else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] assert_close(target_out, gather_out) # check backward correctness @@ -115,7 +117,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): assert_close(x_for_unshard.grad, x_for_shard.grad) -def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool): +def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear_1 = nn.Linear(32, 128).cuda() @@ -125,10 +127,10 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool linear_1_copy = nn.Linear(32, 128).cuda() linear_2_copy = nn.Linear(128, 32).cuda() linear_col = Linear1D_Col.from_native_module( - linear_1_copy, process_group=None, gather_output=False, seq_parallel=seq_parallel, overlap=overlap + linear_1_copy, process_group=None, gather_output=False, seq_parallel_mode=seq_parallel_mode, overlap=overlap ) linear_row = Linear1D_Row.from_native_module( - linear_2_copy, process_group=None, parallel_input=True, seq_parallel=seq_parallel + linear_2_copy, process_group=None, parallel_input=True, seq_parallel_mode=seq_parallel_mode ) linear_1.load_state_dict(linear_col.state_dict()) @@ -141,13 +143,17 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) - x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + x_for_shard = ( + x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + ) x_for_shard.requires_grad_(True) # run forward unshard_out = linear_2(linear_1(x_for_unshard)) shard_out = linear_row(linear_col(x_for_shard)) - target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()] + target_out = ( + unshard_out if seq_parallel_mode is None else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()] + ) assert_close(target_out, shard_out) # check backward correctness @@ -163,19 +169,19 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool assert x_for_unshard.grad is not None target_unshard_gard = ( x_for_unshard.grad - if seq_parallel is False + if seq_parallel_mode is None else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] ) assert_close(target_unshard_gard, x_for_shard.grad) @parameterize("lazy_init", [False, True]) -@parameterize("seq_parallel", [False, True]) +@parameterize("seq_parallel_mode", [None, "1"]) @parameterize("overlap", [True]) -def run_dist_linear_test(lazy_init, seq_parallel, overlap): - check_linear_1d_col(lazy_init, seq_parallel, overlap) - check_linear_1d_row(lazy_init, seq_parallel) - check_linear_col_plus_row(lazy_init, seq_parallel, overlap) +def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap): + check_linear_1d_col(lazy_init, seq_parallel_mode, overlap) + check_linear_1d_row(lazy_init, seq_parallel_mode) + check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap) def check_dist_linear(rank, world_size, port): From 858f55dac08e67d0e2aa259d07e1cfafd77e43b1 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Tue, 23 Jan 2024 16:58:15 +0800 Subject: [PATCH 14/50] support gpt model all-to-all sp --- colossalai/shardformer/modeling/gpt2.py | 72 ++++++++--- colossalai/shardformer/policies/gpt2.py | 20 ++- colossalai/shardformer/policies/llama.py | 2 +- tests/kit/model_zoo/transformers/gpt.py | 46 +++++-- tests/test_shardformer/test_model/_utils.py | 32 ++--- .../test_model/test_shard_gpt2.py | 117 ++++++++++++------ 6 files changed, 205 insertions(+), 84 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 1d034fff1639..f98a28b24ce4 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -1,6 +1,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch +import torch.distributed as dist from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -22,7 +23,12 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention -from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.layer._operation import ( + _gather, + all_to_all_comm, + gather_forward_split_backward, + split_forward_gather_backward, +) from colossalai.shardformer.shard import ShardConfig from ..layer import cross_entropy_1d @@ -98,6 +104,11 @@ def _get_attention_mask( return attention_mask, encoder_attention_mask +def print_rank(prompt, value, rank=0): + if dist.get_rank() == rank: + print(f"rank-{rank}, {prompt}: {value}") + + class GPT2PipelineForwards: """ This class serves as a micro library for forward function substitution of GPT2 models @@ -789,7 +800,7 @@ def gpt2_for_sequence_classification_forward( ) -def get_gpt2_flash_attention_forward(): +def get_gpt2_flash_attention_forward(sp_mode, sp_size, sp_group): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention def forward( @@ -816,9 +827,15 @@ def forward( attention_mask = encoder_attention_mask else: query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) + + if sp_mode == "3": + query = all_to_all_comm(query) + key = all_to_all_comm(key) + value = all_to_all_comm(value) + + query = split_heads(query, self.num_heads, self.head_dim) + key = split_heads(key, self.num_heads, self.head_dim) + value = split_heads(value, self.num_heads, self.head_dim) if layer_past is not None: past_key, past_value = layer_past @@ -838,6 +855,9 @@ def forward( dropout_p = self.attn_dropout.p if self.training else 0.0 attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + if sp_mode == "3": + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present, None) @@ -847,7 +867,8 @@ def forward( return forward -def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig): +# def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig): +def gpt2_sequence_parallel_forward_fn(sp_mode, sp_size, sp_group): def forward( self: GPT2Model, input_ids: Optional[torch.LongTensor] = None, @@ -886,16 +907,23 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device + # use variable seq_len to replace input_shape[-1] + seq_len = input_shape[-1] + if sp_mode in ["2", "3"]: + seq_len *= sp_size + if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) + token_type_ids = token_type_ids.view(-1, seq_len) if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) + position_ids = position_ids.view(-1, seq_len) if past_key_values is None: past_length = 0 past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) + if sp_mode in ["2", "3"]: + past_length *= sp_size if position_ids is None: position_ids = torch.arange( past_length, @@ -905,6 +933,13 @@ def forward( ) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + # split position ids when using sequence parallel + if sp_mode in ["2", "3"]: + position_ids = torch.chunk(position_ids.clone(), sp_size, dim=1)[dist.get_rank(sp_group)] + + if sp_mode in ["2", "3"]: + attention_mask = _gather(attention_mask, 1, sp_group) + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N @@ -1120,6 +1155,9 @@ def forward( encoder_hidden_states, encoder_attention_mask, ) + # output_shape = input_shape + (hidden_states.size(-1),) + # output_shape = input_shape[:-1] + (seq_len, ) + (hidden_states.size(-1),) + output_shape = (-1,) + input_shape[1:-1] + (seq_len,) + (hidden_states.size(-1),) if self.gradient_checkpointing and self.training: if use_cache: @@ -1134,13 +1172,10 @@ def forward( all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None - # split the input tensor along sequence dimension - # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - hidden_states = split_forward_gather_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - ) + if sp_mode == "1": + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward(hidden_states, dim=1, process_group=sp_group) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): # Model parallel @@ -1203,13 +1238,10 @@ def custom_forward(*inputs): hidden_states = hidden_states.to("cuda:" + str(k + 1)) # When sequence parallelism done, gather the output tensor in forward and split it in backward - hidden_states = gather_forward_split_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - ) + hidden_states = gather_forward_split_backward(hidden_states, dim=1, process_group=sp_group) hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) # Add last hidden state if output_hidden_states: diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 4ce94c9fec66..2137387cdf6c 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -52,9 +52,21 @@ def module_policy(self): norm_cls = col_nn.LayerNorm sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + sp_size = self.shard_config.sequence_parallel_size + sp_group = self.shard_config.sequence_parallel_process_group overlap = self.shard_config.enable_sequence_overlap sp_partial_derived = sp_mode in ["1"] + if sp_mode == "2": + pass + elif sp_mode == "3": + decoder_attribute_replacement = { + "num_heads": self.model.config.num_attention_heads // sp_size, + } + policy[GPT2Attention] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + ) + if self.shard_config.enable_tensor_parallelism: policy[GPT2Model] = ModulePolicyDescription( sub_module_replacement=[ @@ -151,7 +163,7 @@ def module_policy(self): if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_gpt2_flash_attention_forward(), + "forward": get_gpt2_flash_attention_forward(sp_mode, sp_size, sp_group), }, policy=policy, target_key=GPT2Attention, @@ -161,8 +173,10 @@ def module_policy(self): "forward": get_gpt_model_forward_for_flash_attn(self.shard_config) } - if sp_mode == "1": - policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} + if sp_mode is not None: + policy[GPT2Model].method_replacement = { + "forward": gpt2_sequence_parallel_forward_fn(sp_mode, sp_size, sp_group) + } return policy diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 98524c03daea..dda9cb7cc1d1 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -76,7 +76,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: elif sp_mode == "3": decoder_attribute_replacement = { "num_heads": self.model.config.num_attention_heads // sp_size, - "head_dim": self.model.config.hidden_size // self.model.config.num_attention_heads, + # "head_dim": self.model.config.hidden_size // self.model.config.num_attention_heads, } if getattr(self.model.config, "num_key_value_heads", False): decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 24f9627c269c..b7372b6f9607 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -18,8 +18,23 @@ def data_gen(): # tokenized_input = tokenizer(input, return_tensors='pt') # input_ids = tokenized_input['input_ids'] # attention_mask = tokenized_input['attention_mask'] - input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + # input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) + # attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + input_ids = torch.tensor( + [ + [15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779], + [15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779], + ], + dtype=torch.int64, + ) + attention_mask = torch.tensor( + [ + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ], + dtype=torch.int64, + ) + return dict(input_ids=input_ids, attention_mask=attention_mask) @@ -35,9 +50,9 @@ def data_gen_for_question_answering(): # question answering data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - start_positions = torch.tensor([0], dtype=torch.int64) + start_positions = torch.tensor([[0], [0]], dtype=torch.int64) data["start_positions"] = start_positions - end_positions = torch.tensor([1], dtype=torch.int64) + end_positions = torch.tensor([[1], [1]], dtype=torch.int64) data["end_positions"] = end_positions return data @@ -46,14 +61,20 @@ def data_gen_for_token_classification(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64) + data["labels"] = torch.tensor( + [ + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1], + ], + dtype=torch.int64, + ) return data def data_gen_for_sequence_classification(): # sequence classification data gen data = data_gen() - data["labels"] = torch.tensor([1], dtype=torch.int64) + data["labels"] = torch.tensor([[1], [1]], dtype=torch.int64) return data @@ -61,12 +82,18 @@ def date_gen_for_double_heads(): num_choices = 2 batch_size = 2 input_ids = torch.tensor( - [[15496, 11, 616, 3290, 318, 13779, 318, 13779], [15496, 11, 616, 3290, 318, 13779, 318, 13779]], + [ + [15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779], + [15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779], + ], + dtype=torch.int64, + ) + attention_mask = torch.tensor( + [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64, ) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) - mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64) + mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64) mc_token_ids = torch.arange(0, num_choices, dtype=torch.int64) mc_token_ids = mc_token_ids.expand((batch_size, num_choices)) multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, num_choices, -1).contiguous() @@ -103,6 +130,7 @@ def date_gen_for_double_heads(): hidden_dropout=0, problem_type="single_label_classification", pad_token_id=50256, + tie_word_embeddings=False, ) config_for_token_classification = copy.deepcopy(config) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 33f631689a67..f6496ef4402a 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,5 +1,4 @@ import copy -import math from contextlib import nullcontext from typing import Any, Callable, Dict, List, Optional @@ -171,26 +170,29 @@ def _criterion(outputs, inputs): # tg_size[1] = 64 * 2 # data[k] = v.repeat(tg_size) - if ( - booster.plugin.shard_config.enable_sequence_parallelism - and booster.plugin.shard_config.sequence_parallelism_mode in ["1", "2"] - and booster.plugin.tp_size != 0 - ): - seq_len = data["input_ids"].shape[-1] - lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) - times = lcm // seq_len - input_shape = data["input_ids"].shape - for k, v in data.items(): - if v.shape == input_shape: - data[k] = v.repeat((1,) * (v.dim() - 1) + (times,)) + # if ( + # booster.plugin.shard_config.enable_sequence_parallelism + # and booster.plugin.shard_config.sequence_parallelism_mode in ["1", "2"] + # and booster.plugin.tp_size != 0 + # ): + # seq_len = data["input_ids"].shape[-1] + # lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) + # times = lcm // seq_len + # input_shape = data["input_ids"].shape + # for k, v in data.items(): + # if v.shape == input_shape: + # data[k] = v.repeat((1,) * (v.dim() - 1) + (times,)) shard_test_data = {} for k, v in data.items(): - if k == "labels": + if k not in ["input_ids", "attention_mask"]: shard_test_data[k] = data[k].clone() else: + # todo: check the correctness of using dim=-1: to be compatible with date_gen_for_double_heads() shard_test_data[k] = ( - torch.chunk(data[k].clone(), booster.plugin.shard_config.sequence_parallel_size, dim=1)[dist.get_rank()] + torch.chunk(data[k].clone(), booster.plugin.shard_config.sequence_parallel_size, dim=-1)[ + dist.get_rank() + ] if booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.shard_config.sequence_parallelism_mode in ["2", "3"] else data[k].clone() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index d59d7e4ad499..1501a8e12fb1 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -1,5 +1,6 @@ import pytest import torch +import torch.distributed as dist import colossalai from colossalai.logging import disable_existing_loggers @@ -19,6 +20,11 @@ ) +def print_rank(prompt, value, rank=0): + if dist.get_rank() == rank: + print(f"rank-{rank}, {prompt}: {value}") + + def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( model_fn, loss_fn, test_config @@ -131,24 +137,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 4, - "enable_all_optimization": True, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 4, - "enable_all_optimization": True, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - }, { "tp_size": 4, "pp_size": 1, @@ -168,28 +156,85 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 2, "num_microbatches": 4, "enable_all_optimization": False, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "2", + "enable_flash_attention": True, "use_lazy_init": True, "precision": "fp32", - }, - { - "tp_size": 2, - "pp_size": 1, - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 2, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", "initial_scale": 1, }, + # { + # "tp_size": 1, + # "pp_size": 1, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "3", + # "enable_flash_attention": True, + # "use_lazy_init": True, + # "precision": "fp32", + # "initial_scale": 1, + # }, + # { + # "tp_size": 2, + # "pp_size": 2, + # "num_microbatches": 4, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 1, + # "pp_size": 2, + # "num_microbatches": 4, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 4, + # "pp_size": 1, + # "enable_all_optimization": False, + # "enable_flash_attention": True, + # "use_lazy_init": False, + # "precision": "fp32", + # }, + # { + # "tp_size": 2, + # "pp_size": 1, + # "enable_all_optimization": True, + # "use_lazy_init": False, + # "precision": "fp32", + # }, + # { + # "tp_size": 2, + # "pp_size": 2, + # "num_microbatches": 4, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "precision": "fp32", + # }, + # { + # "tp_size": 2, + # "pp_size": 1, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 2, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 1, + # "pp_size": 2, + # "num_microbatches": 2, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 1, + # "precision": "fp16", + # "initial_scale": 1, + # }, ], ) @clear_cache_before_run() From 0b115b4465a6ca34ff5d3d839bad0c99f4671ca7 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Tue, 23 Jan 2024 17:01:07 +0800 Subject: [PATCH 15/50] modify shard data dimension (meant to be dim=-1) --- tests/test_shardformer/test_model/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index f6496ef4402a..882e479b2909 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -190,8 +190,8 @@ def _criterion(outputs, inputs): else: # todo: check the correctness of using dim=-1: to be compatible with date_gen_for_double_heads() shard_test_data[k] = ( - torch.chunk(data[k].clone(), booster.plugin.shard_config.sequence_parallel_size, dim=-1)[ - dist.get_rank() + torch.chunk(data[k].clone(), booster.plugin.shard_config.sequence_parallel_size, dim=1)[ + dist.get_rank(booster.plugin.shard_config.sequence_parallel_process_group) ] if booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.shard_config.sequence_parallelism_mode in ["2", "3"] From d146040d6e7e317d0a3657876e5fb9ff875abcb8 Mon Sep 17 00:00:00 2001 From: linsj20 Date: Tue, 23 Jan 2024 17:01:42 +0800 Subject: [PATCH 16/50] support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability --- colossalai/shardformer/layer/_operation.py | 43 ++- colossalai/shardformer/layer/linear.py | 18 +- colossalai/shardformer/modeling/llama.py | 354 ++++++++++++++++--- colossalai/shardformer/policies/llama.py | 38 +- colossalai/shardformer/shard/shard_config.py | 18 +- tests/test_shardformer/test_model/_utils.py | 6 - 6 files changed, 390 insertions(+), 87 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 3167129c195c..ec245760af6c 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -210,6 +210,41 @@ def _AllgatherLinear(input_, weight, process_group): return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=1) +class _GatherForwardReduceScatterBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. + + """ + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.process_group = process_group + ctx.dim = dim + + return _gather(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + dim = ctx.dim + process_group = ctx.process_group + + # do reduce-scatter + new_shape = list(grad_output.shape) + assert ( + new_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " + new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) + grad_list = [item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)] + output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device) + dist.reduce_scatter(output, grad_list, group=process_group) + + return output, None, None + + class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """Gather input from sequence parallel in forward and reduce-scatter gradient in backward @@ -386,7 +421,7 @@ def _ReduceScatterLinear(input_, weight, process_group): class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): - """Reduce-scatter input from sequence parallel in forward and gather gradient in backward + """Reduce-scatter input from sequence parallel in forward and gather gradient in backward with ring Args: input_ (`torch.Tensor`): The input tensor from sequence parallel region. @@ -448,7 +483,7 @@ def backward(ctx, grad_output): class _ReduceScatterForwardGatherBackward(torch.autograd.Function): - """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + """Reduce-scatter input from sequence parallel in forward and gather gradient in backward Args: input_ (`torch.Tensor`): The input tensor from sequence parallel region. @@ -808,6 +843,10 @@ def linear_gather_forward_reducescatter_backward( ) +def gather_forward_reducescatter_backward(input_, process_group, dim): + return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim) + + def reducescatter_forward_gather_backward(input_, process_group, dim): return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index a908a862da88..20a9f0328cfc 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -24,6 +24,8 @@ from ._operation import ( gather_forward_split_backward, + gather_forward_reducescatter_backward, + reducescatter_forward_gather_backward, linear_gather_forward_reducescatter_backward, linear_reducescatter_forward_gather_backward, linear_with_async_comm, @@ -200,7 +202,10 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: if self.seq_parallel_mode is None: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) - elif self.seq_parallel_mode in ["1", "2"]: + elif self.seq_parallel_mode == "1": + input_parallel = gather_forward_reducescatter_backward(input_parallel, self.process_group, self.seq_parallel_dim) + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False) + elif self.seq_parallel_mode == "2": output_parallel = linear_gather_forward_reducescatter_backward( input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap ) @@ -412,18 +417,15 @@ def forward(self, input_: Tensor) -> Tensor: output = torch.cat(output_parallel_list, dim=-1) else: if self.seq_parallel_mode is None: - output_parallel = linear_with_async_comm(input_, self.weight, None, None, False) + output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output = reduce_forward(output_parallel, self.process_group) elif self.seq_parallel_mode == "1": - output_parallel = linear_with_async_comm(input_, self.weight, None, None, False) + #output = linear_with_async_comm(input_, self.weight, None, None, False) + output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output = reducescatter_forward_gather_backward( - output_parallel, self.process_group, self.seq_parallel_dim + output_parallel, self.process_group, self.seq_parallel_dim ) elif self.seq_parallel_mode == "2": - # TODO how to maintain compatibility? - # output = reducescatter_forward_gather_backward( - # output_parallel, self.process_group, self.seq_parallel_dim - # ) output = linear_reducescatter_forward_gather_backward( input_, self.weight, diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 988d4360b52f..1fe4be571ad0 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -22,6 +22,8 @@ all_to_all_comm, gather_forward_split_backward, split_forward_gather_backward, + gather_forward_reducescatter_backward, + reducescatter_forward_gather_backward, ) from colossalai.shardformer.shard import ShardConfig @@ -441,7 +443,6 @@ def llama_for_sequence_classification_forward( hidden_states = transformer_outputs.get("hidden_states") return {"hidden_states": hidden_states} - def get_llama_flash_attention_forward(shard_config): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb @@ -465,23 +466,24 @@ def forward( bsz, q_len, _ = hidden_states.size() sp_mode = shard_config.sequence_parallelism_mode sp_size = shard_config.sequence_parallel_size - sp_group = shard_config.sequence_parallel_process_group - - if sp_mode == "2": + + if sp_mode in["1", "2"]: q_len *= shard_config.sequence_parallel_size - # assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." - + assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." + query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "3": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) + query_states = all_to_all_comm(query_states) + key_states = all_to_all_comm(key_states) + value_states = all_to_all_comm(value_states) bsz, q_len, _ = query_states.size() + if shard_config.sequence_parallel_size < 4: + print(query_states.shape) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -536,7 +538,7 @@ def forward( # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "3": - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + attn_output = all_to_all_comm(attn_output, None, scatter_dim=1, gather_dim=2) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value @@ -814,7 +816,7 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() # sp: modify sp_len when sequence parallel mode is 2 - if sp_mode == "2": + if sp_mode in["1", "2"]: q_len *= sp_size if self.config.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp @@ -866,24 +868,49 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + # TODO (linshengjie) Block attention with ring + #### + block_wise = False + seq_len = query_states[2] + seq_block = 1024 + if block_wise and seq_len > seq_block: + assert query_states.shape[2] % seq_block == 0 + block_num = query_states.shape[2] // seq_block + + query_states_chunks = query_states.chunk(block_num, dim=2) + if attention_mask is not None: + attention_mask_chunks = attention_mask.chunk(block_num, dim=2) + attn_output_chunks = [] + + + for i in range(block_num): + attn_weights = torch.matmul(query_states_chunks[i], key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask_chunks[i] + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output_chunks.append(torch.matmul(attn_weights, value_states)) + attn_output = torch.cat(attn_output_chunks, dim=2) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" ) - attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + #### if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( @@ -913,52 +940,72 @@ def forward( return forward -def get_llama_seq_parallel_model_forward(sp_mode, sp_size): +import torch.distributed as dist + +def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): + + logger = logging.get_logger(__name__) + # Copied from transformers.models.bart.modeling_bart._make_causal_mask - def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 + def _make_causal_mask_partial( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0, sp_group = None ): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + world_size = dist.get_world_size() + tgt_len *= world_size + + mask = torch.full((tgt_len, tgt_len // world_size), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1) * world_size, device=device) + + block_size = tgt_len // world_size + idx = dist.get_rank() + off = idx * block_size + + mask.masked_fill_(mask_cond[off:off+block_size] < (mask_cond + 1).view(mask.size(-1) * world_size, 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + mask = torch.cat([torch.zeros(tgt_len // world_size, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, (tgt_len + past_key_values_length) // world_size) + # Copied from transformers.models.bart.modeling_bart._expand_mask - def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + def _expand_mask_partial(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None, sp_group = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - # inverted_mask = 1.0 - expanded_mask - inverted_mask = expanded_mask.mul_(-1).add_(1.0) - return inverted_mask.masked_fill_(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + world_size = dist.get_world_size() + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len * world_size, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length): + def _prepare_decoder_attention_mask_partial(attention_mask, input_shape, inputs_embeds, past_key_values_length, sp_group = None): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( + combined_attention_mask = _make_causal_mask_partial( input_shape, inputs_embeds.dtype, device=inputs_embeds.device, past_key_values_length=past_key_values_length, + sp_group=sp_group ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + expanded_attn_mask = _expand_mask_partial(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1], sp_group=sp_group).to( inputs_embeds.device ) combined_attention_mask = ( @@ -967,6 +1014,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, return combined_attention_mask + def forward( self, input_ids: torch.LongTensor = None, @@ -998,7 +1046,8 @@ def forward( raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") # sp: modify seq_length when using sequence parallel - seq_length *= sp_size + if sp_mode in ["2", "3"]: + seq_length *= sp_size seq_length_with_past = seq_length past_key_values_length = 0 @@ -1020,28 +1069,46 @@ def forward( if inputs_embeds is None: if sp_mode == "2": - input_ids = _gather(input_ids, 1, None) + input_ids = _gather(input_ids, 1, sp_group) inputs_embeds = self.embed_tokens(input_ids) - input_ids = input_ids.chunk(4, dim=1)[torch.distributed.get_rank()] - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, None) + input_ids = input_ids.chunk(sp_size, dim=1)[torch.distributed.get_rank(sp_group)] + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) else: inputs_embeds = self.embed_tokens(input_ids) + # TODO Internal function + use_distributed_mask = False + # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) + if sp_mode is None or use_distributed_mask is False: + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) - attention_mask = _gather(attention_mask, 1, None) + if sp_mode in ["2", "3"]: + attention_mask = _gather(attention_mask, 1, sp_group) - attention_mask = _prepare_decoder_attention_mask( - attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length - ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length + ) + else: + world_size = dist.get_world_size(sp_group) + assert seq_length_with_past % world_size == 0 + attention_mask = torch.ones( + (batch_size, seq_length_with_past // world_size), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = _prepare_decoder_attention_mask_partial( + attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length, sp_group + ) + attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + attention_mask = _gather(attention_mask, 1, sp_group) hidden_states = inputs_embeds + if sp_mode == "1": + hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) - if self.gradient_checkpointing and self.training: + if (self.gradient_checkpointing or sp_mode in ["2", "3"]) and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." @@ -1058,8 +1125,7 @@ def forward( all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - + if (self.gradient_checkpointing or sp_mode in ["2", "3"]) and self.training: def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value @@ -1074,6 +1140,7 @@ def custom_forward(*inputs): position_ids, ) else: + layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, @@ -1094,7 +1161,7 @@ def custom_forward(*inputs): hidden_states = self.norm(hidden_states) # Todo: Maybe this line can be optimized - hidden_states = gather_forward_split_backward(hidden_states, 1, None) + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) # add hidden states from the last decoder layer if output_hidden_states: @@ -1112,3 +1179,178 @@ def custom_forward(*inputs): ) return forward + + +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): + from transformers import LlamaForCausalLM + + def forward( + self: LlamaForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + if shard_config.enable_tensor_parallelism: + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + ) + else: + shift_logits = shift_logits.view(-1, self.config.vocab_size) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return forward + + +def get_llama_decoder_seq_parallel_model_forward(sp_mode, sp_size, sp_group): + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + if sp_mode == "1": + hidden_states = gather_forward_reducescatter_backward(hidden_states, sp_group, 1) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + if sp_mode == "1": + hidden_states = reducescatter_forward_gather_backward(hidden_states, sp_group, 1) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + if sp_mode == "1": + hidden_states = gather_forward_reducescatter_backward(hidden_states, sp_group, 1) + + hidden_states = self.mlp(hidden_states) + + if sp_mode == "1": + hidden_states = reducescatter_forward_gather_backward(hidden_states, sp_group, 1) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + return forward diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index dda9cb7cc1d1..2ee28075407a 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -14,6 +14,7 @@ get_llama_model_forward_for_flash_attn, get_llama_seq_parallel_attention_forward, get_llama_seq_parallel_model_forward, + get_llama_decoder_seq_parallel_model_forward, get_lm_forward_with_dist_cross_entropy, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -54,10 +55,20 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: # sp_partial_derived = sp_mode in ["1"] # todo: Support SP for LlaMa model if sp_mode == "1": - self.shard_config.enable_sequence_parallelism = False - self.shard_config.sequence_parallelism_mode = None - sp_mode = None - warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + self.append_or_create_method_replacement( + description={ + "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=LlamaModel, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=LlamaAttention, + ) elif sp_mode == "2": self.append_or_create_method_replacement( description={ @@ -68,7 +79,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size), + "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group), }, policy=policy, target_key=LlamaModel, @@ -95,7 +106,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size), + "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group), }, policy=policy, target_key=LlamaModel, @@ -177,6 +188,17 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_key=LlamaDecoderLayer, ) + ''' + if sp_mode == "1" and False: + self.append_or_create_method_replacement( + description={ + "forward": get_llama_decoder_seq_parallel_model_forward(sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=LlamaDecoderLayer, + ) + ''' + self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="norm", @@ -319,7 +341,9 @@ def module_policy(self): policy = super().module_policy() - if self.shard_config.enable_tensor_parallelism: + setattr(self.shard_config, "causal_lm", True) + + if self.shard_config.enable_tensor_parallelism and not self.shard_config.enable_sequence_parallelism: # add a new item for casual lm new_item = { LlamaForCausalLM: ModulePolicyDescription( diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 2d18a839a5d6..b7310d81b314 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -54,6 +54,10 @@ def sequence_parallel_size(self): return self._sequence_parallel_size def __post_init__(self): + # turn on all optimization if all_optimization is set to True + if self.enable_all_optimization: + self._turn_on_all_optimization() + if self.enable_sequence_parallelism: self.sequence_parallelism_mode = ( "1" if self.sequence_parallelism_mode is None else self.sequence_parallelism_mode @@ -96,10 +100,6 @@ def __post_init__(self): else: self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group) - # turn on all optimization if all_optimization is set to True - if self.enable_all_optimization: - self._turn_on_all_optimization() - def _turn_on_all_optimization(self): """ Turn on all optimization. @@ -108,10 +108,12 @@ def _turn_on_all_optimization(self): self.enable_fused_normalization = True self.enable_flash_attention = True self.enable_jit_fused = True - self.enable_sequence_parallelism = True - self.enable_sequence_overlap = True - # todo modify default sequence parallelism mode - self.sequence_parallelism_mode = "1" + if self.enable_tensor_parallelism: + self.enable_sequence_parallelism = True + self.enable_sequence_overlap = True + # todo modify default sequence parallelism mode and process group + self.sequence_parallelism_mode = "1" + self.sequence_parallel_process_group = self.tensor_parallel_process_group def _infer(self): """ diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 882e479b2909..c0ef27255ab9 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -164,11 +164,6 @@ def _criterion(outputs, inputs): return loss data = data_gen_fn() - # for k, v in data.items(): - # size = list(v.shape) - # tg_size = [1] * len(size) - # tg_size[1] = 64 * 2 - # data[k] = v.repeat(tg_size) # if ( # booster.plugin.shard_config.enable_sequence_parallelism @@ -239,7 +234,6 @@ def _criterion(outputs, inputs): org_loss.backward() return org_loss, org_output, sharded_loss, sharded_output - # return sharded_loss, sharded_output, sharded_loss, sharded_output def check_output_hidden_state( From 362b5b62236ecbc961ad57edfef0fd21120cd17f Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 24 Jan 2024 17:31:56 +0800 Subject: [PATCH 17/50] finish sp mode 3 support for gpt --- colossalai/shardformer/modeling/llama.py | 57 +++++++++++-------- .../test_layer/test_sequence_parallel.py | 2 +- tests/test_shardformer/test_model/_utils.py | 2 +- 3 files changed, 34 insertions(+), 27 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 1fe4be571ad0..2d3a135d6d88 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -20,10 +20,10 @@ from colossalai.shardformer.layer._operation import ( _gather, all_to_all_comm, - gather_forward_split_backward, - split_forward_gather_backward, gather_forward_reducescatter_backward, + gather_forward_split_backward, reducescatter_forward_gather_backward, + split_forward_gather_backward, ) from colossalai.shardformer.shard import ShardConfig @@ -443,6 +443,7 @@ def llama_for_sequence_classification_forward( hidden_states = transformer_outputs.get("hidden_states") return {"hidden_states": hidden_states} + def get_llama_flash_attention_forward(shard_config): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb @@ -466,11 +467,11 @@ def forward( bsz, q_len, _ = hidden_states.size() sp_mode = shard_config.sequence_parallelism_mode sp_size = shard_config.sequence_parallel_size - - if sp_mode in["1", "2"]: + + if sp_mode in ["1", "2"]: q_len *= shard_config.sequence_parallel_size assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." - + query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) @@ -816,7 +817,7 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() # sp: modify sp_len when sequence parallel mode is 2 - if sp_mode in["1", "2"]: + if sp_mode in ["1", "2"]: q_len *= sp_size if self.config.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp @@ -882,9 +883,10 @@ def forward( attention_mask_chunks = attention_mask.chunk(block_num, dim=2) attn_output_chunks = [] - for i in range(block_num): - attn_weights = torch.matmul(query_states_chunks[i], key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states_chunks[i], key_states.transpose(2, 3)) / math.sqrt( + self.head_dim + ) if attention_mask is not None: attn_weights = attn_weights + attention_mask_chunks[i] attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) @@ -942,19 +944,23 @@ def forward( import torch.distributed as dist -def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): +def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): logger = logging.get_logger(__name__) # Copied from transformers.models.bart.modeling_bart._make_causal_mask def _make_causal_mask_partial( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0, sp_group = None + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sp_group=None, ): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - world_size = dist.get_world_size() + world_size = dist.get_world_size(sp_group) tgt_len *= world_size mask = torch.full((tgt_len, tgt_len // world_size), torch.finfo(dtype).min, device=device) @@ -964,23 +970,24 @@ def _make_causal_mask_partial( idx = dist.get_rank() off = idx * block_size - mask.masked_fill_(mask_cond[off:off+block_size] < (mask_cond + 1).view(mask.size(-1) * world_size, 1), 0) + mask.masked_fill_(mask_cond[off : off + block_size] < (mask_cond + 1).view(mask.size(-1) * world_size, 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len // world_size, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + mask = torch.cat( + [torch.zeros(tgt_len // world_size, past_key_values_length, dtype=dtype, device=device), mask], dim=-1 + ) return mask[None, None, :, :].expand(bsz, 1, tgt_len, (tgt_len + past_key_values_length) // world_size) - # Copied from transformers.models.bart.modeling_bart._expand_mask - def _expand_mask_partial(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None, sp_group = None): + def _expand_mask_partial(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None, sp_group=None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len - world_size = dist.get_world_size() + world_size = dist.get_world_size(sp_group) expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len * world_size, src_len).to(dtype) @@ -988,9 +995,10 @@ def _expand_mask_partial(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Option return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask_partial(attention_mask, input_shape, inputs_embeds, past_key_values_length, sp_group = None): + def _prepare_decoder_attention_mask_partial( + attention_mask, input_shape, inputs_embeds, past_key_values_length, sp_group=None + ): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None @@ -1000,21 +1008,20 @@ def _prepare_decoder_attention_mask_partial(attention_mask, input_shape, inputs_ inputs_embeds.dtype, device=inputs_embeds.device, past_key_values_length=past_key_values_length, - sp_group=sp_group + sp_group=sp_group, ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask_partial(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1], sp_group=sp_group).to( - inputs_embeds.device - ) + expanded_attn_mask = _expand_mask_partial( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1], sp_group=sp_group + ).to(inputs_embeds.device) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask - def forward( self, input_ids: torch.LongTensor = None, @@ -1126,6 +1133,7 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if (self.gradient_checkpointing or sp_mode in ["2", "3"]) and self.training: + def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value @@ -1140,7 +1148,6 @@ def custom_forward(*inputs): position_ids, ) else: - layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, @@ -1287,7 +1294,6 @@ def forward( def get_llama_decoder_seq_parallel_model_forward(sp_mode, sp_size, sp_group): - def forward( self, hidden_states: torch.Tensor, @@ -1353,4 +1359,5 @@ def forward( outputs += (present_key_value,) return outputs + return forward diff --git a/tests/test_shardformer/test_layer/test_sequence_parallel.py b/tests/test_shardformer/test_layer/test_sequence_parallel.py index c2ad6918cd2b..72a9fd09fbc6 100644 --- a/tests/test_shardformer/test_layer/test_sequence_parallel.py +++ b/tests/test_shardformer/test_layer/test_sequence_parallel.py @@ -216,7 +216,7 @@ def seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size): # backward result check assert_close(q_grad_seq, q_grad) assert_close(k_grad_seq, k_grad) - assert_close(v_grad_seq, v_grad) + assert_close(v_grad_seq, v_grad, atol=1e-4, rtol=1e-4) assert_close(o_grad_seq, o_grad) assert_close(x_grad_seq_gather, x_grad) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index c0ef27255ab9..4d4e54e0178c 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -185,7 +185,7 @@ def _criterion(outputs, inputs): else: # todo: check the correctness of using dim=-1: to be compatible with date_gen_for_double_heads() shard_test_data[k] = ( - torch.chunk(data[k].clone(), booster.plugin.shard_config.sequence_parallel_size, dim=1)[ + torch.chunk(data[k].clone(), booster.plugin.shard_config.sequence_parallel_size, dim=-1)[ dist.get_rank(booster.plugin.shard_config.sequence_parallel_process_group) ] if booster.plugin.shard_config.enable_sequence_parallelism From 7293b1631d35818c64b3355834045ba5dbcb59f8 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 24 Jan 2024 19:51:58 +0800 Subject: [PATCH 18/50] using all_to_all_single when batch size is 1 --- colossalai/shardformer/layer/_operation.py | 43 +++++++++++++++++-- tests/kit/model_zoo/transformers/llama.py | 12 +++--- .../test_layer/test_sequence_parallel.py | 2 +- .../test_model/test_shard_llama.py | 4 +- 4 files changed, 49 insertions(+), 12 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index ec245760af6c..bca16aabef25 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -230,7 +230,7 @@ def forward(ctx, input_, process_group, dim): @staticmethod def backward(ctx, grad_output): dim = ctx.dim - process_group = ctx.process_group + process_group = ctx.process_group # do reduce-scatter new_shape = list(grad_output.shape) @@ -238,7 +238,9 @@ def backward(ctx, grad_output): new_shape[dim] % dist.get_world_size(process_group) == 0 ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) - grad_list = [item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)] + grad_list = [ + item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim) + ] output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device) dist.reduce_scatter(output, grad_list, group=process_group) @@ -716,7 +718,13 @@ def forward(ctx, input_, process_group, scatter_dim, gather_dim): ctx.scatter_dim = scatter_dim ctx.gather_dim = gather_dim world_size = dist.get_world_size(process_group) - return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim) + bsz, _, _ = input_.shape + + # using all_to_all_single when batch size is 1 + if bsz == 1: + return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim) + else: + return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim) @staticmethod def backward(ctx, *grad_output): @@ -827,6 +835,35 @@ def _all_to_all(input_, world_size, group, scatter_dim, gather_dim): return torch.cat(output_list, dim=gather_dim).contiguous() +def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): + inp_shape = list(input_.shape) + inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size + if scatter_dim < 2: + input_t = input_.reshape([seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]).contiguous() + else: + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + input_t = ( + input_.reshape([-1, seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]) + .transpose(0, 1) + .contiguous() + ) + + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) + + # if scattering the seq-dim, transpose the heads back to the original dimension + if scatter_dim < 2: + output = output.transpose(0, 1).contiguous() + + return output.reshape( + inp_shape[:gather_dim] + + [ + inp_shape[gather_dim] * seq_world_size, + ] + + inp_shape[gather_dim + 1 :] + ).contiguous() + + def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 49244ea688b4..26d4a6125788 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -34,9 +34,9 @@ def data_gen(): input_ids = torch.Tensor( [ [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], - [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], - [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], - [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], + # [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], + # [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], + # [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], ] ).long() @@ -44,9 +44,9 @@ def data_gen(): attention_mask = torch.Tensor( [ [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], ] ).long() diff --git a/tests/test_shardformer/test_layer/test_sequence_parallel.py b/tests/test_shardformer/test_layer/test_sequence_parallel.py index 72a9fd09fbc6..fb1471d591ae 100644 --- a/tests/test_shardformer/test_layer/test_sequence_parallel.py +++ b/tests/test_shardformer/test_layer/test_sequence_parallel.py @@ -229,7 +229,7 @@ def seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size): @parameterize("seq_len", [128]) @parameterize("hidden_dim", [64]) @parameterize("head_num", [4]) -@parameterize("batch_size", [4]) +@parameterize("batch_size", [1]) def run_seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size): seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index a3600f208707..e1a159ea3048 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -161,7 +161,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 4, "pp_size": 1, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32", }, @@ -173,7 +173,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": False, "precision": "fp32", }, - {"tp_size": 2, "pp_size": 1, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32"}, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, { "tp_size": 2, "pp_size": 1, From 65db8b2cd5827fe27c2d4e710399537b47541e7c Mon Sep 17 00:00:00 2001 From: linsj20 Date: Fri, 26 Jan 2024 16:16:05 +0800 Subject: [PATCH 19/50] support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 --- colossalai/shardformer/layer/_operation.py | 242 +++++++++++++----- colossalai/shardformer/layer/linear.py | 4 +- .../shardformer/layer/qkv_fused_linear.py | 12 +- colossalai/shardformer/modeling/gpt2.py | 7 +- colossalai/shardformer/modeling/llama.py | 7 +- colossalai/shardformer/policies/gpt2.py | 2 +- 6 files changed, 196 insertions(+), 78 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index bca16aabef25..1e2b29340c8f 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -167,47 +167,57 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias, None, None, None -def _AllgatherLinear(input_, weight, process_group): +def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False): + # currently only support one single tensor as output group_size = dist.get_world_size(process_group) cur_rank = dist.get_rank(process_group) - input_shape = input_.shape - weight_shape = weight.shape - - output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)] + #output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)] # initialization of ring communication - input_shape[1] recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1 - recv_tensor = input_.clone() - send_tensor = input_.clone() - - recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) - send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) - handles = dist.batch_isend_irecv([send_op, recv_op]) + recv_tensors = {} + send_tensors = {} + for k, v in input_to_gather.items(): + recv_tensors[k] = v.clone() + send_tensors[k] = v.clone() + + def communicate_step(): + comm_ops = [] + for k in recv_tensors: + comm_ops.append(dist.P2POp(dist.irecv, recv_tensors[k], recv_rank, group=process_group)) + comm_ops.append(dist.P2POp(dist.isend, send_tensors[k], send_rank, group=process_group)) + return dist.batch_isend_irecv(comm_ops) + + def switch_step(): + for k in recv_tensors: + tmp_tensor = send_tensors[k] + send_tensors[k] = recv_tensors[k] + recv_tensors[k] = tmp_tensor + + output_tensors = [] + + handles = communicate_step() # first round: special case, retrive from local tensor - output_tensors[0] = F.linear(input_, weight) + output_tensors.append(func(**input_to_gather, **input_local)) for i in range(group_size - 2): for handle in handles: handle.wait() - tmp_tensor = send_tensor - send_tensor = recv_tensor - recv_tensor = tmp_tensor + switch_step() - recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) - send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) - handles = dist.batch_isend_irecv([recv_op, send_op]) + handles = communicate_step() # actual computation - output_tensors[i + 1] = F.linear(send_tensor, weight) + output_tensors.append(func(**send_tensors, **input_local)) # final round: special case, no need to send/recv again for handle in handles: handle.wait() - output_tensors[group_size - 1] = F.linear(recv_tensor, weight) - return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=1) + output_tensors.append(func(**recv_tensors, **input_local)) + + return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim) class _GatherForwardReduceScatterBackward(torch.autograd.Function): @@ -247,6 +257,41 @@ def backward(ctx, grad_output): return output, None, None +class _GatherForwardReduceScatterBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. + + """ + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.process_group = process_group + ctx.dim = dim + + return _gather(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + dim = ctx.dim + process_group = ctx.process_group + + # do reduce-scatter + new_shape = list(grad_output.shape) + assert ( + new_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " + new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) + grad_list = [item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)] + output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device) + dist.reduce_scatter(output, grad_list, group=process_group) + + return output, None, None + + class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """Gather input from sequence parallel in forward and reduce-scatter gradient in backward @@ -258,7 +303,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group @@ -266,11 +311,27 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, ctx.dim = dim ctx.overlap = overlap - if bias is not None: - input_parallel = _gather(input_, dim, process_group) - output = F.linear(input_parallel, weight, bias) + if ring is True: + input_to_gather = {} + input_local = {} + input_to_gather['input'] = input_ + input_local['weight'] = weight + + output = _ring_as_gather( + F.linear, + input_to_gather=input_to_gather, + input_local=input_local, + process_group=process_group, + ) + + if bias is not None: + output += bias else: - output = _AllgatherLinear(input_, weight, process_group) + input_parallel = _gather(input_, dim, process_group) + if bias is not None: + output = F.linear(input_parallel, weight, bias) + else: + output = F.linear(input_parallel, weight) return output @@ -373,34 +434,43 @@ def backward(ctx, grad_output): # wait until reduce-scatter finished reducescatter_handle.wait() - return output, grad_weight, grad_bias, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None -def _ReduceScatterLinear(input_, weight, process_group): +def _ring_as_reducescatter(func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1): + # currently only support one single tensor as output group_size = dist.get_world_size(process_group) cur_rank = dist.get_rank(process_group) - input_shape = input_.shape - # initialization of ring communication - # communicate(e.g.): 0->1->2->3 - # compute(e.g.): 3->2->1->0 - input_tensors = list(torch.split(input_, int(input_shape[1] / group_size), dim=1)) - input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank] - input_tensors.reverse() recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1 send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 + input_tensors = [] + for _ in range(group_size): + input_tensors.append({}) + for k, v in input_to_reducescatter.items(): + input_shape = v.shape + assert input_shape[reducescatter_dim] % group_size == 0 + _input_tensors = list(torch.split(v, input_shape[reducescatter_dim] // group_size, dim=reducescatter_dim)) + for i in range(group_size): + input_tensors[i][k] = _input_tensors[i] + input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank] + input_tensors.reverse() - # first round: special case, no reduce operation - output_tensor = F.linear(input_tensors[0], weight) + output_tensor = func(**input_tensors[0], **input_local) recv_tensor = output_tensor.clone() send_tensor = output_tensor.clone() - recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) - send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) - handles = dist.batch_isend_irecv([recv_op, send_op]) + + def communicate_step(): + recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) + send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) + return dist.batch_isend_irecv([recv_op, send_op]) + + handles = communicate_step() + # first round: special case, retrive from local tensor for i in range(group_size - 2): # actual computation - output_tensor = F.linear(input_tensors[i + 1], weight) + output_tensor = func(**input_tensors[i + 1], **input_local) for handle in handles: handle.wait() @@ -410,12 +480,10 @@ def _ReduceScatterLinear(input_, weight, process_group): send_tensor = output_tensor output_tensor = tmp_tensor - recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) - send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) - handles = dist.batch_isend_irecv([recv_op, send_op]) + handles = communicate_step() # final round: special case, no need to send/recv again - output_tensor = F.linear(input_tensors[group_size - 1], weight) + output_tensor = func(**input_tensors[-1], **input_local) for handle in handles: handle.wait() output_tensor += recv_tensor @@ -433,27 +501,44 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, dim): + def forward(ctx, input_, weight, bias, process_group, dim, ring): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.dim = dim - if bias is not None: - partial_output = F.linear(input_, weight, bias) + + if ring is True: + input_to_reducescatter = {} + input_local = {} + input_to_reducescatter['input'] = input_ + input_local['weight'] = weight + + if bias is not None: + input_to_reducescatter['bias'] = bias + + output = _ring_as_reducescatter( + F.linear, + input_to_reducescatter=input_to_reducescatter, + input_local=input_local, + process_group=process_group, + ) else: - return _ReduceScatterLinear(input_, weight, process_group) + if bias is not None: + partial_output = F.linear(input_, weight, bias) + else: + partial_output = F.linear(input_, weight) - output_shape = list(partial_output.shape) - assert ( - output_shape[dim] % dist.get_world_size(process_group) == 0 - ), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " - output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group) + output_shape = list(partial_output.shape) + assert ( + output_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " + output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group) - output_list = [ - item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim) - ] - output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous() - dist.reduce_scatter(output, output_list, group=process_group) + output_list = [ + item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous() + dist.reduce_scatter(output, output_list, group=process_group) return output @@ -481,7 +566,7 @@ def backward(ctx, grad_output): grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None - return grad_input, grad_weight, grad_bias, None, None + return grad_input, grad_weight, grad_bias, None, None, None class _ReduceScatterForwardGatherBackward(torch.autograd.Function): @@ -530,7 +615,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group @@ -538,9 +623,24 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, ctx.dim = dim ctx.overlap = overlap - input_parallel = _gather(input_, dim, process_group) + if ring is True: + input_to_gather = {} + input_local = {} + input_to_gather['input'] = input_ + input_local['other'] = weight + + output = _ring_as_gather( + torch.matmul, + input_to_gather=input_to_gather, + input_local=input_local, + process_group=process_group, + gather_dim=dim + ) + + else: + input_parallel = _gather(input_, dim, process_group) - output = torch.matmul(input_parallel, weight) + output = torch.matmul(input_parallel, weight) if bias is not None: output = output + bias @@ -620,7 +720,7 @@ def backward(ctx, grad_output): # wait until reduce-scatter finished reducescatter_handle.wait() - return output, grad_weight, grad_bias, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): @@ -873,10 +973,10 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre def linear_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False ): return _LinearWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring ) @@ -888,15 +988,15 @@ def reducescatter_forward_gather_backward(input_, process_group, dim): return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim) -def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1): - return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim) +def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False): + return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring) def matmul_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False ): return _MatmulWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring ) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 20a9f0328cfc..a773783b9f19 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -207,7 +207,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False) elif self.seq_parallel_mode == "2": output_parallel = linear_gather_forward_reducescatter_backward( - input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap + input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True ) if self.gather_output: @@ -429,7 +429,9 @@ def forward(self, input_: Tensor) -> Tensor: output = linear_reducescatter_forward_gather_backward( input_, self.weight, + process_group=self.process_group, dim=self.seq_parallel_dim, + ring=True, ) if not self.skip_bias_add: diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 6c5fb41494f0..a5d75db8a740 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -323,6 +323,11 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: output_parallel = matmul_gather_forward_reducescatter_backward( input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap ) + elif self.seq_parallel_mode == "2": + input_parallel = input_ + output_parallel = matmul_gather_forward_reducescatter_backward( + input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap, True + ) if self.gather_output: # All-gather across the partitions. @@ -528,10 +533,14 @@ def forward(self, input_: Tensor) -> Tensor: handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - output_parallel = torch.matmul(input_, self.weight) if self.seq_parallel_mode is None: + output_parallel = torch.matmul(input_, self.weight) output = reduce_forward(output_parallel, self.process_group) elif self.seq_parallel_mode == "1": + output_parallel = torch.matmul(input_, self.weight) + output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + elif self.seq_parallel_mode == "2": + output_parallel = torch.matmul(input_, self.weight) output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) if not self.skip_bias_add: @@ -702,7 +711,6 @@ def from_native_module( # process_group=process_group, # is_transposed=False) # linear_1d.bias.data.copy_(sharded_bias.data) - print(linear_1d.weight.shape) return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index f98a28b24ce4..710e26013388 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -1135,7 +1135,12 @@ def forward( head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) + if sp_mode in ["2"]: + input_ids = _gather(input_ids, 1, sp_group) + inputs_embeds = self.wte(input_ids) + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + else: + inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 2d3a135d6d88..cf4ea9b58206 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -24,6 +24,8 @@ gather_forward_split_backward, reducescatter_forward_gather_backward, split_forward_gather_backward, + gather_forward_reducescatter_backward, + reducescatter_forward_gather_backward, ) from colossalai.shardformer.shard import ShardConfig @@ -443,7 +445,6 @@ def llama_for_sequence_classification_forward( hidden_states = transformer_outputs.get("hidden_states") return {"hidden_states": hidden_states} - def get_llama_flash_attention_forward(shard_config): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb @@ -872,7 +873,7 @@ def forward( # TODO (linshengjie) Block attention with ring #### block_wise = False - seq_len = query_states[2] + seq_len = query_states.shape[2] seq_block = 1024 if block_wise and seq_len > seq_block: assert query_states.shape[2] % seq_block == 0 @@ -1022,6 +1023,7 @@ def _prepare_decoder_attention_mask_partial( return combined_attention_mask + def forward( self, input_ids: torch.LongTensor = None, @@ -1148,6 +1150,7 @@ def custom_forward(*inputs): position_ids, ) else: + layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 2137387cdf6c..02e05681c4e9 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -55,7 +55,7 @@ def module_policy(self): sp_size = self.shard_config.sequence_parallel_size sp_group = self.shard_config.sequence_parallel_process_group overlap = self.shard_config.enable_sequence_overlap - sp_partial_derived = sp_mode in ["1"] + sp_partial_derived = sp_mode in ["1", "2"] if sp_mode == "2": pass From e72bd873ed33db41f40785416193dd405478e0b1 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Fri, 26 Jan 2024 17:47:46 +0800 Subject: [PATCH 20/50] polish code --- colossalai/shardformer/modeling/gpt2.py | 5 - colossalai/shardformer/modeling/llama.py | 21 +-- colossalai/shardformer/policies/gpt2.py | 15 +- colossalai/shardformer/policies/llama.py | 27 ++-- .../test_layer/test_sequence_parallel.py | 5 - .../test_model/test_shard_gpt2.py | 148 +++++++++--------- .../test_model/test_shard_llama.py | 7 - 7 files changed, 96 insertions(+), 132 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 710e26013388..7d53d9cc763d 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -104,11 +104,6 @@ def _get_attention_mask( return attention_mask, encoder_attention_mask -def print_rank(prompt, value, rank=0): - if dist.get_rank() == rank: - print(f"rank-{rank}, {prompt}: {value}") - - class GPT2PipelineForwards: """ This class serves as a micro library for forward function substitution of GPT2 models diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index cf4ea9b58206..96403f74f90c 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -24,8 +24,6 @@ gather_forward_split_backward, reducescatter_forward_gather_backward, split_forward_gather_backward, - gather_forward_reducescatter_backward, - reducescatter_forward_gather_backward, ) from colossalai.shardformer.shard import ShardConfig @@ -40,11 +38,6 @@ LATEST_VERSION = False -def print_rank(prompt, value, rank=0): - if dist.get_rank() == rank: - print(f"rank-{rank}, {prompt}: {value}") - - class LlamaPipelineForwards: """ This class serves as a micro library for forward function substitution of Llama models @@ -445,7 +438,8 @@ def llama_for_sequence_classification_forward( hidden_states = transformer_outputs.get("hidden_states") return {"hidden_states": hidden_states} -def get_llama_flash_attention_forward(shard_config): + +def get_llama_flash_attention_forward(shard_config, sp_mode, sp_size): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb llama_version = 2 @@ -466,11 +460,9 @@ def forward( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - sp_mode = shard_config.sequence_parallelism_mode - sp_size = shard_config.sequence_parallel_size if sp_mode in ["1", "2"]: - q_len *= shard_config.sequence_parallel_size + q_len *= sp_size assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." query_states = self.q_proj(hidden_states) @@ -484,8 +476,6 @@ def forward( value_states = all_to_all_comm(value_states) bsz, q_len, _ = query_states.size() - if shard_config.sequence_parallel_size < 4: - print(query_states.shape) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -943,9 +933,6 @@ def forward( return forward -import torch.distributed as dist - - def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): logger = logging.get_logger(__name__) @@ -1023,7 +1010,6 @@ def _prepare_decoder_attention_mask_partial( return combined_attention_mask - def forward( self, input_ids: torch.LongTensor = None, @@ -1150,7 +1136,6 @@ def custom_forward(*inputs): position_ids, ) else: - layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 02e05681c4e9..86aa8d0b074d 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List @@ -56,10 +57,16 @@ def module_policy(self): sp_group = self.shard_config.sequence_parallel_process_group overlap = self.shard_config.enable_sequence_overlap sp_partial_derived = sp_mode in ["1", "2"] + use_flash_attention = self.shard_config.enable_flash_attention + # todo: currently sp mode 2 and 3 need to be used with flashattention + if sp_mode in ["2", "3"]: + if not use_flash_attention: + warnings.warn( + f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will enable FlashAttention automatically." + ) + use_flash_attention = True - if sp_mode == "2": - pass - elif sp_mode == "3": + if sp_mode == "3": decoder_attribute_replacement = { "num_heads": self.model.config.num_attention_heads // sp_size, } @@ -160,7 +167,7 @@ def module_policy(self): target_key=GPT2Block, ) - if self.shard_config.enable_flash_attention: + if use_flash_attention: self.append_or_create_method_replacement( description={ "forward": get_gpt2_flash_attention_forward(sp_mode, sp_size, sp_group), diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 2ee28075407a..c1e0c6b68c20 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -14,7 +14,6 @@ get_llama_model_forward_for_flash_attn, get_llama_seq_parallel_attention_forward, get_llama_seq_parallel_model_forward, - get_llama_decoder_seq_parallel_model_forward, get_lm_forward_with_dist_cross_entropy, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -51,9 +50,16 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None sp_size = self.shard_config.sequence_parallel_size sp_group = self.shard_config.sequence_parallel_process_group - # overlap = self.shard_config.enable_sequence_overlap - # sp_partial_derived = sp_mode in ["1"] - # todo: Support SP for LlaMa model + if self.pipeline_stage_manager is not None: + if sp_mode is not None: + warnings.warn( + "Sequence parallelism is not supported under pipeline parallelism setting. " + "Sequence parallelism will be disabled." + ) + sp_mode = None + sp_size = None + sp_group = None + if sp_mode == "1": self.append_or_create_method_replacement( description={ @@ -188,17 +194,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_key=LlamaDecoderLayer, ) - ''' - if sp_mode == "1" and False: - self.append_or_create_method_replacement( - description={ - "forward": get_llama_decoder_seq_parallel_model_forward(sp_mode, sp_size, sp_group), - }, - policy=policy, - target_key=LlamaDecoderLayer, - ) - ''' - self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="norm", @@ -212,7 +207,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_llama_flash_attention_forward(shard_config=self.shard_config), + "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size), }, policy=policy, target_key=LlamaAttention, diff --git a/tests/test_shardformer/test_layer/test_sequence_parallel.py b/tests/test_shardformer/test_layer/test_sequence_parallel.py index fb1471d591ae..740fbaee05e8 100644 --- a/tests/test_shardformer/test_layer/test_sequence_parallel.py +++ b/tests/test_shardformer/test_layer/test_sequence_parallel.py @@ -220,11 +220,6 @@ def seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size): assert_close(o_grad_seq, o_grad) assert_close(x_grad_seq_gather, x_grad) - # print_rank('x_grad', x_grad_seq, 0) - # print_rank('x_grad', x_grad_seq, 1) - # print_rank('x_grad', x_grad_seq, 2) - # print_rank('x_grad', x_grad_seq, 3) - @parameterize("seq_len", [128]) @parameterize("hidden_dim", [64]) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 1501a8e12fb1..19619e6a3eaa 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -1,6 +1,5 @@ import pytest import torch -import torch.distributed as dist import colossalai from colossalai.logging import disable_existing_loggers @@ -20,11 +19,6 @@ ) -def print_rank(prompt, value, rank=0): - if dist.get_rank() == rank: - print(f"rank-{rank}, {prompt}: {value}") - - def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( model_fn, loss_fn, test_config @@ -164,77 +158,77 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp32", "initial_scale": 1, }, - # { - # "tp_size": 1, - # "pp_size": 1, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "3", - # "enable_flash_attention": True, - # "use_lazy_init": True, - # "precision": "fp32", - # "initial_scale": 1, - # }, - # { - # "tp_size": 2, - # "pp_size": 2, - # "num_microbatches": 4, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 1, - # "pp_size": 2, - # "num_microbatches": 4, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 4, - # "pp_size": 1, - # "enable_all_optimization": False, - # "enable_flash_attention": True, - # "use_lazy_init": False, - # "precision": "fp32", - # }, - # { - # "tp_size": 2, - # "pp_size": 1, - # "enable_all_optimization": True, - # "use_lazy_init": False, - # "precision": "fp32", - # }, - # { - # "tp_size": 2, - # "pp_size": 2, - # "num_microbatches": 4, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "precision": "fp32", - # }, - # { - # "tp_size": 2, - # "pp_size": 1, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 2, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 1, - # "pp_size": 2, - # "num_microbatches": 2, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 1, - # "precision": "fp16", - # "initial_scale": 1, - # }, + { + "tp_size": 1, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "3", + "enable_flash_attention": True, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": False, + "enable_flash_attention": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, ], ) @clear_cache_before_run() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index e1a159ea3048..143e051304ed 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -2,7 +2,6 @@ import pytest import torch -import torch.distributed as dist import colossalai from colossalai.logging import disable_existing_loggers @@ -21,12 +20,6 @@ unwrap_model, ) - -def print_rank(prompt, value, rank=0): - if dist.get_rank() == rank: - print(f"rank-{rank}, {prompt}: {value}") - - os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" From 2076bcf8928d845406b9f30f350489956e95f1ec Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Thu, 1 Feb 2024 11:58:10 +0800 Subject: [PATCH 21/50] enable distributed attn mask when using sp mode 2 and 3 in llama --- colossalai/shardformer/modeling/llama.py | 6 ++++-- tests/test_shardformer/test_model/test_shard_llama.py | 2 -- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 96403f74f90c..743525a84795 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -1071,8 +1071,8 @@ def forward( else: inputs_embeds = self.embed_tokens(input_ids) - # TODO Internal function - use_distributed_mask = False + # TODO use_distributed_mask + use_distributed_mask = True if sp_mode in ["2", "3"] else False # embed positions if sp_mode is None or use_distributed_mask is False: @@ -1120,6 +1120,7 @@ def forward( all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None + if (self.gradient_checkpointing or sp_mode in ["2", "3"]) and self.training: def create_custom_forward(module): @@ -1135,6 +1136,7 @@ def custom_forward(*inputs): attention_mask, position_ids, ) + else: layer_outputs = decoder_layer( hidden_states, diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 143e051304ed..4d82d4554bd8 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -192,8 +192,6 @@ def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name != "transformers_llama": - continue check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() From bb1857799a2ce3ca22c21975fbbe13fef53982c9 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Thu, 1 Feb 2024 12:07:36 +0800 Subject: [PATCH 22/50] automatically enable flash attn when using sp mode 2 and 3 in llama --- colossalai/shardformer/policies/llama.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index c1e0c6b68c20..35c29d6219a8 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -60,6 +60,15 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: sp_size = None sp_group = None + use_flash_attention = self.shard_config.enable_flash_attention + # todo: currently sp mode 2 and 3 need to be used with flashattention + if sp_mode in ["2", "3"]: + if not use_flash_attention: + warnings.warn( + f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will enable FlashAttention automatically." + ) + use_flash_attention = True + if sp_mode == "1": self.append_or_create_method_replacement( description={ @@ -204,7 +213,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) # use flash attention - if self.shard_config.enable_flash_attention: + if use_flash_attention: self.append_or_create_method_replacement( description={ "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size), From 9788fd8622cc1fe869fd23b43d185c3d0886fbb4 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Thu, 1 Feb 2024 15:09:29 +0800 Subject: [PATCH 23/50] inplace attn mask --- colossalai/shardformer/modeling/llama.py | 9 ++++++--- .../test_model/test_shard_llama.py | 18 +++++++++--------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 743525a84795..ad6bbaa00e2a 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -979,9 +979,10 @@ def _expand_mask_partial(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Option expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len * world_size, src_len).to(dtype) - inverted_mask = 1.0 - expanded_mask + # inverted_mask = 1.0 - expanded_mask + inverted_mask = expanded_mask.mul_(-1).add_(1.0) - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + return inverted_mask.masked_fill_(inverted_mask.to(torch.bool), torch.finfo(dtype).min) # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask def _prepare_decoder_attention_mask_partial( @@ -1005,7 +1006,9 @@ def _prepare_decoder_attention_mask_partial( attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1], sp_group=sp_group ).to(inputs_embeds.device) combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask.add_(combined_attention_mask) ) return combined_attention_mask diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 4d82d4554bd8..b56c60add5cd 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -135,15 +135,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp32", "initial_scale": 1, }, - { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": True, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - }, + # { + # "tp_size": 2, + # "pp_size": 2, + # "num_microbatches": 2, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "precision": "fp16", + # "initial_scale": 1, + # }, { "tp_size": 1, "pp_size": 2, From 544a06da80b98277b8d97c79fe0eb48bbe9e5b35 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Mon, 19 Feb 2024 15:42:40 +0800 Subject: [PATCH 24/50] add zero2 support for sequence parallel --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 6c8a7ff4d2bc..a48c20bfe1fe 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -748,7 +748,6 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]: # Get all working gradients and gradients to be synchronized. all_working_grads = _get_all_working_grads() grads_to_sync = _get_grads_to_sync(all_working_grads) - if self.require_grad_sync and grads_to_sync is not None: # Synchronize sequence parallelism gradients if required. SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync) @@ -1200,7 +1199,13 @@ def configure( tp_process_group=self.tp_group, ) else: - if self.dp_size == 1: + if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "3": + self.zero_dp_size = self.sp_size + self.zero_dp_group = self.sp_group + else: + self.zero_dp_size = self.dp_size + self.zero_dp_group = self.dp_group + if self.zero_dp_size == 1: warnings.warn( "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "If you are not intended to use cpu_offload, please consider set zero_stage=0." @@ -1212,7 +1217,7 @@ def configure( model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info, - dp_process_group=self.dp_group, + dp_process_group=self.zero_dp_group, tp_process_group=self.tp_group, pp_process_group=self.pp_group, verbose=True, From c3d0e83ed84b44d166bb2507f086a0b80931aa4c Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Tue, 27 Feb 2024 13:47:15 +0800 Subject: [PATCH 25/50] polish code --- .../booster/plugin/hybrid_parallel_plugin.py | 14 ++++---- colossalai/shardformer/modeling/bert.py | 4 +-- colossalai/shardformer/modeling/bloom.py | 4 +-- colossalai/shardformer/modeling/chatglm2.py | 4 +-- colossalai/shardformer/modeling/gpt2.py | 20 +++++------ colossalai/shardformer/modeling/llama.py | 36 +++++++++---------- colossalai/shardformer/policies/bert.py | 4 +-- colossalai/shardformer/policies/bloom.py | 4 +-- colossalai/shardformer/policies/chatglm2.py | 4 +-- colossalai/shardformer/policies/gpt2.py | 6 ++-- colossalai/shardformer/policies/llama.py | 17 ++++++--- colossalai/shardformer/shard/shard_config.py | 8 ++--- tests/test_shardformer/test_model/_utils.py | 15 +------- .../test_model/test_shard_gpt2.py | 4 +-- .../test_model/test_shard_llama.py | 26 +++++++------- 15 files changed, 82 insertions(+), 88 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index a48c20bfe1fe..72dbb6201b50 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -35,7 +35,7 @@ from .pp_plugin_base import PipelinePluginBase DP_AXIS, PP_AXIS, TP_AXIS, SP_AXIS = 0, 1, 2, 3 -SUPPORT_SP_MODE = ["1", "2", "3"] +SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} @@ -173,12 +173,12 @@ def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None): """ if self.shard_config.enable_sequence_parallelism: - if self.shard_config.sequence_parallelism_mode in ["1", "2"]: + if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: # If sequence parallelism is enabled and mode is 1 or 2, gradients are synchronized # across the tensor parallelism group. group = self.tp_group require_flag = True - elif self.shard_config.sequence_parallelism_mode == "3": + elif self.shard_config.sequence_parallelism_mode == "all_to_all": # If sequence parallelism is enabled and mode is 3, gradients are synchronized # across the sequence parallelism group. group = self.sp_group @@ -1003,7 +1003,7 @@ def __init__( assert ( self.sequence_parallelism_mode in SUPPORT_SP_MODE ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" - if self.sequence_parallelism_mode in ["1", "2"]: + if self.sequence_parallelism_mode in ["split_gather", "ring"]: assert ( tp_size > 1 ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" @@ -1013,7 +1013,7 @@ def __init__( ) self.sp_size = 1 self.dp_size = dist.get_world_size() // (tp_size * pp_size) - elif self.sequence_parallelism_mode in ["3"]: + elif self.sequence_parallelism_mode in ["all_to_all"]: assert ( tp_size == 1 ), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with tensor parallelism" @@ -1077,7 +1077,7 @@ def __init__( self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) - if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["1", "2"]: + if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]: self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) else: self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS) @@ -1199,7 +1199,7 @@ def configure( tp_process_group=self.tp_group, ) else: - if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "3": + if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": self.zero_dp_size = self.sp_size self.zero_dp_group = self.sp_group else: diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 99cd7acd312e..0838fcee682e 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -186,7 +186,7 @@ def bert_model_forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config is not None and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "1": + if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = split_forward_gather_backward( hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group ) @@ -241,7 +241,7 @@ def custom_forward(*inputs): # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config is not None and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "1": + if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = gather_forward_split_backward( hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group ) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 370017384616..fe70376e144d 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -214,7 +214,7 @@ def bloom_model_forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "1": + if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = split_forward_gather_backward( hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group ) @@ -263,7 +263,7 @@ def custom_forward(*inputs): # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "1": + if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = gather_forward_split_backward( hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group ) diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 9bfb7053943b..9207b34d0d1c 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -192,7 +192,7 @@ def chatglm_model_forward( start_idx, end_idx = stage_index[0], stage_index[1] if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "1": + if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = split_forward_gather_backward( hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group ) @@ -222,7 +222,7 @@ def chatglm_model_forward( presents = presents + (kv_cache,) if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "1": + if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = gather_forward_split_backward( hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group ) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 7d53d9cc763d..e12c826b21a0 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -225,7 +225,7 @@ def gpt2_model_forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "1": + if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = split_forward_gather_backward( hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group ) @@ -284,7 +284,7 @@ def custom_forward(*inputs): # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "1": + if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = gather_forward_split_backward( hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group ) @@ -823,7 +823,7 @@ def forward( else: query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - if sp_mode == "3": + if sp_mode == "all_to_all": query = all_to_all_comm(query) key = all_to_all_comm(key) value = all_to_all_comm(value) @@ -850,7 +850,7 @@ def forward( dropout_p = self.attn_dropout.p if self.training else 0.0 attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) - if sp_mode == "3": + if sp_mode == "all_to_all": attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) attn_output = self.c_proj(attn_output) @@ -904,7 +904,7 @@ def forward( # use variable seq_len to replace input_shape[-1] seq_len = input_shape[-1] - if sp_mode in ["2", "3"]: + if sp_mode in ["ring", "all_to_all"]: seq_len *= sp_size if token_type_ids is not None: @@ -917,7 +917,7 @@ def forward( past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) - if sp_mode in ["2", "3"]: + if sp_mode in ["ring", "all_to_all"]: past_length *= sp_size if position_ids is None: position_ids = torch.arange( @@ -929,10 +929,10 @@ def forward( position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) # split position ids when using sequence parallel - if sp_mode in ["2", "3"]: + if sp_mode in ["ring", "all_to_all"]: position_ids = torch.chunk(position_ids.clone(), sp_size, dim=1)[dist.get_rank(sp_group)] - if sp_mode in ["2", "3"]: + if sp_mode in ["ring", "all_to_all"]: attention_mask = _gather(attention_mask, 1, sp_group) # Prepare head mask if needed @@ -1130,7 +1130,7 @@ def forward( head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: - if sp_mode in ["2"]: + if sp_mode in ["ring"]: input_ids = _gather(input_ids, 1, sp_group) inputs_embeds = self.wte(input_ids) inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) @@ -1172,7 +1172,7 @@ def forward( all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None - if sp_mode == "1": + if sp_mode == "split_gather": # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] hidden_states = split_forward_gather_backward(hidden_states, dim=1, process_group=sp_group) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index ad6bbaa00e2a..35fac95a50bd 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -461,7 +461,7 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - if sp_mode in ["1", "2"]: + if sp_mode in ["split_gather", "ring"]: q_len *= sp_size assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." @@ -470,7 +470,7 @@ def forward( value_states = self.v_proj(hidden_states) # sp: all-to-all comminucation when introducing sequence parallel - if sp_mode == "3": + if sp_mode == "all_to_all": query_states = all_to_all_comm(query_states) key_states = all_to_all_comm(key_states) value_states = all_to_all_comm(value_states) @@ -529,7 +529,7 @@ def forward( # ) # sp: all-to-all comminucation when introducing sequence parallel - if sp_mode == "3": + if sp_mode == "all_to_all": attn_output = all_to_all_comm(attn_output, None, scatter_dim=1, gather_dim=2) attn_output = self.o_proj(attn_output) @@ -807,8 +807,8 @@ def forward( use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - # sp: modify sp_len when sequence parallel mode is 2 - if sp_mode in ["1", "2"]: + # sp: modify sp_len when sequence parallel mode is ring + if sp_mode in ["split_gather", "ring"]: q_len *= sp_size if self.config.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp @@ -833,7 +833,7 @@ def forward( value_states = self.v_proj(hidden_states) # sp: all-to-all comminucation when introducing sequence parallel - if sp_mode == "3": + if sp_mode == "all_to_all": query_states = all_to_all_comm(query_states, sp_group) key_states = all_to_all_comm(key_states, sp_group) value_states = all_to_all_comm(value_states, sp_group) @@ -913,7 +913,7 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() # sp: all-to-all comminucation when introducing sequence parallel - if sp_mode == "3": + if sp_mode == "all_to_all": attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) else: @@ -1044,7 +1044,7 @@ def forward( raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") # sp: modify seq_length when using sequence parallel - if sp_mode in ["2", "3"]: + if sp_mode in ["ring", "all_to_all"]: seq_length *= sp_size seq_length_with_past = seq_length @@ -1066,7 +1066,7 @@ def forward( position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: - if sp_mode == "2": + if sp_mode == "ring": input_ids = _gather(input_ids, 1, sp_group) inputs_embeds = self.embed_tokens(input_ids) input_ids = input_ids.chunk(sp_size, dim=1)[torch.distributed.get_rank(sp_group)] @@ -1075,7 +1075,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) # TODO use_distributed_mask - use_distributed_mask = True if sp_mode in ["2", "3"] else False + use_distributed_mask = True if sp_mode in ["ring", "all_to_all"] else False # embed positions if sp_mode is None or use_distributed_mask is False: @@ -1084,7 +1084,7 @@ def forward( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) - if sp_mode in ["2", "3"]: + if sp_mode in ["ring", "all_to_all"]: attention_mask = _gather(attention_mask, 1, sp_group) attention_mask = self._prepare_decoder_attention_mask( @@ -1103,10 +1103,10 @@ def forward( attention_mask = _gather(attention_mask, 1, sp_group) hidden_states = inputs_embeds - if sp_mode == "1": + if sp_mode == "split_gather": hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) - if (self.gradient_checkpointing or sp_mode in ["2", "3"]) and self.training: + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." @@ -1124,7 +1124,7 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if (self.gradient_checkpointing or sp_mode in ["2", "3"]) and self.training: + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -1314,7 +1314,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) - if sp_mode == "1": + if sp_mode == "split_gather": hidden_states = gather_forward_reducescatter_backward(hidden_states, sp_group, 1) # Self Attention @@ -1327,19 +1327,19 @@ def forward( use_cache=use_cache, ) - if sp_mode == "1": + if sp_mode == "split_gather": hidden_states = reducescatter_forward_gather_backward(hidden_states, sp_group, 1) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - if sp_mode == "1": + if sp_mode == "split_gather": hidden_states = gather_forward_reducescatter_backward(hidden_states, sp_group, 1) hidden_states = self.mlp(hidden_states) - if sp_mode == "1": + if sp_mode == "split_gather": hidden_states = reducescatter_forward_gather_backward(hidden_states, sp_group, 1) hidden_states = residual + hidden_states diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 1a5dba0e0064..c161f7d1f934 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -68,7 +68,7 @@ def module_policy(self): sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None overlap = self.shard_config.enable_sequence_overlap - sp_partial_derived = sp_mode in ["1"] + sp_partial_derived = sp_mode == "split_gather" if self.shard_config.enable_tensor_parallelism: policy[BertLayer] = ModulePolicyDescription( @@ -141,7 +141,7 @@ def module_policy(self): ] ) - if sp_mode == "1": + if sp_mode == "split_gather": self.append_or_create_method_replacement( description={"forward": bert_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 2d235b8a0085..28909697d2ba 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -58,7 +58,7 @@ def module_policy(self): sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None overlap = self.shard_config.enable_sequence_overlap - sp_partial_derived = sp_mode in ["1"] + sp_partial_derived = sp_mode == "split_gather" if self.shard_config.enable_tensor_parallelism: policy[BloomBlock] = ModulePolicyDescription( @@ -147,7 +147,7 @@ def module_policy(self): target_key=BloomBlock, ) - if sp_mode == "1": + if sp_mode == "split_gather": self.append_or_create_method_replacement( description={"forward": get_bloom_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 7d1ff3f8e59e..117c6ef3adad 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -58,7 +58,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None overlap = self.shard_config.enable_sequence_overlap - sp_partial_derived = sp_mode in ["1"] + sp_partial_derived = sp_mode == "split_gather" if self.shard_config.enable_tensor_parallelism: policy[ChatGLMModel] = ModulePolicyDescription( @@ -148,7 +148,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) # use sequence parallel - if sp_mode == "1": + if sp_mode == "split_gather": self.append_or_create_method_replacement( description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 86aa8d0b074d..48745e048d03 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -56,17 +56,17 @@ def module_policy(self): sp_size = self.shard_config.sequence_parallel_size sp_group = self.shard_config.sequence_parallel_process_group overlap = self.shard_config.enable_sequence_overlap - sp_partial_derived = sp_mode in ["1", "2"] + sp_partial_derived = sp_mode in ["split_gather", "ring"] use_flash_attention = self.shard_config.enable_flash_attention # todo: currently sp mode 2 and 3 need to be used with flashattention - if sp_mode in ["2", "3"]: + if sp_mode in ["ring", "all_to_all"]: if not use_flash_attention: warnings.warn( f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will enable FlashAttention automatically." ) use_flash_attention = True - if sp_mode == "3": + if sp_mode == "all_to_all": decoder_attribute_replacement = { "num_heads": self.model.config.num_attention_heads // sp_size, } diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 35c29d6219a8..4c4d9e4dcef8 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -47,6 +47,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: else: norm_cls = RMSNorm + if self.pipeline_stage_manager is not None: + self.shard_config.enable_sequence_parallelism = False + self.shard_config.enable_sequence_overlap = False + warnings.warn( + f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" + ) + sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None sp_size = self.shard_config.sequence_parallel_size sp_group = self.shard_config.sequence_parallel_process_group @@ -61,15 +68,15 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: sp_group = None use_flash_attention = self.shard_config.enable_flash_attention - # todo: currently sp mode 2 and 3 need to be used with flashattention - if sp_mode in ["2", "3"]: + # Currently sp mode ring and all_to_all need to be used with flashattention + if sp_mode in ["ring", "all_to_all"]: if not use_flash_attention: warnings.warn( f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will enable FlashAttention automatically." ) use_flash_attention = True - if sp_mode == "1": + if sp_mode == "split_gather": self.append_or_create_method_replacement( description={ "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group), @@ -84,7 +91,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=LlamaAttention, ) - elif sp_mode == "2": + elif sp_mode == "ring": self.append_or_create_method_replacement( description={ "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), @@ -99,7 +106,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=LlamaModel, ) - elif sp_mode == "3": + elif sp_mode == "all_to_all": decoder_attribute_replacement = { "num_heads": self.model.config.num_attention_heads // sp_size, # "head_dim": self.model.config.hidden_size // self.model.config.num_attention_heads, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index b7310d81b314..2858311329e0 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -8,7 +8,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager __all__ = ["ShardConfig"] -SUPPORT_SP_MODE = ["1", "2", "3"] +SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] @dataclass @@ -65,11 +65,11 @@ def __post_init__(self): assert ( self.sequence_parallelism_mode in SUPPORT_SP_MODE ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" - if self.sequence_parallelism_mode in ["1", "2"]: + if self.sequence_parallelism_mode in ["split_gather", "ring"]: assert ( self.enable_tensor_parallelism ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True" - elif self.sequence_parallelism_mode in ["3"]: + elif self.sequence_parallelism_mode in ["all_to_all"]: assert ( not self.enable_tensor_parallelism ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False" @@ -112,7 +112,7 @@ def _turn_on_all_optimization(self): self.enable_sequence_parallelism = True self.enable_sequence_overlap = True # todo modify default sequence parallelism mode and process group - self.sequence_parallelism_mode = "1" + self.sequence_parallelism_mode = "split_gather" self.sequence_parallel_process_group = self.tensor_parallel_process_group def _infer(self): diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 4d4e54e0178c..2c9f2ceefb09 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -165,19 +165,6 @@ def _criterion(outputs, inputs): data = data_gen_fn() - # if ( - # booster.plugin.shard_config.enable_sequence_parallelism - # and booster.plugin.shard_config.sequence_parallelism_mode in ["1", "2"] - # and booster.plugin.tp_size != 0 - # ): - # seq_len = data["input_ids"].shape[-1] - # lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) - # times = lcm // seq_len - # input_shape = data["input_ids"].shape - # for k, v in data.items(): - # if v.shape == input_shape: - # data[k] = v.repeat((1,) * (v.dim() - 1) + (times,)) - shard_test_data = {} for k, v in data.items(): if k not in ["input_ids", "attention_mask"]: @@ -189,7 +176,7 @@ def _criterion(outputs, inputs): dist.get_rank(booster.plugin.shard_config.sequence_parallel_process_group) ] if booster.plugin.shard_config.enable_sequence_parallelism - and booster.plugin.shard_config.sequence_parallelism_mode in ["2", "3"] + and booster.plugin.shard_config.sequence_parallelism_mode in ["ring", "all_to_all"] else data[k].clone() ) unshard_test_data = {} diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 19619e6a3eaa..ad1cbb7bc95a 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -152,7 +152,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "enable_all_optimization": False, "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "2", + "sequence_parallelism_mode": "ring", "enable_flash_attention": True, "use_lazy_init": True, "precision": "fp32", @@ -163,7 +163,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 1, "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "3", + "sequence_parallelism_mode": "all_to_all", "enable_flash_attention": True, "use_lazy_init": True, "precision": "fp32", diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index b56c60add5cd..ff4b75f7236e 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -98,7 +98,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 1, "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "2", + "sequence_parallelism_mode": "ring", "enable_flash_attention": True, "use_lazy_init": True, "precision": "fp32", @@ -109,7 +109,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 1, "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "2", + "sequence_parallelism_mode": "ring", "use_lazy_init": True, "precision": "fp32", "initial_scale": 1, @@ -119,7 +119,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 1, "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "3", + "sequence_parallelism_mode": "all_to_all", "use_lazy_init": True, "precision": "fp32", "initial_scale": 1, @@ -129,21 +129,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 1, "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "3", + "sequence_parallelism_mode": "all_to_all", "enable_flash_attention": True, "use_lazy_init": True, "precision": "fp32", "initial_scale": 1, }, - # { - # "tp_size": 2, - # "pp_size": 2, - # "num_microbatches": 2, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "precision": "fp16", - # "initial_scale": 1, - # }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 1, "pp_size": 2, From 9f2f1fe2226aa9d9e08a178f9414edfaa29b9c9a Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Tue, 27 Feb 2024 14:42:21 +0800 Subject: [PATCH 26/50] fix bugs --- colossalai/shardformer/layer/linear.py | 19 ++++++++++--------- .../shardformer/layer/qkv_fused_linear.py | 8 ++++---- .../test_gpt2_qkv_fused_linear_1d.py | 2 +- .../test_layer/test_linear_1d.py | 2 +- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index a773783b9f19..3014e97b0573 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -23,9 +23,8 @@ ) from ._operation import ( - gather_forward_split_backward, gather_forward_reducescatter_backward, - reducescatter_forward_gather_backward, + gather_forward_split_backward, linear_gather_forward_reducescatter_backward, linear_reducescatter_forward_gather_backward, linear_with_async_comm, @@ -202,10 +201,12 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: if self.seq_parallel_mode is None: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) - elif self.seq_parallel_mode == "1": - input_parallel = gather_forward_reducescatter_backward(input_parallel, self.process_group, self.seq_parallel_dim) + elif self.seq_parallel_mode == "split_gather": + input_parallel = gather_forward_reducescatter_backward( + input_parallel, self.process_group, self.seq_parallel_dim + ) output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False) - elif self.seq_parallel_mode == "2": + elif self.seq_parallel_mode == "ring": output_parallel = linear_gather_forward_reducescatter_backward( input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True ) @@ -419,13 +420,13 @@ def forward(self, input_: Tensor) -> Tensor: if self.seq_parallel_mode is None: output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output = reduce_forward(output_parallel, self.process_group) - elif self.seq_parallel_mode == "1": - #output = linear_with_async_comm(input_, self.weight, None, None, False) + elif self.seq_parallel_mode == "split_gather": + # output = linear_with_async_comm(input_, self.weight, None, None, False) output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output = reducescatter_forward_gather_backward( - output_parallel, self.process_group, self.seq_parallel_dim + output_parallel, self.process_group, self.seq_parallel_dim ) - elif self.seq_parallel_mode == "2": + elif self.seq_parallel_mode == "ring": output = linear_reducescatter_forward_gather_backward( input_, self.weight, diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index a5d75db8a740..dc3634238f74 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -318,12 +318,12 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: output_parallel = matmul_with_async_comm( input_parallel, self.weight, bias, self.process_group, self.async_communication ) - elif self.seq_parallel_mode == "1": + elif self.seq_parallel_mode == "split_gather": input_parallel = input_ output_parallel = matmul_gather_forward_reducescatter_backward( input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap ) - elif self.seq_parallel_mode == "2": + elif self.seq_parallel_mode == "ring": input_parallel = input_ output_parallel = matmul_gather_forward_reducescatter_backward( input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap, True @@ -536,10 +536,10 @@ def forward(self, input_: Tensor) -> Tensor: if self.seq_parallel_mode is None: output_parallel = torch.matmul(input_, self.weight) output = reduce_forward(output_parallel, self.process_group) - elif self.seq_parallel_mode == "1": + elif self.seq_parallel_mode == "split_gather": output_parallel = torch.matmul(input_, self.weight) output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) - elif self.seq_parallel_mode == "2": + elif self.seq_parallel_mode == "ring": output_parallel = torch.matmul(input_, self.weight) output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index e4351ddae7f4..e9aa0dbedbc8 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -135,7 +135,7 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool): @parameterize("lazy_init", [False, True]) -@parameterize("seq_parallel_mode", ["1", None]) +@parameterize("seq_parallel_mode", ["split_gather", None]) @parameterize("overlap", [True]) def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, overlap: bool): check_linear_conv_1d_col(lazy_init, seq_parallel_mode, overlap) diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index 28b4c36f7e89..21d3190de7ae 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -176,7 +176,7 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: @parameterize("lazy_init", [False, True]) -@parameterize("seq_parallel_mode", [None, "1"]) +@parameterize("seq_parallel_mode", [None, "split_gather"]) @parameterize("overlap", [True]) def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap): check_linear_1d_col(lazy_init, seq_parallel_mode, overlap) From 33963a36e586da50401e6351d121124cd4f77e81 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Tue, 27 Feb 2024 16:21:05 +0800 Subject: [PATCH 27/50] fix gemini checkpoint io --- colossalai/testing/comparison.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index e415b5fc3aa3..e62604bfd41a 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -78,7 +78,7 @@ def check_state_dict_equal( v2 = v2.to("cpu") if ignore_dtype: v1 = v1.to(v2.dtype) - assert_close_loose(v1, v2) + assert_close_loose(v1, v2, rtol=2e-3, atol=2e-3) else: assert v1 == v2, f"{v1} not equals to {v2}" From 700c26dcf1013866fb4d31b3edee4ba951ae839c Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 28 Feb 2024 10:49:58 +0800 Subject: [PATCH 28/50] loose tensor checking atol and rtol --- colossalai/testing/comparison.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index e62604bfd41a..07d2731df923 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -78,7 +78,7 @@ def check_state_dict_equal( v2 = v2.to("cpu") if ignore_dtype: v1 = v1.to(v2.dtype) - assert_close_loose(v1, v2, rtol=2e-3, atol=2e-3) + assert_close_loose(v1, v2, rtol=3e-3, atol=3e-3) else: assert v1 == v2, f"{v1} not equals to {v2}" From 9a36add95c23d541ac22e62769a46ec750628306 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Mon, 11 Mar 2024 11:00:36 +0800 Subject: [PATCH 29/50] add comment --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 72dbb6201b50..330bb9897676 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -911,6 +911,7 @@ class HybridParallelPlugin(PipelinePluginBase): Args: tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. + sp_size (int): The size of sequence parallelism. precision (str, optional): Specifies the precision of parameters during training. Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. Defaults to 'fp16'. @@ -923,6 +924,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. + sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather". enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. @@ -1017,6 +1019,9 @@ def __init__( assert ( tp_size == 1 ), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with tensor parallelism" + assert ( + pp_size == 1 + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with pipeline parallelism" self.sp_size = dist.get_world_size() if sp_size is None else sp_size self.dp_size = dist.get_world_size() // (self.sp_size * pp_size) else: @@ -1199,6 +1204,7 @@ def configure( tp_process_group=self.tp_group, ) else: + # Here we bind the ZeRO group with sp group when user enable both ZeRO and all_to_all sp. if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": self.zero_dp_size = self.sp_size self.zero_dp_group = self.sp_group From 0e0ac1889153375163bb793ce120156022de5bb1 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 13 Mar 2024 09:59:05 +0800 Subject: [PATCH 30/50] fix llama layernorm grad --- colossalai/shardformer/policies/llama.py | 19 ++++++++----------- .../test_model/test_shard_llama.py | 17 +++++++++++++++++ 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 4c4d9e4dcef8..aaebec312c5a 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -55,17 +55,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None - sp_size = self.shard_config.sequence_parallel_size - sp_group = self.shard_config.sequence_parallel_process_group - if self.pipeline_stage_manager is not None: - if sp_mode is not None: - warnings.warn( - "Sequence parallelism is not supported under pipeline parallelism setting. " - "Sequence parallelism will be disabled." - ) - sp_mode = None - sp_size = None - sp_group = None + sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None + sp_group = ( + self.shard_config.sequence_parallel_process_group if self.shard_config.enable_sequence_parallelism else None + ) + sp_partial_derived = sp_mode in ["split_gather", "ring"] use_flash_attention = self.shard_config.enable_flash_attention # Currently sp mode ring and all_to_all need to be used with flashattention @@ -200,10 +194,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="input_layernorm", target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="post_attention_layernorm", target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, @@ -214,6 +210,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: description=SubModuleReplacementDescription( suffix="norm", target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, ), policy=policy, target_key=LlamaModel, diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index ff4b75f7236e..5d654d579429 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -41,6 +41,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] col_layer_for_check = ["layers[0].self_attn.o_proj"] + # Here we check the grad of layernorm because an all-reduce operation should be performed during sequence parallelism + norm_layer_for_check = ["layers[0].input_layernorm", "layers[0].post_attention_layernorm"] + + # During pipeline parallelism, we cannot get the grad of norm layer during first stage, so we only check this when pp is not enbaled + if stage_manager is None: + norm_layer_for_check.append("norm") # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} @@ -55,8 +61,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, col_layer_grads = get_grad_tensors_for_check( llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False ) + norm_layer_grads = get_grad_tensors_for_check( + llama_model, + shard_llama_model, + norm_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) + grads_to_check.update(norm_layer_grads) # optimizer executes step org_optimizer.step() From cbb3025c19813670ba267a0c7b7ddb196c3ec609 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 13 Mar 2024 13:54:21 +0800 Subject: [PATCH 31/50] fix zero grad --- .../test_model/test_shard_llama.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 5d654d579429..3bfd8606f1cc 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -2,6 +2,8 @@ import pytest import torch +import torch.distributed as dist +from torch.testing import assert_close import colossalai from colossalai.logging import disable_existing_loggers @@ -48,6 +50,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if stage_manager is None: norm_layer_for_check.append("norm") + # Check the grad when using ZeRO-1 and ZeRO-2 + if ( + booster.plugin.zero_stage in [1, 2] + and booster.plugin.shard_config.enable_sequence_parallelism + and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" + ): + for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): + working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)] + grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p)) + grad_index = 0 if sharded_optimizer._partition_grads else sharded_optimizer._local_rank + grad = grads[grad_index] + sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] + assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0: @@ -138,7 +154,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", "use_lazy_init": True, - "precision": "fp32", + "zero_stage": 2, + "precision": "fp16", "initial_scale": 1, }, { @@ -149,7 +166,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "sequence_parallelism_mode": "all_to_all", "enable_flash_attention": True, "use_lazy_init": True, - "precision": "fp32", + "precision": "fp16", "initial_scale": 1, }, { From 3391d3e7cd81cf2e6f2d66102c3e29431513a5b5 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 13 Mar 2024 13:54:58 +0800 Subject: [PATCH 32/50] fix zero grad --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 8 ++++++++ colossalai/zero/low_level/low_level_optim.py | 7 +++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 330bb9897676..50ed250933de 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -667,6 +667,10 @@ def __init__( self.dp_pg = dp_process_group self.tp_pg = tp_process_group self.pp_pg = pp_process_group + self.use_all_to_all_sequence_parallel = ( + self.model.shard_config.enable_sequence_parallelism + and self.model.shard_config.sequence_parallelism_mode == "all_to_all" + ) if use_pipeline: init_pipeline_optimizer(optimizer, model) super().__init__( @@ -687,6 +691,7 @@ def __init__( cpu_offload=cpu_offload, dp_process_group=dp_process_group, forced_dtype=forced_dtype, + enable_sequence_parallel=self.use_all_to_all_sequence_parallel, ) def sync_dp_grads(self): @@ -1006,6 +1011,9 @@ def __init__( self.sequence_parallelism_mode in SUPPORT_SP_MODE ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" if self.sequence_parallelism_mode in ["split_gather", "ring"]: + assert ( + zero_stage == 0 + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with ZeRO-1 / ZeRO-2" assert ( tp_size > 1 ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index a2433d1b261c..70f366dd716e 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -77,8 +77,11 @@ def __init__( forced_dtype: Optional[torch.dtype] = None, moe_extra_dp_process_group: Optional[ProcessGroup] = None, master_weights: bool = True, # master weights + enable_sequence_parallel: bool = False, ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) + self._enable_sequence_parallel = enable_sequence_parallel + self._dtype = self.optim.param_groups[0]["params"][0].dtype self._logger = get_dist_logger() self._verbose = verbose @@ -297,7 +300,8 @@ def _run_reduction(self): if self.moe_extra_dp_pg is None: flat_grads = self._bucket_store.get_flatten_grad() - flat_grads /= self._world_size + if not self._enable_sequence_parallel: + flat_grads /= self._world_size else: # record moe and non moe param moe_list = [] @@ -494,7 +498,6 @@ def backward(self, loss, retain_graph=False): # clear reduced grads if self._overlap_communication: get_accelerator().synchronize() - self.zero_grad() def backward_by_grad(self, tensor, grad): From cc28bd4edc7c683ec0267bb757dea2494289bd03 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 13 Mar 2024 14:43:18 +0800 Subject: [PATCH 33/50] fix conflict --- colossalai/shardformer/policies/gpt2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 48745e048d03..28dba5a0e955 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -115,9 +115,7 @@ def module_policy(self): SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel_mode": sp_mode, - }, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", From 1a3825d01ff9aadb4d62c33486c839643addf5e6 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Mon, 18 Mar 2024 11:33:58 +0800 Subject: [PATCH 34/50] update split and gather auto grad func --- colossalai/shardformer/layer/_operation.py | 62 ++++++++++------ colossalai/shardformer/modeling/llama.py | 84 ++-------------------- colossalai/shardformer/policies/llama.py | 2 +- 3 files changed, 46 insertions(+), 102 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 1e2b29340c8f..76a2718babcd 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -172,7 +172,7 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group= group_size = dist.get_world_size(process_group) cur_rank = dist.get_rank(process_group) - #output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)] + # output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)] # initialization of ring communication recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 @@ -277,7 +277,7 @@ def forward(ctx, input_, process_group, dim): @staticmethod def backward(ctx, grad_output): dim = ctx.dim - process_group = ctx.process_group + process_group = ctx.process_group # do reduce-scatter new_shape = list(grad_output.shape) @@ -285,7 +285,9 @@ def backward(ctx, grad_output): new_shape[dim] % dist.get_world_size(process_group) == 0 ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) - grad_list = [item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)] + grad_list = [ + item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim) + ] output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device) dist.reduce_scatter(output, grad_list, group=process_group) @@ -314,13 +316,13 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, if ring is True: input_to_gather = {} input_local = {} - input_to_gather['input'] = input_ - input_local['weight'] = weight + input_to_gather["input"] = input_ + input_local["weight"] = weight output = _ring_as_gather( - F.linear, - input_to_gather=input_to_gather, - input_local=input_local, + F.linear, + input_to_gather=input_to_gather, + input_local=input_local, process_group=process_group, ) @@ -437,7 +439,9 @@ def backward(ctx, grad_output): return output, grad_weight, grad_bias, None, None, None, None, None -def _ring_as_reducescatter(func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1): +def _ring_as_reducescatter( + func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1 +): # currently only support one single tensor as output group_size = dist.get_world_size(process_group) cur_rank = dist.get_rank(process_group) @@ -510,16 +514,16 @@ def forward(ctx, input_, weight, bias, process_group, dim, ring): if ring is True: input_to_reducescatter = {} input_local = {} - input_to_reducescatter['input'] = input_ - input_local['weight'] = weight + input_to_reducescatter["input"] = input_ + input_local["weight"] = weight if bias is not None: - input_to_reducescatter['bias'] = bias + input_to_reducescatter["bias"] = bias output = _ring_as_reducescatter( - F.linear, - input_to_reducescatter=input_to_reducescatter, - input_local=input_local, + F.linear, + input_to_reducescatter=input_to_reducescatter, + input_local=input_local, process_group=process_group, ) else: @@ -626,15 +630,15 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, if ring is True: input_to_gather = {} input_local = {} - input_to_gather['input'] = input_ - input_local['other'] = weight + input_to_gather["input"] = input_ + input_local["other"] = weight output = _ring_as_gather( - torch.matmul, - input_to_gather=input_to_gather, - input_local=input_local, + torch.matmul, + input_to_gather=input_to_gather, + input_local=input_local, process_group=process_group, - gather_dim=dim + gather_dim=dim, ) else: @@ -735,13 +739,19 @@ class _SplitForwardGatherBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, dim, process_group): + def forward(ctx, input_, dim, process_group, grad_scale=None): ctx.process_group = process_group ctx.dim = dim + ctx.grad_scale = grad_scale return _split(input_, dim, process_group) @staticmethod def backward(ctx, grad_output): + if ctx.grad_scale is not None: + if ctx.grad_scale == "up": + grad_output = grad_output * dist.get_world_size(ctx.process_group) + elif ctx.grad_scale == "down": + grad_output = grad_output / dist.get_world_size(ctx.process_group) return _gather(grad_output, ctx.dim, ctx.process_group), None, None @@ -792,13 +802,19 @@ class _GatherForwardSplitBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, dim, process_group): + def forward(ctx, input_, dim, process_group, grad_scale=None): ctx.process_group = process_group ctx.dim = dim + ctx.grad_scale = grad_scale return _gather(input_, dim, process_group) @staticmethod def backward(ctx, grad_output): + if ctx.grad_scale is not None: + if ctx.grad_scale == "up": + grad_output = grad_output * dist.get_world_size(ctx.process_group) + elif ctx.grad_scale == "down": + grad_output = grad_output / dist.get_world_size(ctx.process_group) return _split(grad_output, ctx.dim, ctx.process_group), None, None diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 35fac95a50bd..4179a3d0cb9c 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -20,9 +20,7 @@ from colossalai.shardformer.layer._operation import ( _gather, all_to_all_comm, - gather_forward_reducescatter_backward, gather_forward_split_backward, - reducescatter_forward_gather_backward, split_forward_gather_backward, ) from colossalai.shardformer.shard import ShardConfig @@ -439,7 +437,7 @@ def llama_for_sequence_classification_forward( return {"hidden_states": hidden_states} -def get_llama_flash_attention_forward(shard_config, sp_mode, sp_size): +def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb llama_version = 2 @@ -471,9 +469,9 @@ def forward( # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states) - key_states = all_to_all_comm(key_states) - value_states = all_to_all_comm(value_states) + query_states = all_to_all_comm(query_states, sp_group) + key_states = all_to_all_comm(key_states, sp_group) + value_states = all_to_all_comm(value_states, sp_group) bsz, q_len, _ = query_states.size() query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -530,7 +528,7 @@ def forward( # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - attn_output = all_to_all_comm(attn_output, None, scatter_dim=1, gather_dim=2) + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value @@ -1161,7 +1159,7 @@ def custom_forward(*inputs): hidden_states = self.norm(hidden_states) # Todo: Maybe this line can be optimized - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale="up") # add hidden states from the last decoder layer if output_hidden_states: @@ -1284,73 +1282,3 @@ def forward( ) return forward - - -def get_llama_decoder_seq_parallel_model_forward(sp_mode, sp_size, sp_group): - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - if sp_mode == "split_gather": - hidden_states = gather_forward_reducescatter_backward(hidden_states, sp_group, 1) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - if sp_mode == "split_gather": - hidden_states = reducescatter_forward_gather_backward(hidden_states, sp_group, 1) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - if sp_mode == "split_gather": - hidden_states = gather_forward_reducescatter_backward(hidden_states, sp_group, 1) - - hidden_states = self.mlp(hidden_states) - - if sp_mode == "split_gather": - hidden_states = reducescatter_forward_gather_backward(hidden_states, sp_group, 1) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - return forward diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index aaebec312c5a..e122554fa227 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -220,7 +220,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if use_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size), + "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size), }, policy=policy, target_key=LlamaAttention, From 76a22da092a992e699c771a121da53404f49c09b Mon Sep 17 00:00:00 2001 From: linsj20 Date: Wed, 20 Mar 2024 15:45:25 +0800 Subject: [PATCH 35/50] sequence parallel: inside text split (#6) --- .../booster/plugin/hybrid_parallel_plugin.py | 6 +- colossalai/shardformer/layer/_operation.py | 12 +- colossalai/shardformer/modeling/gpt2.py | 211 +++--------------- colossalai/shardformer/modeling/llama.py | 26 +-- colossalai/shardformer/policies/gpt2.py | 4 +- colossalai/shardformer/policies/llama.py | 2 +- colossalai/shardformer/shard/shard_config.py | 1 + colossalai/zero/low_level/low_level_optim.py | 5 +- tests/test_shardformer/test_model/_utils.py | 13 +- 9 files changed, 59 insertions(+), 221 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 50ed250933de..1d93297c500a 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -667,10 +667,6 @@ def __init__( self.dp_pg = dp_process_group self.tp_pg = tp_process_group self.pp_pg = pp_process_group - self.use_all_to_all_sequence_parallel = ( - self.model.shard_config.enable_sequence_parallelism - and self.model.shard_config.sequence_parallelism_mode == "all_to_all" - ) if use_pipeline: init_pipeline_optimizer(optimizer, model) super().__init__( @@ -691,7 +687,6 @@ def __init__( cpu_offload=cpu_offload, dp_process_group=dp_process_group, forced_dtype=forced_dtype, - enable_sequence_parallel=self.use_all_to_all_sequence_parallel, ) def sync_dp_grads(self): @@ -1108,6 +1103,7 @@ def __init__( sequence_parallelism_mode=sequence_parallelism_mode, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, + zero_stage=zero_stage, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 76a2718babcd..8b0bff8f7404 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -752,7 +752,7 @@ def backward(ctx, grad_output): grad_output = grad_output * dist.get_world_size(ctx.process_group) elif ctx.grad_scale == "down": grad_output = grad_output / dist.get_world_size(ctx.process_group) - return _gather(grad_output, ctx.dim, ctx.process_group), None, None + return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None class _ReduceForward(torch.autograd.Function): @@ -815,7 +815,7 @@ def backward(ctx, grad_output): grad_output = grad_output * dist.get_world_size(ctx.process_group) elif ctx.grad_scale == "down": grad_output = grad_output / dist.get_world_size(ctx.process_group) - return _split(grad_output, ctx.dim, ctx.process_group), None, None + return _split(grad_output, ctx.dim, ctx.process_group), None, None, None class _AllToAll(torch.autograd.Function): @@ -1016,12 +1016,12 @@ def matmul_gather_forward_reducescatter_backward( ) -def gather_forward_split_backward(input_, dim, process_group): - return _GatherForwardSplitBackward.apply(input_, dim, process_group) +def gather_forward_split_backward(input_, dim, process_group, grad_scale=None): + return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale) -def split_forward_gather_backward(input_, dim, process_group): - return _SplitForwardGatherBackward.apply(input_, dim, process_group) +def split_forward_gather_backward(input_, dim, process_group, grad_scale=None): + return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale) def reduce_forward(input_, process_group): diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index e12c826b21a0..f09568f2aa6b 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -904,8 +904,6 @@ def forward( # use variable seq_len to replace input_shape[-1] seq_len = input_shape[-1] - if sp_mode in ["ring", "all_to_all"]: - seq_len *= sp_size if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, seq_len) @@ -917,8 +915,6 @@ def forward( past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) - if sp_mode in ["ring", "all_to_all"]: - past_length *= sp_size if position_ids is None: position_ids = torch.arange( past_length, @@ -932,173 +928,34 @@ def forward( if sp_mode in ["ring", "all_to_all"]: position_ids = torch.chunk(position_ids.clone(), sp_size, dim=1)[dist.get_rank(sp_group)] - if sp_mode in ["ring", "all_to_all"]: - attention_mask = _gather(attention_mask, 1, sp_group) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - - hidden_states = self.drop(hidden_states) - - output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) - - attention_mask, encoder_attention_mask = _get_attention_mask( - self, - shard_config, - hidden_states, - past_key_values, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) - # Ensure that attention_mask is always on the same device as hidden_states - if torch.is_tensor(attention_mask): - attention_mask = attention_mask.to(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - hidden_states = self.ln_f(hidden_states) - - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - return forward - - -def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - inputs_embeds.shape[0] + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: raise ValueError("You have to specify either input_ids or inputs_embeds") @@ -1130,12 +987,12 @@ def forward( head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: - if sp_mode in ["ring"]: - input_ids = _gather(input_ids, 1, sp_group) - inputs_embeds = self.wte(input_ids) - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) - else: - inputs_embeds = self.wte(input_ids) + inputs_embeds = self.wte(input_ids) + if sp_mode == "ring": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 'down') + position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 4179a3d0cb9c..19b3f6618e35 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -931,7 +931,7 @@ def forward( return forward -def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): +def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group, zero_stage=0): logger = logging.get_logger(__name__) # Copied from transformers.models.bart.modeling_bart._make_causal_mask @@ -1041,10 +1041,6 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # sp: modify seq_length when using sequence parallel - if sp_mode in ["ring", "all_to_all"]: - seq_length *= sp_size - seq_length_with_past = seq_length past_key_values_length = 0 @@ -1064,13 +1060,12 @@ def forward( position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: - if sp_mode == "ring": - input_ids = _gather(input_ids, 1, sp_group) - inputs_embeds = self.embed_tokens(input_ids) - input_ids = input_ids.chunk(sp_size, dim=1)[torch.distributed.get_rank(sp_group)] - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) - else: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input_ids) + + if sp_mode in ["ring", "split_gather"]: + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 'down') # TODO use_distributed_mask use_distributed_mask = True if sp_mode in ["ring", "all_to_all"] else False @@ -1101,8 +1096,6 @@ def forward( attention_mask = _gather(attention_mask, 1, sp_group) hidden_states = inputs_embeds - if sp_mode == "split_gather": - hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: if use_cache: @@ -1159,7 +1152,10 @@ def custom_forward(*inputs): hidden_states = self.norm(hidden_states) # Todo: Maybe this line can be optimized - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale="up") + if sp_mode == "ring" or sp_mode == "split_gather" or (sp_mode == "all_to_all" and zero_stage == 0): + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all" and zero_stage in [1, 2]: + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale="up") # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 28dba5a0e955..48745e048d03 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -115,7 +115,9 @@ def module_policy(self): SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={ + "seq_parallel_mode": sp_mode, + }, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index e122554fa227..98154b4a4c3b 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -122,7 +122,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group), + "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group, self.shard_config.zero_stage), }, policy=policy, target_key=LlamaModel, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 2858311329e0..31830989f2d8 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -37,6 +37,7 @@ class ShardConfig: enable_jit_fused: bool = False enable_sequence_parallelism: bool = False sequence_parallelism_mode: str = None + zero_stage: int = 0 enable_sequence_overlap: bool = False parallel_output: bool = True # TODO padding vocab diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 70f366dd716e..bbbaf13b53ef 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -77,10 +77,8 @@ def __init__( forced_dtype: Optional[torch.dtype] = None, moe_extra_dp_process_group: Optional[ProcessGroup] = None, master_weights: bool = True, # master weights - enable_sequence_parallel: bool = False, ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) - self._enable_sequence_parallel = enable_sequence_parallel self._dtype = self.optim.param_groups[0]["params"][0].dtype self._logger = get_dist_logger() @@ -300,8 +298,7 @@ def _run_reduction(self): if self.moe_extra_dp_pg is None: flat_grads = self._bucket_store.get_flatten_grad() - if not self._enable_sequence_parallel: - flat_grads /= self._world_size + flat_grads /= self._world_size else: # record moe and non moe param moe_list = [] diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 2c9f2ceefb09..d7a35fcf8d0e 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -167,18 +167,7 @@ def _criterion(outputs, inputs): shard_test_data = {} for k, v in data.items(): - if k not in ["input_ids", "attention_mask"]: - shard_test_data[k] = data[k].clone() - else: - # todo: check the correctness of using dim=-1: to be compatible with date_gen_for_double_heads() - shard_test_data[k] = ( - torch.chunk(data[k].clone(), booster.plugin.shard_config.sequence_parallel_size, dim=-1)[ - dist.get_rank(booster.plugin.shard_config.sequence_parallel_process_group) - ] - if booster.plugin.shard_config.enable_sequence_parallelism - and booster.plugin.shard_config.sequence_parallelism_mode in ["ring", "all_to_all"] - else data[k].clone() - ) + shard_test_data[k] = data[k].clone() unshard_test_data = {} for k, v in data.items(): unshard_test_data[k] = data[k].clone() From 7e80cc4fd5fc736a0fa22e723a6439a55cef4132 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Mon, 25 Mar 2024 13:49:25 +0800 Subject: [PATCH 36/50] polish code (part 1) --- .../booster/plugin/hybrid_parallel_plugin.py | 12 +- colossalai/shardformer/layer/linear.py | 6 +- colossalai/shardformer/layer/utils.py | 11 +- colossalai/shardformer/modeling/gpt2.py | 9 +- .../test_layer/test_sequence_parallel.py | 163 ++++++------------ tests/test_shardformer/test_model/_utils.py | 5 - 6 files changed, 66 insertions(+), 140 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1d93297c500a..0e317b71576a 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -174,27 +174,27 @@ def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None): if self.shard_config.enable_sequence_parallelism: if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: - # If sequence parallelism is enabled and mode is 1 or 2, gradients are synchronized + # If sequence parallelism is enabled and mode is split_gather or ring, gradients are synchronized # across the tensor parallelism group. group = self.tp_group - require_flag = True + only_sp_partial = True elif self.shard_config.sequence_parallelism_mode == "all_to_all": - # If sequence parallelism is enabled and mode is 3, gradients are synchronized + # If sequence parallelism is enabled and mode is all_to_all, gradients are synchronized # across the sequence parallelism group. group = self.sp_group - require_flag = False + only_sp_partial = False else: raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}") if grads is not None: # Synchronize provided gradient tensors across the tensor parallelism group. SeqParallelUtils.allreduce_partial_data_grad( - process_group=group, grads=grads, require_flag=require_flag + process_group=group, grads=grads, only_sp_partial=only_sp_partial ) else: # Synchronize gradients from the model across the tensor parallelism group. SeqParallelUtils.allreduce_partial_data_grad( - process_group=group, model=self.module, require_flag=require_flag + process_group=group, model=self.module, only_sp_partial=only_sp_partial ) def forward(self, *args, **kwargs): diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 3014e97b0573..7c8619ad8f5c 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -233,7 +233,8 @@ class Linear1D_Row(ParallelModule): dtype (`torch.dtype`): The dtype of parameters, defaults to None. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. - seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None. + seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False weight_initializer (:class:`typing.Callable`, optional): @@ -253,7 +254,6 @@ def __init__( dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, - # seq_parallel: bool = False, seq_parallel_mode: str = None, seq_parallel_dim: int = 1, parallel_input: bool = True, @@ -412,7 +412,6 @@ def forward(self, input_: Tensor) -> Tensor: output_parallel_list[i], group=self.process_group, async_op=True ) handle_list.append(handle) - # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) for handle in handle_list: handle.wait() output = torch.cat(output_parallel_list, dim=-1) @@ -421,7 +420,6 @@ def forward(self, input_: Tensor) -> Tensor: output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output = reduce_forward(output_parallel, self.process_group) elif self.seq_parallel_mode == "split_gather": - # output = linear_with_async_comm(input_, self.weight, None, None, False) output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output = reducescatter_forward_gather_backward( output_parallel, self.process_group, self.seq_parallel_dim diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 0f7206ec1c27..ca7a47a9346d 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -39,7 +39,7 @@ def allreduce_partial_data_grad( process_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None, - require_flag: bool = True, + only_sp_partial: bool = True, ): """ Allreduce partial derived gradients across the specified process group. @@ -47,10 +47,10 @@ def allreduce_partial_data_grad( This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism. Args: - tp_group (ProcessGroup): The process group for gradient synchronization. + process_group (ProcessGroup): The process group for gradient synchronization. model (nn.Module): The model from which gradients will be synchronized. grads (List[torch.Tensor]): The list of gradients to be synchronized. - + only_sp_partial (bool): Whether handle all the parameters or only parameters marked as partial derived. Raises: AssertionError: If both `model` and `grads` are provided or neither is provided. """ @@ -67,11 +67,10 @@ def allreduce_partial_data_grad( if model is not None: # If `model` is provided, extract partial derived gradients from the model's parameters. grads = [] + for p in model.parameters(): if p.grad is not None: - if require_flag and SeqParallelUtils.is_sp_partial_derived_param(p): - grads.append(p.grad.data) - elif not require_flag: + if only_sp_partial and SeqParallelUtils.is_sp_partial_derived_param(p) or not only_sp_partial: grads.append(p.grad.data) # Flatten and reduce the gradients using the specified process group. diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index f09568f2aa6b..5a30cf786f85 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -24,7 +24,6 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.layer._operation import ( - _gather, all_to_all_comm, gather_forward_split_backward, split_forward_gather_backward, @@ -824,9 +823,9 @@ def forward( query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) if sp_mode == "all_to_all": - query = all_to_all_comm(query) - key = all_to_all_comm(key) - value = all_to_all_comm(value) + query = all_to_all_comm(query, sp_group) + key = all_to_all_comm(key, sp_group) + value = all_to_all_comm(value, sp_group) query = split_heads(query, self.num_heads, self.head_dim) key = split_heads(key, self.num_heads, self.head_dim) @@ -991,7 +990,7 @@ def forward( if sp_mode == "ring": inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 'down') + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, "down") position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds diff --git a/tests/test_shardformer/test_layer/test_sequence_parallel.py b/tests/test_shardformer/test_layer/test_sequence_parallel.py index 740fbaee05e8..13b1a13e7f94 100644 --- a/tests/test_shardformer/test_layer/test_sequence_parallel.py +++ b/tests/test_shardformer/test_layer/test_sequence_parallel.py @@ -12,7 +12,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -class DistributedAttention(torch.nn.Module): +class SequenceParallelAttention(torch.nn.Module): """Initialization. Arguments: @@ -24,17 +24,14 @@ class DistributedAttention(torch.nn.Module): def __init__( self, - heads_num, - hidden_dim, - q_proj, - k_proj, - v_proj, - out_proj, + heads_num: torch.Tensor, + hidden_dim: torch.Tensor, + enable_sequence_parallellism: bool = False, sequence_process_group: dist.ProcessGroup = None, scatter_idx: int = 2, gather_idx: int = 1, ) -> None: - super(DistributedAttention, self).__init__() + super(SequenceParallelAttention, self).__init__() self.spg = sequence_process_group self.scatter_idx = scatter_idx self.gather_idx = gather_idx @@ -42,22 +39,18 @@ def __init__( self.hidden_dim = hidden_dim assert hidden_dim % heads_num == 0 self.head_dim = hidden_dim // heads_num + self.enable_sequence_parallellism = enable_sequence_parallellism - self.q = q_proj - self.k = k_proj - self.v = v_proj - self.out = out_proj + self.q = nn.Linear(hidden_dim, hidden_dim) + self.k = nn.Linear(hidden_dim, hidden_dim) + self.v = nn.Linear(hidden_dim, hidden_dim) + self.out = nn.Linear(hidden_dim, hidden_dim) def attn(self, q, k, v): batch_size, seq_len = q.shape[0], q.shape[1] scale = self.head_dim**0.5 qk = torch.matmul(q, k.transpose(-2, -1)) / scale - - # if attn_mask is not None: - # mask = attn_mask == 0 - # qk[mask] = torch.tensor(float('-inf')) - weights = F.softmax(qk, dim=-1) attention_score = torch.matmul(weights, v) @@ -65,84 +58,39 @@ def attn(self, q, k, v): return attention_score def forward(self, x) -> Tensor: - """forward + bsz, q_len, _ = x.size() - Arguments: - query (Tensor): query input to the layer - key (Tensor): key input to the layer - value (Tensor): value input to the layer - args: other args - - Returns: - * output (Tensor): context output - """ + seq_len = q_len * dist.get_world_size(self.spg) if self.enable_sequence_parallellism else q_len + num_heads = ( + self.heads_num // dist.get_world_size(self.spg) if self.enable_sequence_parallellism else self.heads_num + ) # in shape : e.g., [s/p:h:] - query = self.q(x) - key = self.k(x) - value = self.v(x) - # TODO Merge three alltoall calls into one - query_layer = all_to_all_comm(query, self.spg, self.scatter_idx, self.gather_idx) - key_layer = all_to_all_comm(key, self.spg, self.scatter_idx, self.gather_idx) - value_layer = all_to_all_comm(value, self.spg, self.scatter_idx, self.gather_idx) - + query_states = self.q(x) + key_states = self.k(x) + value_states = self.v(x) + + if self.enable_sequence_parallellism: + query_states = all_to_all_comm(query_states, self.spg, self.scatter_idx, self.gather_idx) + key_states = all_to_all_comm(key_states, self.spg, self.scatter_idx, self.gather_idx) + value_states = all_to_all_comm(value_states, self.spg, self.scatter_idx, self.gather_idx) + + query_states = query_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2) # out shape : e.g., [s:h/p:] - attn_score = self.attn(query_layer, key_layer, value_layer) - - output = all_to_all_comm(attn_score, self.spg, self.gather_idx, self.scatter_idx) + attn_score = self.attn(query_states, key_states, value_states) + attn_score = attn_score.transpose(1, 2).contiguous() + attn_score = attn_score.reshape(bsz, seq_len, num_heads * self.head_dim) + if self.enable_sequence_parallellism: + attn_score = all_to_all_comm(attn_score, self.spg, self.gather_idx, self.scatter_idx) # output e.g., [s/p::h] - output = self.out(output) + output = self.out(attn_score) return output -class MultiHeadAttn(nn.Module): - def __init__(self, head_num, hidden_dim, q_proj, k_proj, v_proj, out_proj): - super(MultiHeadAttn, self).__init__() - self.head_num = head_num - self.hidden_dim = hidden_dim - assert hidden_dim % head_num == 0 - self.head_dim = hidden_dim // head_num - - self.q = q_proj - self.k = k_proj - self.v = v_proj - self.out = out_proj - - def attn(self, q, k, v): - batch_size, seq_len = q.shape[0], q.shape[1] - - scale = self.head_dim**0.5 - qk = torch.matmul(q, k.transpose(-2, -1)) / scale - - # if attn_mask is not None: - # mask = attn_mask == 0 - # qk[mask] = torch.tensor(float('-inf')) - - weights = F.softmax(qk, dim=-1) - - attention_score = torch.matmul(weights, v) - - return attention_score - - def split(self, x, batch_size, seq_len): - res = x.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2) - return res - - def forward(self, x): - batch_size, seq_len, hidden_dim = x.shape - assert hidden_dim == self.hidden_dim, "hidden_dim should be equal to self.hidden_dim" - query_mha = self.split(self.q(x), batch_size, seq_len) - key_mha = self.split(self.k(x), batch_size, seq_len) - value_mha = self.split(self.v(x), batch_size, seq_len) - score_mha = self.attn(query_mha, key_mha, value_mha) - score_mha_final = score_mha.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_dim) - output_mha = self.out(score_mha_final) - - return output_mha - - def seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size): seq_len = seq_len hidden_dim = hidden_dim @@ -150,58 +98,45 @@ def seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size): batch_size = batch_size world_size = dist.get_world_size() - q_proj = nn.Linear(hidden_dim, hidden_dim) - k_proj = nn.Linear(hidden_dim, hidden_dim) - v_proj = nn.Linear(hidden_dim, hidden_dim) - out_proj = nn.Linear(hidden_dim, hidden_dim) - - q_proj_copy = copy.deepcopy(q_proj) - k_proj_copy = copy.deepcopy(k_proj) - v_proj_copy = copy.deepcopy(v_proj) - out_proj_copy = copy.deepcopy(out_proj) - x = torch.randn(batch_size, seq_len, hidden_dim).cuda() x_unshard = x.clone() x_unshard.requires_grad_(True) x_input = torch.chunk(x.clone(), world_size, dim=1)[dist.get_rank()] x_input.requires_grad_(True) - # x_unshard = torch.randn(batch_size, seq_len, hidden_dim).cuda() - # x_unshard.requires_grad_(True) - # x_input = torch.chunk(x_unshard.clone(), world_size, dim=1)[dist.get_rank()] - # x_input.requires_grad_(True) - # Multi-head Attention - mhn = MultiHeadAttn(head_num, hidden_dim, q_proj, k_proj, v_proj, out_proj).cuda() + mha = SequenceParallelAttention(head_num, hidden_dim).cuda() # Multi-head Attention forward - mhn_out = mhn(x_unshard) + mha_out = mha(x_unshard) # Sequence parallel Attention - dist_attn = DistributedAttention(head_num, hidden_dim, q_proj_copy, k_proj_copy, v_proj_copy, out_proj_copy).cuda() + sp_attn = SequenceParallelAttention(head_num, hidden_dim, True).cuda() + sp_attn.load_state_dict(copy.deepcopy(mha.state_dict())) # Sequence parallel Attention forward - dist_attn_out = dist_attn(x_input) + dist_attn_out = sp_attn(x_input) + # gather the output of sequence parallel attention out_list = [torch.empty_like(dist_attn_out) for _ in range(world_size)] dist.all_gather(out_list, dist_attn_out) seq_out = torch.cat(out_list, dim=1) # forward result check - assert_close(seq_out, mhn_out) + assert_close(seq_out, mha_out) # Multi-head Attention backward - mhn_out.sum().backward() - q_grad = mhn.q.weight.grad - k_grad = mhn.k.weight.grad - v_grad = mhn.v.weight.grad - o_grad = mhn.out.weight.grad + mha_out.sum().backward() + q_grad = mha.q.weight.grad + k_grad = mha.k.weight.grad + v_grad = mha.v.weight.grad + o_grad = mha.out.weight.grad x_grad = x_unshard.grad # Sequence parallel Attention backward dist_attn_out.sum().backward() - q_grad_seq = dist_attn.q.weight.grad - k_grad_seq = dist_attn.k.weight.grad - v_grad_seq = dist_attn.v.weight.grad - o_grad_seq = dist_attn.out.weight.grad + q_grad_seq = sp_attn.q.weight.grad + k_grad_seq = sp_attn.k.weight.grad + v_grad_seq = sp_attn.v.weight.grad + o_grad_seq = sp_attn.out.weight.grad x_grad_seq = x_input.grad # all_reduce the grad of sequence parallel attention weight dist.all_reduce(q_grad_seq) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index d7a35fcf8d0e..d5fc2c30f294 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -22,11 +22,6 @@ from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -def print_rank(prompt, value, rank=0): - if dist.get_rank() == rank: - print(f"rank-{rank}, {prompt}: {value}") - - def build_model( model_fn, enable_fused_normalization=True, From eff69788968e3bf8bb0cec493b1a0cdaba64274a Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Mon, 25 Mar 2024 13:50:05 +0800 Subject: [PATCH 37/50] polish code (part 2) --- colossalai/shardformer/layer/_operation.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 8b0bff8f7404..ef026cf3cd2d 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -957,7 +957,6 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): if scatter_dim < 2: input_t = input_.reshape([seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]).contiguous() else: - # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! input_t = ( input_.reshape([-1, seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]) .transpose(0, 1) @@ -967,7 +966,6 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): output = torch.empty_like(input_t) dist.all_to_all_single(output, input_t, group=group) - # if scattering the seq-dim, transpose the heads back to the original dimension if scatter_dim < 2: output = output.transpose(0, 1).contiguous() From 26f7bf8a8329f22ff5be7947615a213f2e07d9ac Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Mon, 25 Mar 2024 14:12:44 +0800 Subject: [PATCH 38/50] polish code (part 2.5) --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 0e317b71576a..41ef1dcbbc52 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1006,9 +1006,6 @@ def __init__( self.sequence_parallelism_mode in SUPPORT_SP_MODE ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" if self.sequence_parallelism_mode in ["split_gather", "ring"]: - assert ( - zero_stage == 0 - ), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with ZeRO-1 / ZeRO-2" assert ( tp_size > 1 ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" From 2beac05acf8d7946144f5d3202005c4ad8a9ccd9 Mon Sep 17 00:00:00 2001 From: linsj20 Date: Tue, 26 Mar 2024 14:59:12 +0800 Subject: [PATCH 39/50] polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes --- colossalai/shardformer/layer/_operation.py | 73 ++++--------------- colossalai/shardformer/modeling/gpt2.py | 2 +- colossalai/shardformer/modeling/llama.py | 6 +- colossalai/shardformer/shard/shard_config.py | 2 +- .../test_model/test_shard_gpt2.py | 12 +++ .../test_model/test_shard_llama.py | 12 +++ 6 files changed, 44 insertions(+), 63 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index ef026cf3cd2d..82d37bb4cf94 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -177,10 +177,13 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group= # initialization of ring communication recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1 + rank_map = list(dist.get_process_group_ranks(process_group)) + recv_rank = rank_map[recv_rank] + send_rank = rank_map[send_rank] recv_tensors = {} send_tensors = {} for k, v in input_to_gather.items(): - recv_tensors[k] = v.clone() + recv_tensors[k] = torch.empty_like(v) send_tensors[k] = v.clone() def communicate_step(): @@ -192,9 +195,7 @@ def communicate_step(): def switch_step(): for k in recv_tensors: - tmp_tensor = send_tensors[k] - send_tensors[k] = recv_tensors[k] - recv_tensors[k] = tmp_tensor + send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k] output_tensors = [] @@ -257,43 +258,6 @@ def backward(ctx, grad_output): return output, None, None -class _GatherForwardReduceScatterBackward(torch.autograd.Function): - """Gather input from sequence parallel in forward and reduce-scatter gradient in backward - - Args: - input_ (`torch.Tensor`): The input tensor from sequence parallel region. - process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. - overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. - - """ - - @staticmethod - def forward(ctx, input_, process_group, dim): - ctx.process_group = process_group - ctx.dim = dim - - return _gather(input_, dim, process_group) - - @staticmethod - def backward(ctx, grad_output): - dim = ctx.dim - process_group = ctx.process_group - - # do reduce-scatter - new_shape = list(grad_output.shape) - assert ( - new_shape[dim] % dist.get_world_size(process_group) == 0 - ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " - new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) - grad_list = [ - item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim) - ] - output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device) - dist.reduce_scatter(output, grad_list, group=process_group) - - return output, None, None - - class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """Gather input from sequence parallel in forward and reduce-scatter gradient in backward @@ -314,10 +278,8 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, ctx.overlap = overlap if ring is True: - input_to_gather = {} - input_local = {} - input_to_gather["input"] = input_ - input_local["weight"] = weight + input_to_gather = {"input": input_} + input_local = {"weight": weight} output = _ring_as_gather( F.linear, @@ -449,6 +411,9 @@ def _ring_as_reducescatter( # initialization of ring communication recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1 send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 + rank_map = list(dist.get_process_group_ranks(process_group)) + recv_rank = rank_map[recv_rank] + send_rank = rank_map[send_rank] input_tensors = [] for _ in range(group_size): input_tensors.append({}) @@ -462,7 +427,7 @@ def _ring_as_reducescatter( input_tensors.reverse() output_tensor = func(**input_tensors[0], **input_local) - recv_tensor = output_tensor.clone() + recv_tensor = torch.empty_like(output_tensor) send_tensor = output_tensor.clone() def communicate_step(): @@ -512,10 +477,8 @@ def forward(ctx, input_, weight, bias, process_group, dim, ring): ctx.dim = dim if ring is True: - input_to_reducescatter = {} - input_local = {} - input_to_reducescatter["input"] = input_ - input_local["weight"] = weight + input_to_reducescatter = {"input": input_} + input_local = {"weight": weight} if bias is not None: input_to_reducescatter["bias"] = bias @@ -748,10 +711,7 @@ def forward(ctx, input_, dim, process_group, grad_scale=None): @staticmethod def backward(ctx, grad_output): if ctx.grad_scale is not None: - if ctx.grad_scale == "up": - grad_output = grad_output * dist.get_world_size(ctx.process_group) - elif ctx.grad_scale == "down": - grad_output = grad_output / dist.get_world_size(ctx.process_group) + grad_output = grad_output * ctx.grad_scale return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None @@ -811,10 +771,7 @@ def forward(ctx, input_, dim, process_group, grad_scale=None): @staticmethod def backward(ctx, grad_output): if ctx.grad_scale is not None: - if ctx.grad_scale == "up": - grad_output = grad_output * dist.get_world_size(ctx.process_group) - elif ctx.grad_scale == "down": - grad_output = grad_output / dist.get_world_size(ctx.process_group) + grad_output = grad_output * ctx.grad_scale return _split(grad_output, ctx.dim, ctx.process_group), None, None, None diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 5a30cf786f85..02bc0824c7bc 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -990,7 +990,7 @@ def forward( if sp_mode == "ring": inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, "down") + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 19b3f6618e35..af0cee4403fd 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -953,7 +953,7 @@ def _make_causal_mask_partial( mask_cond = torch.arange(mask.size(-1) * world_size, device=device) block_size = tgt_len // world_size - idx = dist.get_rank() + idx = dist.get_rank(sp_group) off = idx * block_size mask.masked_fill_(mask_cond[off : off + block_size] < (mask_cond + 1).view(mask.size(-1) * world_size, 1), 0) @@ -1065,7 +1065,7 @@ def forward( if sp_mode in ["ring", "split_gather"]: inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 'down') + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) # TODO use_distributed_mask use_distributed_mask = True if sp_mode in ["ring", "all_to_all"] else False @@ -1155,7 +1155,7 @@ def custom_forward(*inputs): if sp_mode == "ring" or sp_mode == "split_gather" or (sp_mode == "all_to_all" and zero_stage == 0): hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) elif sp_mode == "all_to_all" and zero_stage in [1, 2]: - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale="up") + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 31830989f2d8..529d31a0801e 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -61,7 +61,7 @@ def __post_init__(self): if self.enable_sequence_parallelism: self.sequence_parallelism_mode = ( - "1" if self.sequence_parallelism_mode is None else self.sequence_parallelism_mode + "split_gather" if self.sequence_parallelism_mode is None else self.sequence_parallelism_mode ) assert ( self.sequence_parallelism_mode in SUPPORT_SP_MODE diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index ad1cbb7bc95a..1eec4cb704b2 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -131,6 +131,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 2, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 4, "pp_size": 1, diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 3bfd8606f1cc..ea47fae8e316 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -126,6 +126,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 2, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 4, "pp_size": 1, From e5dcd93f5c9e8d0ff139afebdeefab6a2b602ec0 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 27 Mar 2024 13:42:58 +0800 Subject: [PATCH 40/50] polish code --- tests/kit/model_zoo/transformers/llama.py | 13 ++----------- .../test_shardformer/test_model/test_shard_llama.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 26d4a6125788..36617f2ef3d6 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -27,26 +27,17 @@ def data_gen(): # tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') # ----------------------------------- - # input_ids = torch.Tensor( - # [[1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082]] - # ).long() - input_ids = torch.Tensor( [ [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], - # [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], - # [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], - # [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], + [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], ] ).long() - # attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long() attention_mask = torch.Tensor( [ [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], ] ).long() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index ea47fae8e316..0dd51891c907 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -149,6 +149,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp32", "initial_scale": 1, }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 4, "pp_size": 1, From ace07c984a677ccd29de4d91ae4c3ad7bd456f97 Mon Sep 17 00:00:00 2001 From: linsj20 Date: Wed, 27 Mar 2024 15:17:37 +0800 Subject: [PATCH 41/50] fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp --- .../booster/plugin/hybrid_parallel_plugin.py | 4 +-- colossalai/cluster/process_group_mesh.py | 36 +++++++++++++++---- .../test_model/test_shard_gpt2.py | 13 +++++++ .../test_model/test_shard_llama.py | 12 +++++++ 4 files changed, 56 insertions(+), 9 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 41ef1dcbbc52..cc1de248382a 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1207,8 +1207,8 @@ def configure( else: # Here we bind the ZeRO group with sp group when user enable both ZeRO and all_to_all sp. if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": - self.zero_dp_size = self.sp_size - self.zero_dp_group = self.sp_group + self.zero_dp_size = self.sp_size * self.dp_size + self.zero_dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS]) else: self.zero_dp_size = self.dp_size self.zero_dp_group = self.dp_group diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index 1f32541a7b21..ccf1226958f5 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -161,7 +161,7 @@ def get_ranks_in_group(self, group: ProcessGroup) -> List[int]: @staticmethod def get_coords_along_axis( - base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int] + base_coord: Tuple[int, ...], axis: Union[int, List[int]], indices_at_axis: Union[List[int], List[List[int]]] ) -> List[Tuple[int, ...]]: """Get coordinates along the given axis. @@ -173,13 +173,28 @@ def get_coords_along_axis( Returns: List[Tuple[int, ...]]: Coordinates along the axis. """ - coords_in_group = [] - for idx in indices_at_axis: - coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :]) + if isinstance(axis, int): + axis = [axis,] + assert isinstance(indices_at_axis[0], int) + indices_at_axis = [indices_at_axis,] + + def add_index(base_coord, axis, indices_at_axis): + coords_in_group = [] + for idx in indices_at_axis: + coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :]) + return coords_in_group + + coords_in_group = [base_coord] + for ax, indices_at_ax in zip(axis, indices_at_axis): + new_coords_in_group = [] + for coords in coords_in_group: + new_coords_in_group += add_index(coords, ax, indices_at_ax) + coords_in_group = new_coords_in_group + return coords_in_group def create_group_along_axis( - self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None + self, axis: Union[int, List[int]], indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, backend: Optional[str] = None ) -> ProcessGroup: """Create all process groups along the given axis, and return the one which the current process belongs to. @@ -191,10 +206,17 @@ def create_group_along_axis( Returns: ProcessGroup: The process group along the given axis which the current process belongs to. """ - indices_at_axis = indices_at_axis or list(range(self._shape[axis])) + if isinstance(axis, int): + axis = [axis,] + if indices_at_axis is not None: + assert isinstance(indices_at_axis[0], int) + indices_at_axis = [indices_at_axis,] + + indices_at_axis = indices_at_axis or [list(range(self._shape[ax])) for ax in axis] reduced_shape = list(self._shape) # the choices on the axis are reduced to 1, since it's determined by `indices_at_axis` - reduced_shape[axis] = 1 + for ax in axis: + reduced_shape[ax] = 1 target_group = None # use Cartesian product to generate all combinations of coordinates for base_coord in itertools.product(*[range(s) for s in reduced_shape]): diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 1eec4cb704b2..ccd3737a6faa 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -181,6 +181,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp32", "initial_scale": 1, }, + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 0dd51891c907..93503359beb6 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -192,6 +192,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, From 56a5ba858e806a391c065128ff741951bb0a7c8a Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Thu, 28 Mar 2024 16:42:45 +0800 Subject: [PATCH 42/50] fix llama and gpt sp --- colossalai/shardformer/modeling/gpt2.py | 262 +++- colossalai/shardformer/modeling/llama.py | 114 +- colossalai/shardformer/modeling/sp_gpt2.py | 1391 +++++++++++++++++ colossalai/shardformer/policies/gpt2.py | 48 +- colossalai/shardformer/policies/llama.py | 72 +- .../shardformer/policies/main_llama_policy.py | 353 +++++ colossalai/shardformer/policies/sp_gpt2.py | 511 ++++++ .../test_model/test_shard_gpt2.py | 68 +- .../test_model/test_shard_llama.py | 31 +- 9 files changed, 2513 insertions(+), 337 deletions(-) create mode 100644 colossalai/shardformer/modeling/sp_gpt2.py create mode 100644 colossalai/shardformer/policies/main_llama_policy.py create mode 100644 colossalai/shardformer/policies/sp_gpt2.py diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 02bc0824c7bc..1306c8aa6299 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -1,7 +1,6 @@ from typing import Dict, List, Optional, Tuple, Union import torch -import torch.distributed as dist from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -23,11 +22,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention -from colossalai.shardformer.layer._operation import ( - all_to_all_comm, - gather_forward_split_backward, - split_forward_gather_backward, -) +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig from ..layer import cross_entropy_1d @@ -226,7 +221,9 @@ def gpt2_model_forward( if shard_config and shard_config.enable_sequence_parallelism: if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, ) # Going through held blocks. @@ -285,7 +282,9 @@ def custom_forward(*inputs): if shard_config and shard_config.enable_sequence_parallelism: if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, ) if stage_manager.is_last_stage(): @@ -794,7 +793,7 @@ def gpt2_for_sequence_classification_forward( ) -def get_gpt2_flash_attention_forward(sp_mode, sp_size, sp_group): +def get_gpt2_flash_attention_forward(): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention def forward( @@ -821,15 +820,9 @@ def forward( attention_mask = encoder_attention_mask else: query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - - if sp_mode == "all_to_all": - query = all_to_all_comm(query, sp_group) - key = all_to_all_comm(key, sp_group) - value = all_to_all_comm(value, sp_group) - - query = split_heads(query, self.num_heads, self.head_dim) - key = split_heads(key, self.num_heads, self.head_dim) - value = split_heads(value, self.num_heads, self.head_dim) + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) if layer_past is not None: past_key, past_value = layer_past @@ -849,9 +842,6 @@ def forward( dropout_p = self.attn_dropout.p if self.training else 0.0 attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) - if sp_mode == "all_to_all": - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) - attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present, None) @@ -861,8 +851,7 @@ def forward( return forward -# def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig): -def gpt2_sequence_parallel_forward_fn(sp_mode, sp_size, sp_group): +def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig): def forward( self: GPT2Model, input_ids: Optional[torch.LongTensor] = None, @@ -901,13 +890,10 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device - # use variable seq_len to replace input_shape[-1] - seq_len = input_shape[-1] - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, seq_len) + token_type_ids = token_type_ids.view(-1, input_shape[-1]) if position_ids is not None: - position_ids = position_ids.view(-1, seq_len) + position_ids = position_ids.view(-1, input_shape[-1]) if past_key_values is None: past_length = 0 @@ -923,38 +909,170 @@ def forward( ) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - # split position ids when using sequence parallel - if sp_mode in ["ring", "all_to_all"]: - position_ids = torch.chunk(position_ids.clone(), sp_size, dim=1)[dist.get_rank(sp_group)] + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) - # GPT2Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + attention_mask, encoder_attention_mask = _get_attention_mask( + self, + shard_config, + hidden_states, + past_key_values, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if torch.is_tensor(attention_mask): + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + return forward + + +def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") @@ -987,11 +1105,6 @@ def forward( if inputs_embeds is None: inputs_embeds = self.wte(input_ids) - if sp_mode == "ring": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) - elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) - position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds @@ -1011,9 +1124,6 @@ def forward( encoder_hidden_states, encoder_attention_mask, ) - # output_shape = input_shape + (hidden_states.size(-1),) - # output_shape = input_shape[:-1] + (seq_len, ) + (hidden_states.size(-1),) - output_shape = (-1,) + input_shape[1:-1] + (seq_len,) + (hidden_states.size(-1),) if self.gradient_checkpointing and self.training: if use_cache: @@ -1028,10 +1138,13 @@ def forward( all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None - if sp_mode == "split_gather": - # split the input tensor along sequence dimension - # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - hidden_states = split_forward_gather_backward(hidden_states, dim=1, process_group=sp_group) + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=shard_config.sequence_parallel_process_group, + ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): # Model parallel @@ -1094,10 +1207,13 @@ def custom_forward(*inputs): hidden_states = hidden_states.to("cuda:" + str(k + 1)) # When sequence parallelism done, gather the output tensor in forward and split it in backward - hidden_states = gather_forward_split_backward(hidden_states, dim=1, process_group=sp_group) + hidden_states = gather_forward_split_backward( + hidden_states, + dim=1, + process_group=shard_config.sequence_parallel_process_group, + ) hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(output_shape) # Add last hidden state if output_hidden_states: diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index af0cee4403fd..86a753504d8b 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -26,7 +26,6 @@ from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, cross_entropy_1d -from colossalai.shardformer.layer._operation import all_to_all_comm, gather_forward_split_backward try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask @@ -931,7 +930,7 @@ def forward( return forward -def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group, zero_stage=0): +def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group, use_flash_attention=False, zero_stage=0): logger = logging.get_logger(__name__) # Copied from transformers.models.bart.modeling_bart._make_causal_mask @@ -1067,8 +1066,7 @@ def forward( elif sp_mode == "all_to_all": inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) - # TODO use_distributed_mask - use_distributed_mask = True if sp_mode in ["ring", "all_to_all"] else False + use_distributed_mask = True if sp_mode in ["ring", "all_to_all"] and use_flash_attention else False # embed positions if sp_mode is None or use_distributed_mask is False: @@ -1077,9 +1075,6 @@ def forward( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) - if sp_mode in ["ring", "all_to_all"]: - attention_mask = _gather(attention_mask, 1, sp_group) - attention_mask = self._prepare_decoder_attention_mask( attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length ) @@ -1173,108 +1168,3 @@ def custom_forward(*inputs): ) return forward - - -def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): - from transformers import LlamaForCausalLM - - def forward( - self: LlamaForCausalLM, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - return forward diff --git a/colossalai/shardformer/modeling/sp_gpt2.py b/colossalai/shardformer/modeling/sp_gpt2.py new file mode 100644 index 000000000000..e84cf6470921 --- /dev/null +++ b/colossalai/shardformer/modeling/sp_gpt2.py @@ -0,0 +1,1391 @@ +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.models.gpt2.modeling_gpt2 import ( + GPT2DoubleHeadsModel, + GPT2DoubleHeadsModelOutput, + GPT2ForQuestionAnswering, + GPT2ForSequenceClassification, + GPT2ForTokenClassification, + GPT2LMHeadModel, + GPT2Model, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import ColoAttention +from colossalai.shardformer.layer._operation import ( + all_to_all_comm, + gather_forward_split_backward, + split_forward_gather_backward, +) +from colossalai.shardformer.shard import ShardConfig + +from ..layer import cross_entropy_1d +from ..layer._operation import gather_forward_split_backward + +logger = logging.get_logger(__name__) + + +def _get_attention_mask( + self: GPT2Model, + shard_config: ShardConfig, + hidden_states: torch.Tensor, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]], + attention_mask: Optional[torch.FloatTensor], + encoder_hidden_states: Optional[torch.Tensor], + encoder_attention_mask: Optional[torch.FloatTensor], +) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]: + batch_size, seq_len = hidden_states.shape[:2] + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + if shard_config.enable_flash_attention: + encoder_attention_mask = ColoAttention.prepare_attn_kwargs( + (encoder_batch_size, 1, seq_len, encoder_sequence_length), + dtype=hidden_states.dtype, + dtype2=encoder_hidden_states.dtype, + q_padding_mask=attention_mask, + kv_padding_mask=encoder_attention_mask, + ) + else: + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + if shard_config.enable_flash_attention: + encoder_attention_mask = {"attention_mask": None} + else: + encoder_attention_mask = None + # GPT2Attention mask. + past_key_values_length = 0 + if past_key_values is not None and past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] + if shard_config.enable_flash_attention: + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = ColoAttention.prepare_attn_kwargs( + (batch_size, 1, seq_len, seq_len + past_key_values_length), + hidden_states.dtype, + hidden_states.device, + attention_mask, + is_causal=True, + ) + elif attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + return attention_mask, encoder_attention_mask + + +class GPT2PipelineForwards: + """ + This class serves as a micro library for forward function substitution of GPT2 models + under pipeline setting. + """ + + @staticmethod + def gpt2_model_forward( + self: GPT2Model, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. + # Please refer to original code of transformers for more details. + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + logger = logging.get_logger(__name__) + + # Preprocess passed in arguments + # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") + past_key_values = None + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + else: + if hidden_states is None: + raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") + input_shape = hidden_states.size()[:-1] + device = hidden_states.device + hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:]) + batch_size = hidden_states.shape[0] + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if stage_manager.is_first_stage(): + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + else: + position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + ) + + # Going through held blocks. + start_idx, end_idx = stage_index[0], stage_index[1] + for i in range(start_idx, end_idx): + block = self.h[i] + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=None, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = gather_forward_split_backward( + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + ) + + if stage_manager.is_last_stage(): + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if stage_manager.is_last_stage(): + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + else: + # always return dict for intermediate stage + return {"hidden_states": hidden_states} + + @staticmethod + def gpt2_lmhead_model_forward( + self: GPT2LMHeadModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. + Please refer to original code of transformers for more details. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {"hidden_states": outputs["hidden_states"]} + + hidden_states = outputs[0] + lm_logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + @staticmethod + def gpt2_double_heads_model_forward( + self: GPT2DoubleHeadsModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel.forward. + Please refer to original code of transformers for more details. + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {"hidden_states": outputs["hidden_states"]} + + hidden_states = outputs[0] + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @staticmethod + def gpt2_for_question_answering_forward( + self: GPT2ForQuestionAnswering, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering.forward. + # Please refer to original code of transformers for more details. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {"hidden_states": outputs["hidden_states"]} + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @staticmethod + def gpt2_for_token_classification_forward( + self: GPT2ForTokenClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification.forward. + # Please refer to original code of transformers for more details. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {"hidden_states": outputs["hidden_states"]} + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @staticmethod + def gpt2_for_sequence_classification_forward( + self: GPT2ForSequenceClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification.forward. + # Please refer to original code of transformers for more details. + """ + logger = logging.get_logger(__name__) + + if input_ids is not None: + batch_size, _ = input_ids.shape[:2] + else: + batch_size, _ = hidden_states.shape[:2] + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {"hidden_states": outputs["hidden_states"]} + + hidden_states = outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def get_gpt2_flash_attention_forward(sp_mode, sp_size, sp_group): + from transformers.models.gpt2.modeling_gpt2 import GPT2Attention + + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention + + def split_heads(tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor + + def forward( + self: GPT2Attention, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + if sp_mode == "all_to_all": + query = all_to_all_comm(query, sp_group) + key = all_to_all_comm(key, sp_group) + value = all_to_all_comm(value, sp_group) + + query = split_heads(query, self.num_heads, self.head_dim) + key = split_heads(key, self.num_heads, self.head_dim) + value = split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) + + if use_cache is True: + present = (key, value) + else: + present = None + + if not self.is_cross_attention: + attn_mask_type = AttnMaskType.causal + flash_attention_mask = None + if attention_mask != None: + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + if not torch.all(flash_attention_mask): + if attn_mask_type == AttnMaskType.causal: + attn_mask_type == AttnMaskType.paddedcausal + else: + attn_mask_type = AttnMaskType.padding + + scale = value.size(-1) ** -0.5 + if self.scale_attn_by_inverse_layer_idx: + scale = scale * (1 / float(self.layer_idx + 1)) + + # use coloattention + attention = ColoAttention( + embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale + ) + + attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) + if sp_mode == "all_to_all": + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + outputs = (attn_output, present, None) + + return outputs + + return forward + + +def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig): + def forward( + self: GPT2Model, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange( + past_length, + input_shape[-1] + past_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + attention_mask, encoder_attention_mask = _get_attention_mask( + self, + shard_config, + hidden_states, + past_key_values, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if torch.is_tensor(attention_mask): + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + return forward + + +def gpt2_sequence_parallel_forward_fn(sp_mode, sp_size, sp_group): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # use variable seq_len to replace input_shape[-1] + seq_len = input_shape[-1] + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, seq_len) + if position_ids is not None: + position_ids = position_ids.view(-1, seq_len) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, seq_len + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_len) + + # split position ids when using sequence parallel + if sp_mode in ["ring", "all_to_all"]: + position_ids = torch.chunk(position_ids.clone(), sp_size, dim=1)[dist.get_rank(sp_group)] + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + if sp_mode == "ring": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + # output_shape = input_shape + (hidden_states.size(-1),) + # output_shape = input_shape[:-1] + (seq_len, ) + (hidden_states.size(-1),) + output_shape = (-1,) + input_shape[1:-1] + (seq_len,) + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger = logging.get_logger(__name__) + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + + if sp_mode == "split_gather": + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward(hidden_states, dim=1, process_group=sp_group) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + hidden_states = gather_forward_split_backward(hidden_states, dim=1, process_group=sp_group) + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + return forward + + +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): + from transformers import GPT2LMHeadModel + + def forward( + self: GPT2LMHeadModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + shift_logits = shift_logits.view(-1, shift_logits.size(-1)) + shift_labels = shift_labels.view(-1) + loss = cross_entropy_1d( + shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + ) + + if not shard_config.parallel_output: + lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 48745e048d03..2ba35fbbf229 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -53,27 +53,17 @@ def module_policy(self): norm_cls = col_nn.LayerNorm sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None - sp_size = self.shard_config.sequence_parallel_size - sp_group = self.shard_config.sequence_parallel_process_group overlap = self.shard_config.enable_sequence_overlap sp_partial_derived = sp_mode in ["split_gather", "ring"] use_flash_attention = self.shard_config.enable_flash_attention - # todo: currently sp mode 2 and 3 need to be used with flashattention - if sp_mode in ["ring", "all_to_all"]: - if not use_flash_attention: + # todo: currently sp cannot be used with flashattention + if sp_mode in ["split_gather", "ring", "all_to_all"]: + if use_flash_attention: warnings.warn( - f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will enable FlashAttention automatically." + f"Sequence parallelism mode {sp_mode} cannot be used with FlashAttention, will disable FlashAttention automatically." ) - use_flash_attention = True - - if sp_mode == "all_to_all": - decoder_attribute_replacement = { - "num_heads": self.model.config.num_attention_heads // sp_size, - } - policy[GPT2Attention] = ModulePolicyDescription( - attribute_replacement=decoder_attribute_replacement, - ) - + self.shard_config.enable_flash_attention = False + use_flash_attention = False if self.shard_config.enable_tensor_parallelism: policy[GPT2Model] = ModulePolicyDescription( sub_module_replacement=[ @@ -98,7 +88,11 @@ def module_policy(self): SubModuleReplacementDescription( suffix="attn.c_attn", target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={"n_fused": 3, "seq_parallel_mode": sp_mode, "overlap": overlap}, + kwargs={ + "n_fused": 3, + "seq_parallel_mode": sp_mode, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="attn.c_proj", @@ -110,7 +104,11 @@ def module_policy(self): SubModuleReplacementDescription( suffix="mlp.c_fc", target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={"n_fused": 1, "seq_parallel_mode": sp_mode, "overlap": overlap}, + kwargs={ + "n_fused": 1, + "seq_parallel_mode": sp_mode, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", @@ -170,7 +168,7 @@ def module_policy(self): if use_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_gpt2_flash_attention_forward(sp_mode, sp_size, sp_group), + "forward": get_gpt2_flash_attention_forward(), }, policy=policy, target_key=GPT2Attention, @@ -181,9 +179,7 @@ def module_policy(self): } if sp_mode is not None: - policy[GPT2Model].method_replacement = { - "forward": gpt2_sequence_parallel_forward_fn(sp_mode, sp_size, sp_group) - } + policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} return policy @@ -206,7 +202,7 @@ def get_held_layers(self) -> List[nn.Module]: layers_per_stage = self.distribute_layers( len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks ) - stage_indices = Policy.get_stage_index( + stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, @@ -247,7 +243,7 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = self.distribute_layers( len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks ) - stage_manager.stage_indices = Policy.get_stage_index( + stage_manager.stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, @@ -261,8 +257,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli ) } else: - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 98154b4a4c3b..aa3a08662cdb 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -50,10 +50,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.pipeline_stage_manager is not None: self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_overlap = False + self.shard_config.sequence_parallelism_mode = None warnings.warn( f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" ) - + zero_stage = self.shard_config.zero_stage sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None sp_group = ( @@ -62,18 +63,20 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: sp_partial_derived = sp_mode in ["split_gather", "ring"] use_flash_attention = self.shard_config.enable_flash_attention - # Currently sp mode ring and all_to_all need to be used with flashattention - if sp_mode in ["ring", "all_to_all"]: - if not use_flash_attention: + # Currently sp cannot to be used with flashattention + if sp_mode in ["split_gather", "ring", "all_to_all"]: + if use_flash_attention: warnings.warn( - f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will enable FlashAttention automatically." + f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will disable FlashAttention automatically." ) - use_flash_attention = True + use_flash_attention = False - if sp_mode == "split_gather": + if sp_mode in ["split_gather", "ring"]: self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group), + "forward": get_llama_seq_parallel_model_forward( + sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group, use_flash_attention=use_flash_attention + ), }, policy=policy, target_key=LlamaModel, @@ -85,31 +88,28 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=LlamaAttention, ) - elif sp_mode == "ring": - self.append_or_create_method_replacement( - description={ - "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), - }, - policy=policy, - target_key=LlamaAttention, - ) - self.append_or_create_method_replacement( - description={ - "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group), - }, - policy=policy, - target_key=LlamaModel, - ) + # elif sp_mode == "ring": + # self.append_or_create_method_replacement( + # description={ + # "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), + # }, + # policy=policy, + # target_key=LlamaAttention, + # ) + # self.append_or_create_method_replacement( + # description={ + # "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group), + # }, + # policy=policy, + # target_key=LlamaModel, + # ) elif sp_mode == "all_to_all": decoder_attribute_replacement = { "num_heads": self.model.config.num_attention_heads // sp_size, - # "head_dim": self.model.config.hidden_size // self.model.config.num_attention_heads, } if getattr(self.model.config, "num_key_value_heads", False): decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size - decoder_attribute_replacement["num_key_value_groups"] = ( - self.model.config.num_attention_heads // self.model.config.num_key_value_heads - ) + policy[LlamaAttention] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) @@ -122,7 +122,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group, self.shard_config.zero_stage), + "forward": get_llama_seq_parallel_model_forward( + sp_mode=sp_mode, + sp_size=sp_size, + sp_group=sp_group, + use_flash_attention=use_flash_attention, + zero_stage=zero_stage, + ), }, policy=policy, target_key=LlamaModel, @@ -256,7 +262,7 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = self.distribute_layers( len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks ) - stage_manager.stage_indices = Policy.get_stage_index( + stage_manager.stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, @@ -267,8 +273,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli } else: - layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config @@ -296,7 +302,7 @@ def get_held_layers(self) -> List[Module]: layers_per_stage = self.distribute_layers( len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks ) - stage_indices = Policy.get_stage_index( + stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, @@ -349,8 +355,6 @@ def module_policy(self): policy = super().module_policy() - setattr(self.shard_config, "causal_lm", True) - if self.shard_config.enable_tensor_parallelism and not self.shard_config.enable_sequence_parallelism: # add a new item for casual lm new_item = { diff --git a/colossalai/shardformer/policies/main_llama_policy.py b/colossalai/shardformer/policies/main_llama_policy.py new file mode 100644 index 000000000000..daa7708c8fdf --- /dev/null +++ b/colossalai/shardformer/policies/main_llama_policy.py @@ -0,0 +1,353 @@ +import warnings +from functools import partial +from typing import Callable, Dict, List, Union + +import torch.nn as nn +from torch import Tensor +from torch.nn import Module + +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D + +from ..modeling.llama import ( + LlamaPipelineForwards, + get_llama_flash_attention_forward, + get_llama_model_forward_for_flash_attn, + get_lm_forward_with_dist_cross_entropy, +) +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"] + + +class LlamaPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel + + policy = {} + + if self.shard_config.enable_fused_normalization: + norm_cls = FusedRMSNorm + else: + norm_cls = RMSNorm + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + + if self.shard_config.enable_tensor_parallelism: + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( + self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size + ) + + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=Linear1D_Row, + ), + ], + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ), + policy=policy, + target_key=LlamaModel, + ) + + # optimization configuration + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=norm_cls, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=norm_cls, + ), + ], + policy=policy, + target_key=LlamaDecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=norm_cls, + ), + policy=policy, + target_key=LlamaModel, + ) + + # use flash attention + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_llama_flash_attention_forward(self.shard_config), + }, + policy=policy, + target_key=LlamaAttention, + ) + if self.pipeline_stage_manager is None: + # replace llama model forward method + self.append_or_create_method_replacement( + description={ + "forward": get_llama_model_forward_for_flash_attn(self.shard_config), + }, + policy=policy, + target_key=LlamaModel, + ) + + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager is None: + return + + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "LlamaModel": + module = self.model + else: + module = self.model.model + + if stage_manager.is_interleave: + layers_per_stage = self.distribute_layers( + len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks + ) + stage_manager.stage_indices = self.get_stage_index( + layers_per_stage, + stage_manager.stage, + num_model_chunks=stage_manager.num_model_chunks, + num_stages=stage_manager.num_stages, + ) + method_replacement = { + "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + } + + else: + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) + + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "LlamaModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = self.distribute_layers( + len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks + ) + stage_indices = self.get_stage_index( + layers_per_stage, + stage_manager.stage, + num_model_chunks=stage_manager.num_model_chunks, + num_stages=stage_manager.num_stages, + ) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(module.norm) + + else: + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + + return held_layers + + +class LlamaModelPolicy(LlamaPolicy): + def module_policy(self): + policy = super().module_policy() + from transformers.models.llama.modeling_llama import LlamaModel + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=LlamaModel, new_forward=LlamaPipelineForwards.llama_model_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama model""" + return [] + + +class LlamaForCausalLMPolicy(LlamaPolicy): + def module_policy(self): + from transformers import LlamaForCausalLM + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + LlamaForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=Linear1D_Col, + kwargs={"gather_output": not self.shard_config.parallel_output}, + ) + ], + ) + } + if self.shard_config.parallel_output: + new_item[LlamaForCausalLM].method_replacement = { + "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + llama_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if ( + id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): + # tie weights + return [ + { + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] + return [] + + +class LlamaForSequenceClassificationPolicy(LlamaPolicy): + def module_policy(self): + from transformers import LlamaForSequenceClassification + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for sequence classification + new_item = { + LlamaForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) + } + policy.update(new_item) + # to be confirmed + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=LlamaForSequenceClassification, + new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama for sequence classification model""" + return [] diff --git a/colossalai/shardformer/policies/sp_gpt2.py b/colossalai/shardformer/policies/sp_gpt2.py new file mode 100644 index 000000000000..efc92dd4cd04 --- /dev/null +++ b/colossalai/shardformer/policies/sp_gpt2.py @@ -0,0 +1,511 @@ +import warnings +from functools import partial +from typing import Callable, Dict, List + +from torch import Tensor, nn + +import colossalai.shardformer.layer as col_nn + +from ..modeling.gpt2 import ( + GPT2PipelineForwards, + get_gpt2_flash_attention_forward, + get_gpt_model_forward_for_flash_attn, + get_lm_forward_with_dist_cross_entropy, + gpt2_sequence_parallel_forward_fn, +) +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = [ + "GPT2Policy", + "GPT2ModelPolicy", + "GPT2LMHeadModelPolicy", + "GPT2DoubleHeadsModelPolicy", + "GPT2ForTokenClassificationPolicy", + "GPT2ForSequenceClassificationPolicy", +] + + +class GPT2Policy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model + + policy = {} + + if self.shard_config.enable_fused_normalization: + norm_cls = col_nn.FusedLayerNorm + else: + norm_cls = col_nn.LayerNorm + + sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + sp_size = self.shard_config.sequence_parallel_size + sp_group = self.shard_config.sequence_parallel_process_group + overlap = self.shard_config.enable_sequence_overlap + sp_partial_derived = sp_mode in ["split_gather", "ring"] + use_flash_attention = self.shard_config.enable_flash_attention + # todo: currently sp cannot be used with flashattention + if sp_mode in ["split_gather", "ring", "all_to_all"]: + if use_flash_attention: + warnings.warn( + f"Sequence parallelism mode {sp_mode} cannot be used with FlashAttention, will disable FlashAttention automatically." + ) + use_flash_attention = False + + if sp_mode == "all_to_all": + decoder_attribute_replacement = { + "num_heads": self.model.config.num_attention_heads // sp_size, + } + policy[GPT2Attention] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + ) + + if self.shard_config.enable_tensor_parallelism: + policy[GPT2Model] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wte", + target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="drop", + target_module=col_nn.DropoutForParallelInput, + ), + ] + ) + + policy[GPT2Block] = ModulePolicyDescription( + attribute_replacement={ + "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.c_attn", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 3, + "seq_parallel_mode": sp_mode, + "overlap": overlap, + }, + ), + SubModuleReplacementDescription( + suffix="attn.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + kwargs={ + "seq_parallel_mode": sp_mode, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.c_fc", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 1, + "seq_parallel_mode": sp_mode, + "overlap": overlap, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + kwargs={ + "seq_parallel_mode": sp_mode, + }, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) + + # optimization configuration + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="ln_f", + target_module=norm_cls, + ), + policy=policy, + target_key=GPT2Model, + ) + + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="ln_1", + target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, + ), + SubModuleReplacementDescription( + suffix="ln_2", + target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, + ), + SubModuleReplacementDescription( + suffix="ln_cross_attn", + target_module=norm_cls, + ignore_if_not_exist=True, + kwargs={"sp_partial_derived": sp_partial_derived}, + ), + ], + policy=policy, + target_key=GPT2Block, + ) + + if use_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_gpt2_flash_attention_forward(sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=GPT2Attention, + ) + if not self.shard_config.pipeline_stage_manager: + policy[GPT2Model].method_replacement = { + "forward": get_gpt_model_forward_for_flash_attn(self.shard_config) + } + + if sp_mode is not None: + policy[GPT2Model].method_replacement = { + "forward": gpt2_sequence_parallel_forward_fn(sp_mode, sp_size, sp_group) + } + + return policy + + def postprocess(self): + return self.model + + def get_held_layers(self) -> List[nn.Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "GPT2Model": + module = self.model + else: + module = self.model.transformer + stage_manager = self.pipeline_stage_manager + + held_layers = [] + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = self.distribute_layers( + len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks + ) + stage_indices = Policy.get_stage_index( + layers_per_stage, + stage_manager.stage, + num_model_chunks=stage_manager.num_model_chunks, + num_stages=stage_manager.num_stages, + ) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.wte) + held_layers.append(module.wpe) + held_layers.append(module.drop) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(module.ln_f) + else: + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.wte) + held_layers.append(module.wpe) + held_layers.append(module.drop) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if not self.pipeline_stage_manager: + raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "GPT2Model": + module = self.model + else: + module = self.model.transformer + + if stage_manager.is_interleave: + layers_per_stage = self.distribute_layers( + len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks + ) + stage_manager.stage_indices = Policy.get_stage_index( + layers_per_stage, + stage_manager.stage, + num_model_chunks=stage_manager.num_model_chunks, + num_stages=stage_manager.num_stages, + ) + method_replacement = { + "forward": partial( + new_forward, + stage_manager=stage_manager, + shard_config=self.shard_config, + ) + } + else: + layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + "forward": partial( + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config, + ) + } + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + + +# GPT2Model +class GPT2ModelPolicy(GPT2Policy): + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2Model + + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPT2Model, + new_forward=GPT2PipelineForwards.gpt2_model_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in GPT2Model.""" + return [] + + +# GPT2LMHeadModel +class GPT2LMHeadModelPolicy(GPT2Policy): + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel + + module_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2LMHeadModel: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": not self.shard_config.parallel_output}, + ) + ], + ) + } + if self.shard_config.parallel_output: + addon_module[GPT2LMHeadModel].method_replacement = { + "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + } + module_policy.update(addon_module) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPT2LMHeadModel, + new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, + policy=module_policy, + ) + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """The weights of wte and lm_head are shared.""" + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager is not None: + if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [ + { + first_stage: module.transformer.wte.weight, + last_stage: module.lm_head.weight, + } + ] + return [] + + +# GPT2DoubleHeadsModel +class GPT2DoubleHeadsModelPolicy(GPT2Policy): + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel + + module_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2DoubleHeadsModel: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}, + ) + ] + ) + } + module_policy.update(addon_module) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPT2DoubleHeadsModel, + new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward, + policy=module_policy, + ) + + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + multiple_choice_head = self.model.multiple_choice_head + held_layers.append(self.model.lm_head) + held_layers.append(multiple_choice_head.summary) + held_layers.append(multiple_choice_head.activation) + held_layers.append(multiple_choice_head.first_dropout) + held_layers.append(multiple_choice_head.last_dropout) + + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """The weights of wte and lm_head are shared.""" + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager is not None: + if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [ + { + first_stage: module.transformer.wte.weight, + last_stage: module.lm_head.weight, + } + ] + return [] + + +# GPT2ForQuestionAnswering +class GPT2ForQuestionAnsweringPolicy(GPT2Policy): + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering + + module_policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPT2ForQuestionAnswering, + new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward, + policy=module_policy, + ) + + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared_params in gpt2 for QA.""" + return [] + + +# GPT2ForTokenClassification +class GPT2ForTokenClassificationPolicy(GPT2Policy): + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification + + module_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2ForTokenClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ] + ) + } + module_policy.update(addon_module) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPT2ForTokenClassification, + new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward, + policy=module_policy, + ) + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in GPT2ForTokenClassification.""" + return [] + + +# GPT2ForSequenceClassification +class GPT2ForSequenceClassificationPolicy(GPT2Policy): + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification + + module_policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPT2ForSequenceClassification, + new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward, + policy=module_policy, + ) + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in GPT2ForTokenClassification.""" + return [] diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index ccd3737a6faa..d59d7e4ad499 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -131,69 +131,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - { - "tp_size": 2, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": True, - "use_lazy_init": True, - "zero_stage": 2, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 4, - "pp_size": 1, - "enable_all_optimization": False, - "use_lazy_init": False, - "precision": "fp32", - }, - { - "tp_size": 2, - "pp_size": 1, - "enable_all_optimization": False, - "use_lazy_init": False, - "precision": "fp32", - }, - { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 4, - "enable_all_optimization": False, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": True, - "use_lazy_init": True, - "precision": "fp32", - "initial_scale": 1, - }, - { - "tp_size": 1, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_flash_attention": True, - "use_lazy_init": True, - "precision": "fp32", - "initial_scale": 1, - }, - { - "tp_size": 1, - "pp_size": 1, - "sp_size": 2, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_flash_attention": True, - "use_lazy_init": True, - "zero_stage": 2, - "precision": "fp16", - "initial_scale": 1, - }, { "tp_size": 2, "pp_size": 2, @@ -216,14 +153,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 4, "pp_size": 1, "enable_all_optimization": False, - "enable_flash_attention": True, "use_lazy_init": False, "precision": "fp32", }, { "tp_size": 2, "pp_size": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", }, @@ -231,7 +167,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": True, "precision": "fp32", }, diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 93503359beb6..581a81ef59da 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -144,7 +144,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "ring", - "enable_flash_attention": True, + "enable_flash_attention": False, "use_lazy_init": True, "precision": "fp32", "initial_scale": 1, @@ -155,29 +155,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": True, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 4, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "use_lazy_init": True, - "precision": "fp32", - "initial_scale": 1, - }, - { - "tp_size": 1, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": False, "use_lazy_init": True, - "zero_stage": 2, "precision": "fp16", "initial_scale": 1, }, @@ -187,7 +166,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", - "enable_flash_attention": True, + "enable_flash_attention": False, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, @@ -223,7 +202,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 4, "pp_size": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", }, @@ -235,7 +214,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": False, "precision": "fp32", }, - {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32"}, { "tp_size": 2, "pp_size": 1, From 93c958f109956095741efccf24e8e434d619e5f8 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Mon, 1 Apr 2024 16:24:01 +0800 Subject: [PATCH 43/50] polish code --- .../booster/plugin/hybrid_parallel_plugin.py | 6 +- colossalai/shardformer/modeling/llama.py | 142 +---- colossalai/shardformer/policies/llama.py | 3 +- .../shardformer/policies/main_llama_policy.py | 353 ------------ colossalai/shardformer/policies/sp_gpt2.py | 511 ------------------ colossalai/shardformer/shard/shard_config.py | 1 + tests/kit/model_zoo/transformers/gpt.py | 2 +- tests/kit/model_zoo/transformers/llama.py | 6 +- 8 files changed, 38 insertions(+), 986 deletions(-) delete mode 100644 colossalai/shardformer/policies/main_llama_policy.py delete mode 100644 colossalai/shardformer/policies/sp_gpt2.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index cc1de248382a..fac7e7b6799e 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1169,11 +1169,15 @@ def configure( param_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 + if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": + dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS]) + else: + dp_group = self.dp_group model = HybridParallelModule( model, precision=self.precision, shard_config=self.shard_config, - dp_group=self.dp_group, + dp_group=dp_group, tp_group=self.tp_group, sp_group=self.sp_group, use_ddp=use_ddp, diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 86a753504d8b..a970cdec2fa3 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -13,12 +13,17 @@ CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel +from transformers.models.llama.modeling_llama import ( + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaModel, + apply_rotary_pos_emb, + repeat_kv, +) from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import ( - _gather, all_to_all_comm, gather_forward_split_backward, split_forward_gather_backward, @@ -499,31 +504,6 @@ def forward( attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - # me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) - # query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) - # key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) - # value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape) - - # flash_attention_mask = None - # attn_mask_type = AttnMaskType.causal - # if not getattr(shard_config, "causal_lm", False) and attention_mask != None: - # if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - # raise ValueError( - # f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - # ) - # flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - # attn_mask_type = AttnMaskType.paddedcausal - # hidden_size = self.hidden_size // sp_size if sp_mode == "3" else self.hidden_size - - # attention = ColoAttention(embed_dim=hidden_size, num_heads=self.num_heads) - # attn_output = attention( - # query_states, - # key_states, - # value_states, - # attn_mask=flash_attention_mask, - # attn_mask_type=attn_mask_type, - # origin_attn_mask=attention_mask, - # ) # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": @@ -767,33 +747,6 @@ def forward( def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): - def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - def forward( self, hidden_states: torch.Tensor, @@ -857,49 +810,24 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - # TODO (linshengjie) Block attention with ring - #### - block_wise = False - seq_len = query_states.shape[2] - seq_block = 1024 - if block_wise and seq_len > seq_block: - assert query_states.shape[2] % seq_block == 0 - block_num = query_states.shape[2] // seq_block - - query_states_chunks = query_states.chunk(block_num, dim=2) - if attention_mask is not None: - attention_mask_chunks = attention_mask.chunk(block_num, dim=2) - attn_output_chunks = [] - - for i in range(block_num): - attn_weights = torch.matmul(query_states_chunks[i], key_states.transpose(2, 3)) / math.sqrt( - self.head_dim - ) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask_chunks[i] - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output_chunks.append(torch.matmul(attn_weights, value_states)) - attn_output = torch.cat(attn_output_chunks, dim=2) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - else: - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) + attn_weights = attn_weights + attention_mask - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) #### if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): @@ -930,7 +858,7 @@ def forward( return forward -def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group, use_flash_attention=False, zero_stage=0): +def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group, zero_stage=0): logger = logging.get_logger(__name__) # Copied from transformers.models.bart.modeling_bart._make_causal_mask @@ -1066,29 +994,14 @@ def forward( elif sp_mode == "all_to_all": inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) - use_distributed_mask = True if sp_mode in ["ring", "all_to_all"] and use_flash_attention else False - - # embed positions - if sp_mode is None or use_distributed_mask is False: - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length - ) - else: - world_size = dist.get_world_size(sp_group) - assert seq_length_with_past % world_size == 0 + if attention_mask is None: attention_mask = torch.ones( - (batch_size, seq_length_with_past // world_size), dtype=torch.bool, device=inputs_embeds.device - ) - attention_mask = _prepare_decoder_attention_mask_partial( - attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length, sp_group + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) - attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - attention_mask = _gather(attention_mask, 1, sp_group) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length + ) hidden_states = inputs_embeds @@ -1146,7 +1059,6 @@ def custom_forward(*inputs): hidden_states = self.norm(hidden_states) - # Todo: Maybe this line can be optimized if sp_mode == "ring" or sp_mode == "split_gather" or (sp_mode == "all_to_all" and zero_stage == 0): hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) elif sp_mode == "all_to_all" and zero_stage in [1, 2]: diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index aa3a08662cdb..ca9e74f8df63 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -75,7 +75,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: self.append_or_create_method_replacement( description={ "forward": get_llama_seq_parallel_model_forward( - sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group, use_flash_attention=use_flash_attention + sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group ), }, policy=policy, @@ -126,7 +126,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group, - use_flash_attention=use_flash_attention, zero_stage=zero_stage, ), }, diff --git a/colossalai/shardformer/policies/main_llama_policy.py b/colossalai/shardformer/policies/main_llama_policy.py deleted file mode 100644 index daa7708c8fdf..000000000000 --- a/colossalai/shardformer/policies/main_llama_policy.py +++ /dev/null @@ -1,353 +0,0 @@ -import warnings -from functools import partial -from typing import Callable, Dict, List, Union - -import torch.nn as nn -from torch import Tensor -from torch.nn import Module - -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D - -from ..modeling.llama import ( - LlamaPipelineForwards, - get_llama_flash_attention_forward, - get_llama_model_forward_for_flash_attn, - get_lm_forward_with_dist_cross_entropy, -) -from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription - -__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"] - - -class LlamaPolicy(Policy): - def config_sanity_check(self): - pass - - def preprocess(self): - if self.shard_config.enable_tensor_parallelism: - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - - return self.model - - def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel - - policy = {} - - if self.shard_config.enable_fused_normalization: - norm_cls = FusedRMSNorm - else: - norm_cls = RMSNorm - - if self.shard_config.enable_sequence_parallelism: - self.shard_config.enable_sequence_parallelism = False - warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") - - if self.shard_config.enable_tensor_parallelism: - decoder_attribute_replacement = { - "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - } - if getattr(self.model.config, "num_key_value_heads", False): - decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( - self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size - ) - - policy[LlamaDecoderLayer] = ModulePolicyDescription( - attribute_replacement=decoder_attribute_replacement, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attn.q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.o_proj", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.gate_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="mlp.up_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="mlp.down_proj", - target_module=Linear1D_Row, - ), - ], - ) - - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ), - policy=policy, - target_key=LlamaModel, - ) - - # optimization configuration - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="input_layernorm", - target_module=norm_cls, - ), - SubModuleReplacementDescription( - suffix="post_attention_layernorm", - target_module=norm_cls, - ), - ], - policy=policy, - target_key=LlamaDecoderLayer, - ) - - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription( - suffix="norm", - target_module=norm_cls, - ), - policy=policy, - target_key=LlamaModel, - ) - - # use flash attention - if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement( - description={ - "forward": get_llama_flash_attention_forward(self.shard_config), - }, - policy=policy, - target_key=LlamaAttention, - ) - if self.pipeline_stage_manager is None: - # replace llama model forward method - self.append_or_create_method_replacement( - description={ - "forward": get_llama_model_forward_for_flash_attn(self.shard_config), - }, - policy=policy, - target_key=LlamaModel, - ) - - return policy - - def postprocess(self): - return self.model - - def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" - if self.pipeline_stage_manager is None: - return - - stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == "LlamaModel": - module = self.model - else: - module = self.model.model - - if stage_manager.is_interleave: - layers_per_stage = self.distribute_layers( - len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks - ) - stage_manager.stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) - method_replacement = { - "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) - } - - else: - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - "forward": partial( - new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config - ) - } - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=model_cls - ) - - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) - - def get_held_layers(self) -> List[Module]: - """Get pipeline layers for current stage.""" - assert self.pipeline_stage_manager is not None - - if self.model.__class__.__name__ == "LlamaModel": - module = self.model - else: - module = self.model.model - stage_manager = self.pipeline_stage_manager - - held_layers = [] - if stage_manager.is_interleave: - assert stage_manager.num_model_chunks is not None - layers_per_stage = self.distribute_layers( - len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks - ) - stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) - if stage_manager.is_first_stage(ignore_chunk=True): - held_layers.append(module.embed_tokens) - for start_idx, end_idx in stage_indices: - held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(module.norm) - - else: - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.embed_tokens) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.norm) - - return held_layers - - -class LlamaModelPolicy(LlamaPolicy): - def module_policy(self): - policy = super().module_policy() - from transformers.models.llama.modeling_llama import LlamaModel - - if self.pipeline_stage_manager: - # set None as default - self.set_pipeline_forward( - model_cls=LlamaModel, new_forward=LlamaPipelineForwards.llama_model_forward, policy=policy - ) - return policy - - def get_held_layers(self) -> List[Module]: - """Get pipeline layers for current stage.""" - held_layers = super().get_held_layers() - return held_layers - - def get_shared_params(self) -> List[Dict[int, Tensor]]: - """No shared params in llama model""" - return [] - - -class LlamaForCausalLMPolicy(LlamaPolicy): - def module_policy(self): - from transformers import LlamaForCausalLM - - policy = super().module_policy() - - if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm - new_item = { - LlamaForCausalLM: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", - target_module=Linear1D_Col, - kwargs={"gather_output": not self.shard_config.parallel_output}, - ) - ], - ) - } - if self.shard_config.parallel_output: - new_item[LlamaForCausalLM].method_replacement = { - "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) - } - policy.update(new_item) - - if self.pipeline_stage_manager: - # set None as default - self.set_pipeline_forward( - model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy - ) - - return policy - - def get_held_layers(self) -> List[Module]: - """Get pipeline layers for current stage.""" - stage_manager = self.pipeline_stage_manager - held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.lm_head) - return held_layers - - def get_shared_params(self) -> List[Dict[int, Tensor]]: - llama_model = self.model.model - if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: - if ( - id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) - and self.pipeline_stage_manager.num_stages > 1 - ): - # tie weights - return [ - { - 0: llama_model.embed_tokens.weight, - self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, - } - ] - return [] - - -class LlamaForSequenceClassificationPolicy(LlamaPolicy): - def module_policy(self): - from transformers import LlamaForSequenceClassification - - policy = super().module_policy() - - if self.shard_config.enable_tensor_parallelism: - # add a new item for sequence classification - new_item = { - LlamaForSequenceClassification: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) - ) - ] - ) - } - policy.update(new_item) - # to be confirmed - if self.pipeline_stage_manager: - # set None as default - self.set_pipeline_forward( - model_cls=LlamaForSequenceClassification, - new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward, - policy=policy, - ) - return policy - - def get_held_layers(self) -> List[Module]: - """Get pipeline layers for current stage.""" - stage_manager = self.pipeline_stage_manager - held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.score) - return held_layers - - def get_shared_params(self) -> List[Dict[int, Tensor]]: - """No shared params in llama for sequence classification model""" - return [] diff --git a/colossalai/shardformer/policies/sp_gpt2.py b/colossalai/shardformer/policies/sp_gpt2.py deleted file mode 100644 index efc92dd4cd04..000000000000 --- a/colossalai/shardformer/policies/sp_gpt2.py +++ /dev/null @@ -1,511 +0,0 @@ -import warnings -from functools import partial -from typing import Callable, Dict, List - -from torch import Tensor, nn - -import colossalai.shardformer.layer as col_nn - -from ..modeling.gpt2 import ( - GPT2PipelineForwards, - get_gpt2_flash_attention_forward, - get_gpt_model_forward_for_flash_attn, - get_lm_forward_with_dist_cross_entropy, - gpt2_sequence_parallel_forward_fn, -) -from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription - -__all__ = [ - "GPT2Policy", - "GPT2ModelPolicy", - "GPT2LMHeadModelPolicy", - "GPT2DoubleHeadsModelPolicy", - "GPT2ForTokenClassificationPolicy", - "GPT2ForSequenceClassificationPolicy", -] - - -class GPT2Policy(Policy): - def config_sanity_check(self): - pass - - def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - return self.model - - def module_policy(self): - from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model - - policy = {} - - if self.shard_config.enable_fused_normalization: - norm_cls = col_nn.FusedLayerNorm - else: - norm_cls = col_nn.LayerNorm - - sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None - sp_size = self.shard_config.sequence_parallel_size - sp_group = self.shard_config.sequence_parallel_process_group - overlap = self.shard_config.enable_sequence_overlap - sp_partial_derived = sp_mode in ["split_gather", "ring"] - use_flash_attention = self.shard_config.enable_flash_attention - # todo: currently sp cannot be used with flashattention - if sp_mode in ["split_gather", "ring", "all_to_all"]: - if use_flash_attention: - warnings.warn( - f"Sequence parallelism mode {sp_mode} cannot be used with FlashAttention, will disable FlashAttention automatically." - ) - use_flash_attention = False - - if sp_mode == "all_to_all": - decoder_attribute_replacement = { - "num_heads": self.model.config.num_attention_heads // sp_size, - } - policy[GPT2Attention] = ModulePolicyDescription( - attribute_replacement=decoder_attribute_replacement, - ) - - if self.shard_config.enable_tensor_parallelism: - policy[GPT2Model] = ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wte", - target_module=col_nn.VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription( - suffix="drop", - target_module=col_nn.DropoutForParallelInput, - ), - ] - ) - - policy[GPT2Block] = ModulePolicyDescription( - attribute_replacement={ - "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attn.c_attn", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 3, - "seq_parallel_mode": sp_mode, - "overlap": overlap, - }, - ), - SubModuleReplacementDescription( - suffix="attn.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel_mode": sp_mode, - }, - ), - SubModuleReplacementDescription( - suffix="mlp.c_fc", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 1, - "seq_parallel_mode": sp_mode, - "overlap": overlap, - }, - ), - SubModuleReplacementDescription( - suffix="mlp.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel_mode": sp_mode, - }, - ), - SubModuleReplacementDescription( - suffix="attn.attn_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attn.resid_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ], - ) - - # optimization configuration - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription( - suffix="ln_f", - target_module=norm_cls, - ), - policy=policy, - target_key=GPT2Model, - ) - - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="ln_1", - target_module=norm_cls, - kwargs={"sp_partial_derived": sp_partial_derived}, - ), - SubModuleReplacementDescription( - suffix="ln_2", - target_module=norm_cls, - kwargs={"sp_partial_derived": sp_partial_derived}, - ), - SubModuleReplacementDescription( - suffix="ln_cross_attn", - target_module=norm_cls, - ignore_if_not_exist=True, - kwargs={"sp_partial_derived": sp_partial_derived}, - ), - ], - policy=policy, - target_key=GPT2Block, - ) - - if use_flash_attention: - self.append_or_create_method_replacement( - description={ - "forward": get_gpt2_flash_attention_forward(sp_mode, sp_size, sp_group), - }, - policy=policy, - target_key=GPT2Attention, - ) - if not self.shard_config.pipeline_stage_manager: - policy[GPT2Model].method_replacement = { - "forward": get_gpt_model_forward_for_flash_attn(self.shard_config) - } - - if sp_mode is not None: - policy[GPT2Model].method_replacement = { - "forward": gpt2_sequence_parallel_forward_fn(sp_mode, sp_size, sp_group) - } - - return policy - - def postprocess(self): - return self.model - - def get_held_layers(self) -> List[nn.Module]: - """Get pipeline layers for current stage.""" - assert self.pipeline_stage_manager is not None - - if self.model.__class__.__name__ == "GPT2Model": - module = self.model - else: - module = self.model.transformer - stage_manager = self.pipeline_stage_manager - - held_layers = [] - if stage_manager.is_interleave: - assert stage_manager.num_model_chunks is not None - layers_per_stage = self.distribute_layers( - len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks - ) - stage_indices = Policy.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) - if stage_manager.is_first_stage(ignore_chunk=True): - held_layers.append(module.wte) - held_layers.append(module.wpe) - held_layers.append(module.drop) - for start_idx, end_idx in stage_indices: - held_layers.extend(module.h[start_idx:end_idx]) - if stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(module.ln_f) - else: - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.wte) - held_layers.append(module.wpe) - held_layers.append(module.drop) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.h[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.ln_f) - return held_layers - - def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" - if not self.pipeline_stage_manager: - raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") - stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == "GPT2Model": - module = self.model - else: - module = self.model.transformer - - if stage_manager.is_interleave: - layers_per_stage = self.distribute_layers( - len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks - ) - stage_manager.stage_indices = Policy.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) - method_replacement = { - "forward": partial( - new_forward, - stage_manager=stage_manager, - shard_config=self.shard_config, - ) - } - else: - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - "forward": partial( - new_forward, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=self.shard_config, - ) - } - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) - - -# GPT2Model -class GPT2ModelPolicy(GPT2Policy): - def module_policy(self): - from transformers.models.gpt2.modeling_gpt2 import GPT2Model - - policy = super().module_policy() - - if self.pipeline_stage_manager is not None: - self.set_pipeline_forward( - model_cls=GPT2Model, - new_forward=GPT2PipelineForwards.gpt2_model_forward, - policy=policy, - ) - return policy - - def get_held_layers(self) -> List[nn.Module]: - return super().get_held_layers() - - def get_shared_params(self) -> List[Dict[int, Tensor]]: - """No shared params in GPT2Model.""" - return [] - - -# GPT2LMHeadModel -class GPT2LMHeadModelPolicy(GPT2Policy): - def module_policy(self): - from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel - - module_policy = super().module_policy() - - if self.shard_config.enable_tensor_parallelism: - addon_module = { - GPT2LMHeadModel: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": not self.shard_config.parallel_output}, - ) - ], - ) - } - if self.shard_config.parallel_output: - addon_module[GPT2LMHeadModel].method_replacement = { - "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) - } - module_policy.update(addon_module) - - if self.pipeline_stage_manager is not None: - self.set_pipeline_forward( - model_cls=GPT2LMHeadModel, - new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, - policy=module_policy, - ) - return module_policy - - def get_held_layers(self) -> List[nn.Module]: - held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.lm_head) - return held_layers - - def get_shared_params(self) -> List[Dict[int, Tensor]]: - """The weights of wte and lm_head are shared.""" - module = self.model - stage_manager = self.pipeline_stage_manager - if stage_manager is not None: - if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): - first_stage, last_stage = 0, stage_manager.num_stages - 1 - return [ - { - first_stage: module.transformer.wte.weight, - last_stage: module.lm_head.weight, - } - ] - return [] - - -# GPT2DoubleHeadsModel -class GPT2DoubleHeadsModelPolicy(GPT2Policy): - def module_policy(self): - from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel - - module_policy = super().module_policy() - - if self.shard_config.enable_tensor_parallelism: - addon_module = { - GPT2DoubleHeadsModel: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}, - ) - ] - ) - } - module_policy.update(addon_module) - - if self.pipeline_stage_manager is not None: - self.set_pipeline_forward( - model_cls=GPT2DoubleHeadsModel, - new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward, - policy=module_policy, - ) - - return module_policy - - def get_held_layers(self) -> List[nn.Module]: - held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - multiple_choice_head = self.model.multiple_choice_head - held_layers.append(self.model.lm_head) - held_layers.append(multiple_choice_head.summary) - held_layers.append(multiple_choice_head.activation) - held_layers.append(multiple_choice_head.first_dropout) - held_layers.append(multiple_choice_head.last_dropout) - - return held_layers - - def get_shared_params(self) -> List[Dict[int, Tensor]]: - """The weights of wte and lm_head are shared.""" - module = self.model - stage_manager = self.pipeline_stage_manager - if stage_manager is not None: - if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): - first_stage, last_stage = 0, stage_manager.num_stages - 1 - return [ - { - first_stage: module.transformer.wte.weight, - last_stage: module.lm_head.weight, - } - ] - return [] - - -# GPT2ForQuestionAnswering -class GPT2ForQuestionAnsweringPolicy(GPT2Policy): - def module_policy(self): - from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering - - module_policy = super().module_policy() - - if self.pipeline_stage_manager is not None: - self.set_pipeline_forward( - model_cls=GPT2ForQuestionAnswering, - new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward, - policy=module_policy, - ) - - return module_policy - - def get_held_layers(self) -> List[nn.Module]: - held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.qa_outputs) - return held_layers - - def get_shared_params(self) -> List[Dict[int, Tensor]]: - """No shared_params in gpt2 for QA.""" - return [] - - -# GPT2ForTokenClassification -class GPT2ForTokenClassificationPolicy(GPT2Policy): - def module_policy(self): - from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification - - module_policy = super().module_policy() - - if self.shard_config.enable_tensor_parallelism: - addon_module = { - GPT2ForTokenClassification: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForParallelInput, - ) - ] - ) - } - module_policy.update(addon_module) - - if self.pipeline_stage_manager is not None: - self.set_pipeline_forward( - model_cls=GPT2ForTokenClassification, - new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward, - policy=module_policy, - ) - return module_policy - - def get_held_layers(self) -> List[nn.Module]: - held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.dropout) - held_layers.append(self.model.classifier) - return held_layers - - def get_shared_params(self) -> List[Dict[int, Tensor]]: - """No shared params in GPT2ForTokenClassification.""" - return [] - - -# GPT2ForSequenceClassification -class GPT2ForSequenceClassificationPolicy(GPT2Policy): - def module_policy(self): - from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification - - module_policy = super().module_policy() - - if self.pipeline_stage_manager is not None: - self.set_pipeline_forward( - model_cls=GPT2ForSequenceClassification, - new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward, - policy=module_policy, - ) - return module_policy - - def get_held_layers(self) -> List[nn.Module]: - held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.score) - return held_layers - - def get_shared_params(self) -> List[Dict[int, Tensor]]: - """No shared params in GPT2ForTokenClassification.""" - return [] diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 529d31a0801e..5bf5964d25c9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -40,6 +40,7 @@ class ShardConfig: zero_stage: int = 0 enable_sequence_overlap: bool = False parallel_output: bool = True + extra_kwargs: Dict[str, Any] = field(default_factory=dict) # TODO padding vocab # make_vocab_size_divisible_by: int = 128 # pipeline_parallel_size: int diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index b7372b6f9607..ab5d97420292 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -130,7 +130,7 @@ def date_gen_for_double_heads(): hidden_dropout=0, problem_type="single_label_classification", pad_token_id=50256, - tie_word_embeddings=False, + tie_word_embeddings=True, ) config_for_token_classification = copy.deepcopy(config) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 36617f2ef3d6..58b5b0487a82 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -60,9 +60,9 @@ def data_gen_for_casual_lm(): config = LlamaConfig( num_hidden_layers=8, - hidden_size=64, - intermediate_size=256, - num_attention_heads=8, + hidden_size=32, + intermediate_size=64, + num_attention_heads=4, max_position_embeddings=128, num_labels=16, ) From 48580c7217960e31395f197d997708583b25d17f Mon Sep 17 00:00:00 2001 From: linsj20 Date: Tue, 2 Apr 2024 12:33:12 +0800 Subject: [PATCH 44/50] move ulysses grad sync to ddp (#9) --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 6 +++--- colossalai/shardformer/modeling/llama.py | 4 ++-- colossalai/shardformer/policies/llama.py | 2 -- colossalai/shardformer/shard/shard_config.py | 1 - .../test_shardformer/test_model/test_shard_llama.py | 13 ++++++++++++- 5 files changed, 17 insertions(+), 9 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index fac7e7b6799e..87dfcca8ef21 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -182,7 +182,7 @@ def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None): # If sequence parallelism is enabled and mode is all_to_all, gradients are synchronized # across the sequence parallelism group. group = self.sp_group - only_sp_partial = False + only_sp_partial = True else: raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}") @@ -1100,7 +1100,6 @@ def __init__( sequence_parallelism_mode=sequence_parallelism_mode, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, - zero_stage=zero_stage, ) self.amp_config = dict( initial_scale=initial_scale, @@ -1168,7 +1167,8 @@ def configure( ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: param_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): - use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 + use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or \ + (self.dp_size == 1 and self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all") if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS]) else: diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index a970cdec2fa3..a6de5948564a 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -1059,9 +1059,9 @@ def custom_forward(*inputs): hidden_states = self.norm(hidden_states) - if sp_mode == "ring" or sp_mode == "split_gather" or (sp_mode == "all_to_all" and zero_stage == 0): + if sp_mode == "ring" or sp_mode == "split_gather": hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) - elif sp_mode == "all_to_all" and zero_stage in [1, 2]: + elif sp_mode == "all_to_all": hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) # add hidden states from the last decoder layer diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index ca9e74f8df63..55454b6f37c2 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -54,7 +54,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: warnings.warn( f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" ) - zero_stage = self.shard_config.zero_stage sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None sp_group = ( @@ -126,7 +125,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group, - zero_stage=zero_stage, ), }, policy=policy, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 5bf5964d25c9..07239b545229 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -37,7 +37,6 @@ class ShardConfig: enable_jit_fused: bool = False enable_sequence_parallelism: bool = False sequence_parallelism_mode: str = None - zero_stage: int = 0 enable_sequence_overlap: bool = False parallel_output: bool = True extra_kwargs: Dict[str, Any] = field(default_factory=dict) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 581a81ef59da..611f7864e834 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -163,10 +163,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 1, "pp_size": 1, + "sp_size": 2, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", - "enable_flash_attention": False, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, @@ -183,6 +183,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, + { + "tp_size": 1, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, From aea4fb6296c646febf8bdf4fae23ed7b99224175 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Tue, 2 Apr 2024 13:50:04 +0800 Subject: [PATCH 45/50] remove zero_stage and unbind the grad sync for alltoall sp --- .../booster/plugin/hybrid_parallel_plugin.py | 24 ++++++++----------- colossalai/shardformer/layer/utils.py | 3 +-- colossalai/shardformer/modeling/llama.py | 2 +- 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 87dfcca8ef21..f94630e00209 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -173,29 +173,22 @@ def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None): """ if self.shard_config.enable_sequence_parallelism: + if self.shard_config.sequence_parallelism_mode == "all_to_all": + return + if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: # If sequence parallelism is enabled and mode is split_gather or ring, gradients are synchronized # across the tensor parallelism group. group = self.tp_group - only_sp_partial = True - elif self.shard_config.sequence_parallelism_mode == "all_to_all": - # If sequence parallelism is enabled and mode is all_to_all, gradients are synchronized - # across the sequence parallelism group. - group = self.sp_group - only_sp_partial = True else: raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}") if grads is not None: # Synchronize provided gradient tensors across the tensor parallelism group. - SeqParallelUtils.allreduce_partial_data_grad( - process_group=group, grads=grads, only_sp_partial=only_sp_partial - ) + SeqParallelUtils.allreduce_partial_data_grad(process_group=group, grads=grads) else: # Synchronize gradients from the model across the tensor parallelism group. - SeqParallelUtils.allreduce_partial_data_grad( - process_group=group, model=self.module, only_sp_partial=only_sp_partial - ) + SeqParallelUtils.allreduce_partial_data_grad(process_group=group, model=self.module) def forward(self, *args, **kwargs): if self.convert_fn is not None: @@ -1167,8 +1160,11 @@ def configure( ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: param_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): - use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or \ - (self.dp_size == 1 and self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all") + use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( + self.dp_size == 1 + and self.enable_sequence_parallelism + and self.sequence_parallelism_mode == "all_to_all" + ) if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS]) else: diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index ca7a47a9346d..9c6ced4454dc 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -39,7 +39,6 @@ def allreduce_partial_data_grad( process_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None, - only_sp_partial: bool = True, ): """ Allreduce partial derived gradients across the specified process group. @@ -70,7 +69,7 @@ def allreduce_partial_data_grad( for p in model.parameters(): if p.grad is not None: - if only_sp_partial and SeqParallelUtils.is_sp_partial_derived_param(p) or not only_sp_partial: + if SeqParallelUtils.is_sp_partial_derived_param(p): grads.append(p.grad.data) # Flatten and reduce the gradients using the specified process group. diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index a6de5948564a..fff7b246b88a 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -858,7 +858,7 @@ def forward( return forward -def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group, zero_stage=0): +def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): logger = logging.get_logger(__name__) # Copied from transformers.models.bart.modeling_bart._make_causal_mask From 07ae37b3e0eec53c0cffb3ba4769edcdd8ba861c Mon Sep 17 00:00:00 2001 From: linsj20 Date: Wed, 3 Apr 2024 14:11:05 +0800 Subject: [PATCH 46/50] add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test --- tests/test_cluster/test_process_group_mesh.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/test_cluster/test_process_group_mesh.py b/tests/test_cluster/test_process_group_mesh.py index 08542d1f64fa..3d206622d644 100644 --- a/tests/test_cluster/test_process_group_mesh.py +++ b/tests/test_cluster/test_process_group_mesh.py @@ -84,6 +84,30 @@ def check_process_group_mesh_with_cases(): 2: [2], 3: [3], } + TPxPP_RANKS_IN_GROUP = { + 0: [0, 1, 2, 3], + 1: [0, 1, 2, 3], + 2: [0, 1, 2, 3], + 3: [0, 1, 2, 3], + } + DPxTP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + TPxPP_PARTIAL_INDICES = { + 0: [[0, 1], [0]], + 1: [[1], [0, 1]], + 2: [[0], [0, 1]], + 3: [[0, 1], [1]], + } + TPxPP_RANKS_IN_GROUP_PARTIAL = { + 0: [0, 1], + 1: [1, 3], + 2: [0, 2], + 3: [2, 3], + } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE, TP_SIZE) @@ -107,6 +131,12 @@ def check_process_group_mesh_with_cases(): assert pg_mesh.get_ranks_in_group(pp_group) == PP_RANKS_IN_GROUP[rank] dp_group = pg_mesh.get_group_along_axis(DP_DIM) assert pg_mesh.get_ranks_in_group(dp_group) == DP_RANKS_IN_GROUP[rank] + dpxtp_group = pg_mesh.create_group_along_axis([DP_DIM, TP_DIM]) + assert pg_mesh.get_ranks_in_group(dpxtp_group) == DPxTP_RANKS_IN_GROUP[rank] + tpxpp_group = pg_mesh.create_group_along_axis([TP_DIM, PP_DIM]) + assert pg_mesh.get_ranks_in_group(tpxpp_group) == TPxPP_RANKS_IN_GROUP[rank] + tpxpp_group_partial = pg_mesh.create_group_along_axis([TP_DIM, PP_DIM], TPxPP_PARTIAL_INDICES[rank]) + assert pg_mesh.get_ranks_in_group(tpxpp_group_partial) == TPxPP_RANKS_IN_GROUP_PARTIAL[rank] # check prev rank if RANK_TO_COORDINATE[rank][TP_DIM] != 0: From 145e8792c507f19148288ac13f11bd29b0627e1d Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 3 Apr 2024 14:12:58 +0800 Subject: [PATCH 47/50] remove useless code --- .../booster/plugin/hybrid_parallel_plugin.py | 13 +- colossalai/shardformer/modeling/llama.py | 78 - colossalai/shardformer/modeling/sp_gpt2.py | 1391 ----------------- 3 files changed, 4 insertions(+), 1478 deletions(-) delete mode 100644 colossalai/shardformer/modeling/sp_gpt2.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index f94630e00209..1b48178919ee 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1162,6 +1162,7 @@ def configure( if not isinstance(model, ModelWrapper): use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( self.dp_size == 1 + and self.pp_size == 1 and self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all" ) @@ -1205,14 +1206,8 @@ def configure( tp_process_group=self.tp_group, ) else: - # Here we bind the ZeRO group with sp group when user enable both ZeRO and all_to_all sp. - if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": - self.zero_dp_size = self.sp_size * self.dp_size - self.zero_dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS]) - else: - self.zero_dp_size = self.dp_size - self.zero_dp_group = self.dp_group - if self.zero_dp_size == 1: + zero_dp_size = dist.get_world_size(dp_group) + if zero_dp_size == 1: warnings.warn( "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "If you are not intended to use cpu_offload, please consider set zero_stage=0." @@ -1224,7 +1219,7 @@ def configure( model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info, - dp_process_group=self.zero_dp_group, + dp_process_group=dp_group, tp_process_group=self.tp_group, pp_process_group=self.pp_group, verbose=True, diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index fff7b246b88a..463344446722 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -3,7 +3,6 @@ from typing import List, Optional, Tuple, Union import torch -import torch.distributed as dist import torch.nn.functional as F import torch.utils.checkpoint from torch import nn @@ -861,83 +860,6 @@ def forward( def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): logger = logging.get_logger(__name__) - # Copied from transformers.models.bart.modeling_bart._make_causal_mask - def _make_causal_mask_partial( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, - sp_group=None, - ): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - world_size = dist.get_world_size(sp_group) - tgt_len *= world_size - - mask = torch.full((tgt_len, tgt_len // world_size), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1) * world_size, device=device) - - block_size = tgt_len // world_size - idx = dist.get_rank(sp_group) - off = idx * block_size - - mask.masked_fill_(mask_cond[off : off + block_size] < (mask_cond + 1).view(mask.size(-1) * world_size, 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat( - [torch.zeros(tgt_len // world_size, past_key_values_length, dtype=dtype, device=device), mask], dim=-1 - ) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, (tgt_len + past_key_values_length) // world_size) - - # Copied from transformers.models.bart.modeling_bart._expand_mask - def _expand_mask_partial(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None, sp_group=None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - world_size = dist.get_world_size(sp_group) - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len * world_size, src_len).to(dtype) - - # inverted_mask = 1.0 - expanded_mask - inverted_mask = expanded_mask.mul_(-1).add_(1.0) - - return inverted_mask.masked_fill_(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask_partial( - attention_mask, input_shape, inputs_embeds, past_key_values_length, sp_group=None - ): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask_partial( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - sp_group=sp_group, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask_partial( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1], sp_group=sp_group - ).to(inputs_embeds.device) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask.add_(combined_attention_mask) - ) - - return combined_attention_mask - def forward( self, input_ids: torch.LongTensor = None, diff --git a/colossalai/shardformer/modeling/sp_gpt2.py b/colossalai/shardformer/modeling/sp_gpt2.py deleted file mode 100644 index e84cf6470921..000000000000 --- a/colossalai/shardformer/modeling/sp_gpt2.py +++ /dev/null @@ -1,1391 +0,0 @@ -from typing import Dict, List, Optional, Tuple, Union - -import torch -import torch.distributed as dist -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, - CausalLMOutputWithCrossAttentions, - QuestionAnsweringModelOutput, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, -) -from transformers.models.gpt2.modeling_gpt2 import ( - GPT2DoubleHeadsModel, - GPT2DoubleHeadsModelOutput, - GPT2ForQuestionAnswering, - GPT2ForSequenceClassification, - GPT2ForTokenClassification, - GPT2LMHeadModel, - GPT2Model, -) -from transformers.utils import logging - -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer import ColoAttention -from colossalai.shardformer.layer._operation import ( - all_to_all_comm, - gather_forward_split_backward, - split_forward_gather_backward, -) -from colossalai.shardformer.shard import ShardConfig - -from ..layer import cross_entropy_1d -from ..layer._operation import gather_forward_split_backward - -logger = logging.get_logger(__name__) - - -def _get_attention_mask( - self: GPT2Model, - shard_config: ShardConfig, - hidden_states: torch.Tensor, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]], - attention_mask: Optional[torch.FloatTensor], - encoder_hidden_states: Optional[torch.Tensor], - encoder_attention_mask: Optional[torch.FloatTensor], -) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]: - batch_size, seq_len = hidden_states.shape[:2] - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - if shard_config.enable_flash_attention: - encoder_attention_mask = ColoAttention.prepare_attn_kwargs( - (encoder_batch_size, 1, seq_len, encoder_sequence_length), - dtype=hidden_states.dtype, - dtype2=encoder_hidden_states.dtype, - q_padding_mask=attention_mask, - kv_padding_mask=encoder_attention_mask, - ) - else: - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - if shard_config.enable_flash_attention: - encoder_attention_mask = {"attention_mask": None} - else: - encoder_attention_mask = None - # GPT2Attention mask. - past_key_values_length = 0 - if past_key_values is not None and past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] - if shard_config.enable_flash_attention: - if attention_mask is not None: - attention_mask = attention_mask.view(batch_size, -1) - attention_mask = ColoAttention.prepare_attn_kwargs( - (batch_size, 1, seq_len, seq_len + past_key_values_length), - hidden_states.dtype, - hidden_states.device, - attention_mask, - is_causal=True, - ) - elif attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - return attention_mask, encoder_attention_mask - - -class GPT2PipelineForwards: - """ - This class serves as a micro library for forward function substitution of GPT2 models - under pipeline setting. - """ - - @staticmethod - def gpt2_model_forward( - self: GPT2Model, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None, - ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: - # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. - # Please refer to original code of transformers for more details. - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - logger = logging.get_logger(__name__) - - # Preprocess passed in arguments - # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. - if past_key_values: - logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") - past_key_values = None - if output_attentions: - logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") - output_attentions = False - if output_hidden_states: - logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") - output_hidden_states = False - if use_cache: - logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") - use_cache = False - - if stage_manager.is_first_stage(): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - else: - if hidden_states is None: - raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") - input_shape = hidden_states.size()[:-1] - device = hidden_states.device - hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:]) - batch_size = hidden_states.shape[0] - - # GPT2Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if stage_manager.is_first_stage(): - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - else: - position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - hidden_states = self.drop(hidden_states) - - output_shape = input_shape + (hidden_states.size(-1),) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - - # split the input tensor along sequence dimension - # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "split_gather": - hidden_states = split_forward_gather_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - ) - - # Going through held blocks. - start_idx, end_idx = stage_index[0], stage_index[1] - for i in range(start_idx, end_idx): - block = self.h[i] - torch.cuda.set_device(hidden_states.device) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=None, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - - # When sequence parallelism done, gather the output tensor in forward and split it in backward - if shard_config and shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "split_gather": - hidden_states = gather_forward_split_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - ) - - if stage_manager.is_last_stage(): - hidden_states = self.ln_f(hidden_states) - - hidden_states = hidden_states.view(output_shape) - - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if stage_manager.is_last_stage(): - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - else: - # always return dict for intermediate stage - return {"hidden_states": hidden_states} - - @staticmethod - def gpt2_lmhead_model_forward( - self: GPT2LMHeadModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None, - ) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - - This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. - Please refer to original code of transformers for more details. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = GPT2PipelineForwards.gpt2_model_forward( - self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config, - ) - - # If not at the last stage, return hidden_states as in GPT2Model - if not stage_manager.is_last_stage(): - return {"hidden_states": outputs["hidden_states"]} - - hidden_states = outputs[0] - lm_logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - @staticmethod - def gpt2_double_heads_model_forward( - self: GPT2DoubleHeadsModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - mc_token_ids: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, - mc_labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None, - ) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]: - r""" - mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): - Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - - 1]`. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to - `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` - mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) - - This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel.forward. - Please refer to original code of transformers for more details. - ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = GPT2PipelineForwards.gpt2_model_forward( - self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config, - ) - - # If not at the last stage, return hidden_states as in GPT2Model - if not stage_manager.is_last_stage(): - return {"hidden_states": outputs["hidden_states"]} - - hidden_states = outputs[0] - lm_logits = self.lm_head(hidden_states) - mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) - - mc_loss = None - if mc_labels is not None: - loss_fct = CrossEntropyLoss() - mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) - lm_loss = None - if labels is not None: - labels = labels.to(lm_logits.device) - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - loss_fct = CrossEntropyLoss() - lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - if not return_dict: - output = (lm_logits, mc_logits) + outputs[1:] - if mc_loss is not None: - output = (mc_loss,) + output - return ((lm_loss,) + output) if lm_loss is not None else output - - return GPT2DoubleHeadsModelOutput( - loss=lm_loss, - mc_loss=mc_loss, - logits=lm_logits, - mc_logits=mc_logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - @staticmethod - def gpt2_for_question_answering_forward( - self: GPT2ForQuestionAnswering, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - start_positions: Optional[torch.LongTensor] = None, - end_positions: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None, - ) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]: - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - - # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering.forward. - # Please refer to original code of transformers for more details. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = GPT2PipelineForwards.gpt2_model_forward( - self.transformer, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config, - ) - - # If not at the last stage, return hidden_states as in GPT2Model - if not stage_manager.is_last_stage(): - return {"hidden_states": outputs["hidden_states"]} - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - total_loss = None - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1).to(start_logits.device) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1).to(end_logits.device) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions = start_positions.clamp(0, ignored_index) - end_positions = end_positions.clamp(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - - return QuestionAnsweringModelOutput( - loss=total_loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - @staticmethod - def gpt2_for_token_classification_forward( - self: GPT2ForTokenClassification, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None, - ) -> Union[Dict, Tuple, TokenClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification.forward. - # Please refer to original code of transformers for more details. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = GPT2PipelineForwards.gpt2_model_forward( - self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config, - ) - - # If not at the last stage, return hidden_states as in GPT2Model - if not stage_manager.is_last_stage(): - return {"hidden_states": outputs["hidden_states"]} - - hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states) - logits = self.classifier(hidden_states) - - loss = None - if labels is not None: - labels = labels.to(logits.device) - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - @staticmethod - def gpt2_for_sequence_classification_forward( - self: GPT2ForSequenceClassification, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None, - ) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification.forward. - # Please refer to original code of transformers for more details. - """ - logger = logging.get_logger(__name__) - - if input_ids is not None: - batch_size, _ = input_ids.shape[:2] - else: - batch_size, _ = hidden_states.shape[:2] - assert ( - self.config.pad_token_id is not None or batch_size == 1 - ), "Cannot handle batch sizes > 1 if no padding token is defined." - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = GPT2PipelineForwards.gpt2_model_forward( - self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config, - ) - - # If not at the last stage, return hidden_states as in GPT2Model - if not stage_manager.is_last_stage(): - return {"hidden_states": outputs["hidden_states"]} - - hidden_states = outputs[0] - logits = self.score(hidden_states) - - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) - else: - sequence_lengths = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def get_gpt2_flash_attention_forward(sp_mode, sp_size, sp_group): - from transformers.models.gpt2.modeling_gpt2 import GPT2Attention - - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention - - def split_heads(tensor, num_heads, attn_head_size): - """ - Splits hidden_size dim into attn_head_size and num_heads - """ - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(new_shape) - return tensor - - def forward( - self: GPT2Attention, - hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - if encoder_hidden_states is not None: - if not hasattr(self, "q_attn"): - raise ValueError( - "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." - ) - - query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) - attention_mask = encoder_attention_mask - else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - - if sp_mode == "all_to_all": - query = all_to_all_comm(query, sp_group) - key = all_to_all_comm(key, sp_group) - value = all_to_all_comm(value, sp_group) - - query = split_heads(query, self.num_heads, self.head_dim) - key = split_heads(key, self.num_heads, self.head_dim) - value = split_heads(value, self.num_heads, self.head_dim) - - if layer_past is not None: - past_key, past_value = layer_past - key = torch.cat((past_key, key), dim=1) - value = torch.cat((past_value, value), dim=1) - - if use_cache is True: - present = (key, value) - else: - present = None - - if not self.is_cross_attention: - attn_mask_type = AttnMaskType.causal - flash_attention_mask = None - if attention_mask != None: - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - if not torch.all(flash_attention_mask): - if attn_mask_type == AttnMaskType.causal: - attn_mask_type == AttnMaskType.paddedcausal - else: - attn_mask_type = AttnMaskType.padding - - scale = value.size(-1) ** -0.5 - if self.scale_attn_by_inverse_layer_idx: - scale = scale * (1 / float(self.layer_idx + 1)) - - # use coloattention - attention = ColoAttention( - embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale - ) - - attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) - if sp_mode == "all_to_all": - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) - - attn_output = self.c_proj(attn_output) - attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, present, None) - - return outputs - - return forward - - -def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig): - def forward( - self: GPT2Model, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) - if position_ids is None: - position_ids = torch.arange( - past_length, - input_shape[-1] + past_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - - hidden_states = self.drop(hidden_states) - - output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) - - attention_mask, encoder_attention_mask = _get_attention_mask( - self, - shard_config, - hidden_states, - past_key_values, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) - # Ensure that attention_mask is always on the same device as hidden_states - if torch.is_tensor(attention_mask): - attention_mask = attention_mask.to(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - hidden_states = self.ln_f(hidden_states) - - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - return forward - - -def gpt2_sequence_parallel_forward_fn(sp_mode, sp_size, sp_group): - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - # use variable seq_len to replace input_shape[-1] - seq_len = input_shape[-1] - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, seq_len) - if position_ids is not None: - position_ids = position_ids.view(-1, seq_len) - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) - if position_ids is None: - position_ids = torch.arange(past_length, seq_len + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, seq_len) - - # split position ids when using sequence parallel - if sp_mode in ["ring", "all_to_all"]: - position_ids = torch.chunk(position_ids.clone(), sp_size, dim=1)[dist.get_rank(sp_group)] - - # GPT2Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - if sp_mode == "ring": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) - elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) - - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - - hidden_states = self.drop(hidden_states) - - # output_shape = input_shape + (hidden_states.size(-1),) - # output_shape = input_shape[:-1] + (seq_len, ) + (hidden_states.size(-1),) - output_shape = (-1,) + input_shape[1:-1] + (seq_len,) + (hidden_states.size(-1),) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger = logging.get_logger(__name__) - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - - if sp_mode == "split_gather": - # split the input tensor along sequence dimension - # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - hidden_states = split_forward_gather_backward(hidden_states, dim=1, process_group=sp_group) - - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - # When sequence parallelism done, gather the output tensor in forward and split it in backward - hidden_states = gather_forward_split_backward(hidden_states, dim=1, process_group=sp_group) - - hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - return forward - - -def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): - from transformers import GPT2LMHeadModel - - def forward( - self: GPT2LMHeadModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - shift_logits = shift_logits.view(-1, shift_logits.size(-1)) - shift_labels = shift_labels.view(-1) - loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group - ) - - if not shard_config.parallel_output: - lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - cross_attentions=transformer_outputs.cross_attentions, - ) - - return forward From 7c314551bca21816c3c49a7efae0a2f7d66439bc Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 3 Apr 2024 14:15:43 +0800 Subject: [PATCH 48/50] change shard config not to enable sp when enable_all_optimizations --- colossalai/shardformer/shard/shard_config.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 07239b545229..03d0e4e2e840 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -109,12 +109,10 @@ def _turn_on_all_optimization(self): self.enable_fused_normalization = True self.enable_flash_attention = True self.enable_jit_fused = True - if self.enable_tensor_parallelism: - self.enable_sequence_parallelism = True - self.enable_sequence_overlap = True - # todo modify default sequence parallelism mode and process group - self.sequence_parallelism_mode = "split_gather" - self.sequence_parallel_process_group = self.tensor_parallel_process_group + # This can cause non-in-place param sharding when used without ZeRO. + # It may also slow down training when seq len is small. Plz enable manually. + # self.enable_sequence_parallelism = True + # self.enable_sequence_overlap = True def _infer(self): """ From 794800adbe8b514a5c78b8f955c7002ec7559723 Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 3 Apr 2024 15:22:50 +0800 Subject: [PATCH 49/50] add sp warnings for several model --- colossalai/shardformer/policies/bert.py | 8 +++++++ colossalai/shardformer/policies/bloom.py | 8 +++++++ colossalai/shardformer/policies/chatglm2.py | 7 ++++++ colossalai/shardformer/policies/gpt2.py | 6 +++++ colossalai/testing/comparison.py | 2 +- .../test_gemini_checkpoint_io.py | 17 ++++++++++---- .../test_model/test_shard_bert.py | 22 +++++++++++++++++++ .../test_model/test_shard_bloom.py | 22 +++++++++++++++++++ .../test_model/test_shard_chatglm2.py | 22 +++++++++++++++++++ .../test_model/test_shard_gpt2.py | 22 +++++++++++++++++++ 10 files changed, 131 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index f01382878f17..142f47e2b468 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List @@ -67,6 +68,13 @@ def module_policy(self): norm_cls = col_nn.LayerNorm sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for Bert" + if sp_mode == "ring": + warnings.warn( + f"For Bert, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" + ) + sp_mode = "split_gather" + overlap = self.shard_config.enable_sequence_overlap sp_partial_derived = sp_mode == "split_gather" diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 9898f4fb388a..3bb3f2c04628 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List @@ -57,6 +58,13 @@ def module_policy(self): norm_cls = col_nn.LayerNorm sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for BLOOM" + if sp_mode == "ring": + warnings.warn( + f"For BLOOM, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" + ) + sp_mode = "split_gather" + overlap = self.shard_config.enable_sequence_overlap sp_partial_derived = sp_mode == "split_gather" diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 249999b06528..c22e5d6839ec 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -57,6 +58,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: norm_cls = col_nn.LayerNorm sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2" + if sp_mode == "ring": + warnings.warn( + f"For ChatGLM2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" + ) + sp_mode = "split_gather" overlap = self.shard_config.enable_sequence_overlap sp_partial_derived = sp_mode == "split_gather" diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 2ba35fbbf229..f38f54b1e19b 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -53,6 +53,12 @@ def module_policy(self): norm_cls = col_nn.LayerNorm sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for GPT2" + if sp_mode == "ring": + warnings.warn( + f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" + ) + sp_mode = "split_gather" overlap = self.shard_config.enable_sequence_overlap sp_partial_derived = sp_mode in ["split_gather", "ring"] use_flash_attention = self.shard_config.enable_flash_attention diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index 07d2731df923..e415b5fc3aa3 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -78,7 +78,7 @@ def check_state_dict_equal( v2 = v2.to("cpu") if ignore_dtype: v1 = v1.to(v2.dtype) - assert_close_loose(v1, v2, rtol=3e-3, atol=3e-3) + assert_close_loose(v1, v2) else: assert v1 == v2, f"{v1} not equals to {v2}" diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index ece3b40360e8..ac6f8caef816 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -44,7 +44,10 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) bert_model = model_fn() - enable_all_optimization = True if tp_size > 1 else False + + enable_flash_attention = True if tp_size > 1 else False + enable_fused_normalization = True if tp_size > 1 else False + enable_jit_fused = True if tp_size > 1 else False with shared_tempdir() as tempdir: pretrained_path = os.path.join(tempdir, "pretrained") @@ -54,7 +57,9 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b plugin = GeminiPlugin( **placement_config, tp_size=tp_size, - enable_all_optimization=enable_all_optimization, + enable_flash_attention=enable_flash_attention, + enable_fused_normalization=enable_fused_normalization, + enable_jit_fused=enable_jit_fused, extra_dp_size=extra_dp_size, ) booster = Booster(plugin=plugin) @@ -80,7 +85,9 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() - enable_all_optimization = True if tp_size > 1 else False + enable_flash_attention = True if tp_size > 1 else False + enable_fused_normalization = True if tp_size > 1 else False + enable_jit_fused = True if tp_size > 1 else False extra_dp_size = dist.get_world_size() // (zero_size * tp_size) plugin = GeminiPlugin( **placement_config, @@ -88,7 +95,9 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, - enable_all_optimization=enable_all_optimization, + enable_flash_attention=enable_flash_attention, + enable_fused_normalization=enable_fused_normalization, + enable_jit_fused=enable_jit_fused, ) booster = Booster(plugin=plugin) diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 9cd0b57365df..919557797fcd 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -100,6 +100,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 1, diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index b70cba8b4a53..cc0786618853 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -99,6 +99,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 78d752b69003..405ceba328df 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -135,6 +135,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index d59d7e4ad499..4aac7f3d4ed7 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -131,6 +131,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, From daec9e83633125bb183072c50f6759d15a5930ff Mon Sep 17 00:00:00 2001 From: KKZ20 Date: Wed, 3 Apr 2024 16:23:35 +0800 Subject: [PATCH 50/50] remove useless code --- colossalai/shardformer/modeling/llama.py | 1 - colossalai/shardformer/policies/llama.py | 15 --------------- 2 files changed, 16 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 463344446722..484fed95fad9 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -827,7 +827,6 @@ def forward( # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) - #### if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 55454b6f37c2..d41bf2bd1c1d 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -87,21 +87,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=LlamaAttention, ) - # elif sp_mode == "ring": - # self.append_or_create_method_replacement( - # description={ - # "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), - # }, - # policy=policy, - # target_key=LlamaAttention, - # ) - # self.append_or_create_method_replacement( - # description={ - # "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group), - # }, - # policy=policy, - # target_key=LlamaModel, - # ) elif sp_mode == "all_to_all": decoder_attribute_replacement = { "num_heads": self.model.config.num_attention_heads // sp_size,