diff --git a/.github/workflows/doc_check_on_pr.yml b/.github/workflows/doc_check_on_pr.yml index ee8a82128dd7..68e13a971e7e 100644 --- a/.github/workflows/doc_check_on_pr.yml +++ b/.github/workflows/doc_check_on_pr.yml @@ -58,6 +58,7 @@ jobs: # there is no main branch, so it's safe to checkout the main branch from the merged branch # docer will rebase the remote main branch to the merged branch, so we have to config user - name: Make the merged branch main + run: | cd ColossalAI git checkout -b main diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 7cec5f003bae..edbb7118aa1a 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -38,6 +38,19 @@ def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None: class ZeroBubbleVPipeScheduler(PipelineSchedule): + r""" + ZeroBubbleVPipeScheduler + + Args: + stage_manager (PipelineStageManager): If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism. + schedule (List[ScheduledNode]): Schedule for ZeroBubbleVPipe. + num_model_chunks (int) : The number of model chunk in a device. + num_microbatch (Optional[int]): The number of microbatch. + microbatch_size (Optional[int]): The size per microbatch. + enable_metadata_cache (bool): whether to enable metadata cache to acclerate communication. + overlap_p2p (bool): whether to use overlap_p2p. + """ + def __init__( self, stage_manager: PipelineStageManager, diff --git a/colossalai/pipeline/weight_grad_store.py b/colossalai/pipeline/weight_grad_store.py index c51c45085ea2..1a9ef142156d 100644 --- a/colossalai/pipeline/weight_grad_store.py +++ b/colossalai/pipeline/weight_grad_store.py @@ -8,7 +8,6 @@ class WeightGradStore: @classmethod def put(cls, total_input, grad_output, weight, func): - # func(total_input, grad_output, weight.main_grad) cls.cache.append((total_input, grad_output, weight, func)) @classmethod @@ -18,15 +17,26 @@ def flush(cls, chunk=0): @classmethod def pop(cls, chunk=0): - # print(f"chunk id {chunk} queue size {cls.weight_grad_queue[chunk].qsize()}") if cls.weight_grad_queue[chunk].qsize() > 0: stored_grads = cls.weight_grad_queue[chunk].get() for total_input, grad_output, weight, func in stored_grads: - if weight.grad is not None: - func(total_input, grad_output, weight.grad) - # for first bwd; weight.grad is None, assign grad_weight to weight.grad + if isinstance(weight, tuple): + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. + # View will lead to weight ptr change + # weight_cal & weight_origin in tuple, weight_cal use to cal dw, weight_origin use to update + _, weight_origin = weight + if weight_origin.grad is not None: + func(total_input, grad_output, weight_origin.grad) + # for first bwd; weight.grad is None, assign grad_weight to weight.grad + else: + grad_weight = func(total_input, grad_output) + weight_origin.grad = grad_weight else: - grad_weight = func(total_input, grad_output) - weight.grad = grad_weight + if weight.grad is not None: + func(total_input, grad_output, weight.grad) + # for first bwd; weight.grad is None, assign grad_weight to weight.grad + else: + grad_weight = func(total_input, grad_output) + weight.grad = grad_weight else: raise Exception("Pop empty queue.") diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index da5363840848..0bd1b60923e9 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -6,7 +6,14 @@ from .loss import cross_entropy_1d, dist_cross_entropy from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule -from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row +from .qkv_fused_linear import ( + FusedLinear, + FusedLinear1D_Col, + FusedLinear1D_Row, + GPT2FusedLinearConv, + GPT2FusedLinearConv1D_Col, + GPT2FusedLinearConv1D_Row, +) __all__ = [ "Embedding1D", @@ -14,8 +21,9 @@ "LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row", - "GPT2FusedLinearConv1D_Col", + "GPT2FusedLinearConv", "GPT2FusedLinearConv1D_Row", + "GPT2FusedLinearConv1D_Col", "DropoutForParallelInput", "DropoutForReplicatedInput", "cross_entropy_1d", @@ -26,6 +34,7 @@ "FusedLayerNorm", "FusedRMSNorm", "FusedLinear1D_Col", + "FusedLinear", "ParallelModule", "PaddingEmbedding", "PaddingLMHead", diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 8c2e6e7c5d92..0252f90e1c27 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -6,7 +6,13 @@ from colossalai.pipeline.weight_grad_store import WeightGradStore -from .utils import is_share_sp_tp +from .utils import ( + execute_conv1d_w_pass, + execute_conv1d_w_pass_grad_accum, + execute_w_pass, + execute_w_pass_grad_accum, + is_share_sp_tp, +) try: import fused_mix_prec_layer_norm_cuda @@ -73,12 +79,13 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce ctx.fp8_communication = fp8_communication + ctx.use_zbv = use_zbv output = torch.matmul(input_, weight) @@ -92,8 +99,10 @@ def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias fp8_communication = ctx.fp8_communication + use_zbv = ctx.use_zbv # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. + weight_origin = weight weight = weight.view(weight.shape) if bias is not None: bias = bias.view(bias.shape) @@ -114,7 +123,42 @@ def backward(ctx, grad_output): # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py - grad_weight = total_input.t().matmul(grad_output) + # split dx & dw + if _grad_accum_fusion_available and weight.grad is not None: + grad = weight.grad + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + (weight, weight_origin), + functools.partial( + execute_conv1d_w_pass_grad_accum, + ), + ) + grad_weight = None + else: + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = total_input.t().matmul(grad_output) + else: + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + (weight, weight_origin), + functools.partial( + execute_conv1d_w_pass, + wgrad_gemm_func=torch.matmul, + ), + ) + grad_weight = None + else: + grad_weight = total_input.t().matmul(grad_output) grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.async_grad_allreduce and not fp8_communication: @@ -123,6 +167,87 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias, None, None, None, None +class MatmulWithGradAccum(torch.autograd.Function): + """ + Linear layer execution with grad accum in backprop. (no tp version) + """ + + @staticmethod + def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False): + ctx.save_for_backward(input_, weight, bias) + ctx.use_bias = bias is not None + ctx.async_grad_allreduce = async_grad_allreduce + ctx.use_zbv = use_zbv + + output = torch.matmul(input_, weight) + if bias is not None: + output = output + bias + + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + use_bias = ctx.use_bias + use_zbv = ctx.use_zbv + + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. + weight_origin = weight + weight = weight.view(weight.shape) + if bias is not None: + bias = bias.view(bias.shape) + + total_input = input + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + # split dx & dw + if _grad_accum_fusion_available and weight.grad is not None: + grad = weight.grad + + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + (weight, weight_origin), + functools.partial( + execute_conv1d_w_pass_grad_accum, + ), + ) + grad_weight = None + else: + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = total_input.t().matmul(grad_output) + else: + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + (weight, weight_origin), + functools.partial( + execute_conv1d_w_pass, + wgrad_gemm_func=torch.matmul, + ), + ) + grad_weight = None + else: + grad_weight = total_input.t().matmul(grad_output) + + grad_bias = grad_output.sum(dim=0) if use_bias else None + + return grad_input, grad_weight, grad_bias, None, None, None, None + + class LinearWithAsyncCommunication(torch.autograd.Function): """ Linear layer execution with asynchronous communication in backprop. @@ -150,12 +275,6 @@ def backward(ctx, grad_output): fp8_communication = ctx.fp8_communication use_zbv = ctx.use_zbv - def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): - wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) - - def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): - return wgrad_gemm_func(_grad_output_.t(), _input_) - # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. if use_bias: bias.view(bias.shape) @@ -179,31 +298,15 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad if use_zbv: - # TODO: append input, grad_output_, weight, grad func to WeightGradStore - if grad.dtype == torch.float32: - WeightGradStore.put( - total_input, - grad_output, - weight, - functools.partial( - execute_w_pass_grad_accum, - wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, - ), - ) - grad_weight = None - elif grad.dtype in (torch.float16, torch.bfloat16): - WeightGradStore.put( - total_input, - grad_output, - weight, - functools.partial( - execute_w_pass_grad_accum, - wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, - ), - ) - grad_weight = None - else: - raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + ), + ) + grad_weight = None else: if grad.dtype == torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) @@ -259,12 +362,6 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias use_zbv = ctx.use_zbv - def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): - wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) - - def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): - return wgrad_gemm_func(_grad_output_.t(), _input_) - # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. if use_bias: bias.view(bias.shape) @@ -280,31 +377,15 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad if use_zbv: - # TODO: append input, grad_output_, weight, grad func to WeightGradStore - if grad.dtype == torch.float32: - WeightGradStore.put( - total_input, - grad_output, - weight, - functools.partial( - execute_w_pass_grad_accum, - wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, - ), - ) - grad_weight = None - elif grad.dtype in (torch.float16, torch.bfloat16): - WeightGradStore.put( - total_input, - grad_output, - weight, - functools.partial( - execute_w_pass_grad_accum, - wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, - ), - ) - grad_weight = None - else: - raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + ), + ) + grad_weight = None else: if grad.dtype == torch.float32: fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) @@ -454,12 +535,13 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, use_zbv=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.dim = dim + ctx.use_zbv = use_zbv if ring is True: input_to_gather = {"input": input_} @@ -491,6 +573,7 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias dim = ctx.dim process_group = ctx.process_group + use_zbv = ctx.use_zbv # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm if use_bias: @@ -518,23 +601,46 @@ def backward(ctx, grad_output): if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad - if grad.dtype == torch.float32: - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + ), + ) grad_weight = None - elif grad.dtype == torch.float16: - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + else: + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + else: + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass, + wgrad_gemm_func=torch.matmul, + ), + ) grad_weight = None else: grad_weight = grad_output.t().matmul(total_input) - else: - grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.async_grad_reduce_scatter: handle.wait() - return output, grad_weight, grad_bias, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None def _ring_as_reducescatter( @@ -606,11 +712,12 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, dim, ring): + def forward(ctx, input_, weight, bias, process_group, dim, ring, use_zbv=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.dim = dim + ctx.use_zbv = use_zbv if ring is True: input_to_reducescatter = {"input": input_} @@ -651,7 +758,7 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias dim = ctx.dim process_group = ctx.process_group - + use_zbv = ctx.use_zbv # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm if use_bias: bias = bias.view(bias.shape) @@ -666,10 +773,47 @@ def backward(ctx, grad_output): if len(grad_output.shape) > 2: grad_output = grad_output.view(-1, grad_output.shape[-1]) total_input = total_input.reshape(-1, total_input.shape[-1]) - grad_weight = grad_output.t().matmul(total_input) + + if _grad_accum_fusion_available and weight.grad is not None: + grad = weight.grad + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + ), + ) + grad_weight = None + else: + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + else: + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass, + wgrad_gemm_func=torch.matmul, + ), + ) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + + # grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None - return grad_input, grad_weight, grad_bias, None, None, None + return grad_input, grad_weight, grad_bias, None, None, None, None class _ReduceScatterForwardGatherBackward(torch.autograd.Function): @@ -723,13 +867,16 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication): + def forward( + ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication, use_zbv=False + ): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.dim = dim ctx.fp8_communication = fp8_communication + ctx.use_zbv = use_zbv if ring is True: input_to_gather = {"input": input_} @@ -759,8 +906,10 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias dim = ctx.dim process_group = ctx.process_group + use_zbv = ctx.use_zbv # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm + weight_origin = weight weight = weight.view(weight.shape) if use_bias: bias = bias.view(bias.shape) @@ -785,13 +934,49 @@ def backward(ctx, grad_output): # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated - grad_weight = total_input.t().matmul(grad_output) + # split dx & dw + if _grad_accum_fusion_available and weight.grad is not None: + grad = weight.grad + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + (weight, weight_origin), + functools.partial( + execute_conv1d_w_pass_grad_accum, + ), + ) + grad_weight = None + else: + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = total_input.t().matmul(grad_output) + else: + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + (weight, weight_origin), + functools.partial( + execute_conv1d_w_pass, + wgrad_gemm_func=torch.matmul, + ), + ) + grad_weight = None + else: + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.async_grad_reduce_scatter: handle.wait() - return output, grad_weight, grad_bias, None, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): @@ -1108,12 +1293,18 @@ def _all_to_all_single( ).contiguous() -def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): +def matmul_with_async_comm( + input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False +): return MatmulWithAsyncCommunication.apply( - input_, weight, bias, process_group, async_grad_allreduce, fp8_communication + input_, weight, bias, process_group, async_grad_allreduce, fp8_communication, use_zbv ) +def matmul_with_grad_comm(input_, weight, bias, async_grad_allreduce, use_zbv=False): + return MatmulWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv) + + def linear_with_async_comm( input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False ): @@ -1127,10 +1318,10 @@ def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=F def linear_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, use_zbv=False ): return _LinearWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, use_zbv ) @@ -1142,15 +1333,25 @@ def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_commun return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim, fp8_communication) -def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False): - return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring) +def linear_reducescatter_forward_gather_backward( + input_, weight, bias=None, process_group=None, dim=1, ring=False, use_zbv=False +): + return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring, use_zbv) def matmul_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False + input_, + weight, + bias, + process_group, + async_grad_reduce_scatter, + dim, + ring=False, + fp8_communication=False, + use_zbv=False, ): return _MatmulWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication, use_zbv ) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index d39d6e997af8..fe195d6987da 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -350,6 +350,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: True, self.seq_parallel_dim, ring=self.seq_parallel_mode == "ring", + use_zbv=self.use_zbv, ) else: output_parallel = linear_with_async_comm( @@ -580,6 +581,7 @@ def forward(self, input_: Tensor) -> Tensor: process_group=self.process_group, dim=self.seq_parallel_dim, ring=self.seq_parallel_mode == "ring", + use_zbv=self.use_zbv, ) else: output_parallel = F.linear(input_, self.weight) diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 6e469686b403..e3aaa3a4635c 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -7,7 +7,6 @@ import torch import torch.distributed as dist import torch.nn as nn -import torch.nn.functional as F from torch import Tensor from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter @@ -28,8 +27,10 @@ linear_gather_forward_reducescatter_backward, linear_reducescatter_forward_gather_backward, linear_with_async_comm, + linear_with_grad_accum, matmul_gather_forward_reducescatter_backward, matmul_with_async_comm, + matmul_with_grad_comm, reduce_forward, reducescatter_forward_gather_backward, split_forward_gather_backward, @@ -37,7 +38,14 @@ from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset, is_share_sp_tp -__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"] +__all__ = [ + "FusedLinear1D_Col", + "FusedLinear1D_Row", + "FusedLinear", + "GPT2FusedLinearConv1D_Col", + "GPT2FusedLinearConv1D_Row", + "GPT2FusedLinearConv", +] # ==================================== # For GPT Only @@ -228,6 +236,7 @@ def __init__( weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), fp8_communication: bool = False, + use_zbv: bool = False, ): super().__init__() @@ -241,6 +250,7 @@ def __init__( self.split_sizes = split_sizes self.process_group = process_group self.fp8_communication = fp8_communication + self.use_zbv = use_zbv assert ( sum(split_sizes) == out_features @@ -375,6 +385,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: 1, ring=self.seq_parallel_mode == "ring", fp8_communication=self.fp8_communication, + use_zbv=self.use_zbv, ) elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn": # Set up backprop all-reduce. @@ -386,6 +397,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: self.process_group, True, fp8_communication=self.fp8_communication, + use_zbv=self.use_zbv, ) else: raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!") @@ -441,6 +453,7 @@ def __init__( bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), stream_chunk_num: int = 1, fp8_communication: bool = False, + use_zbv: bool = False, ): super().__init__() @@ -455,6 +468,7 @@ def __init__( self.seq_parallel_mode = seq_parallel_mode self.num_partitions = dist.get_world_size(self.process_group) self.fp8_communication = fp8_communication + self.use_zbv = use_zbv if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -620,6 +634,152 @@ def forward(self, input_: Tensor) -> Tensor: return output, self.bias +class GPT2FusedLinearConv(ParallelModule): + r"""Linear layer without parallelism. + This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None. + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + seq_parallel_mode: str = None, + seq_parallel_dim: int = 1, + skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + use_zbv: bool = False, + ): + super().__init__() + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.seq_parallel_mode = seq_parallel_mode + self.seq_parallel_dim = seq_parallel_dim + self.skip_bias_add = skip_bias_add + self.device = device + self.use_zbv = use_zbv + + if skip_bias_add and not bias: + raise ValueError("cannot skip bias addition if bias is None") + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, None) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" + else: + assert bias_ is None, "bias_ must be None if weight is None" + + # Parameters. + if weight is None: + # Initialize weight. + factory_kwargs = {"device": device, "dtype": dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + else: + self.bias = None + + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module( + module: nn.Module, + *args, + **kwargs, + ) -> ParallelModule: + r""" + Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. + + Args: + module (`nn.Linear`): The module to be converted. + split_sizes (List[int]): The sizes of the split tensor. In GPT2, Q,K,V are fused in one weight. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.weight.shape[0] + out_features = module.weight.shape[1] + bias = module.bias is not None + device = module.weight.device + + linear_1d = GPT2FusedLinearConv( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + weight=module.weight, + bias_=module.bias, + *args, + **kwargs, + ) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn": + # Set up backprop all-reduce. + input_parallel = input_ + output_parallel = matmul_with_grad_comm( + input_parallel, + self.weight, + bias, + False, + self.use_zbv, + ) + else: + raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!") + + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + # ==================================== # For Fused torch.nn.Linear # ==================================== @@ -671,6 +831,7 @@ def __init__( weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), fp8_communication: bool = False, + use_zbv: bool = False, ): super().__init__() # Keep input parameters @@ -684,6 +845,7 @@ def __init__( self.split_sizes = split_sizes self.process_group = process_group self.fp8_communication = fp8_communication + self.use_zbv = use_zbv assert ( sum(split_sizes) == out_features @@ -811,10 +973,17 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: True, self.seq_parallel_dim, ring=self.seq_parallel_mode == "ring", + use_zbv=self.use_zbv, ) else: output_parallel = linear_with_async_comm( - input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication + input_parallel, + self.weight, + bias, + self.process_group, + True, + fp8_communication=self.fp8_communication, + use_zbv=self.use_zbv, ) if self.gather_output: @@ -870,6 +1039,7 @@ def __init__( weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), fp8_communication: bool = False, + use_zbv: bool = False, ): super().__init__() # Keep input parameters @@ -883,6 +1053,7 @@ def __init__( self.seq_parallel_dim = seq_parallel_dim self.num_partitions = dist.get_world_size(self.process_group) self.fp8_communication = fp8_communication + self.use_zbv = use_zbv assert ( sum(split_sizes) == in_features @@ -1009,9 +1180,18 @@ def forward(self, input_: Tensor) -> Tensor: process_group=self.process_group, dim=self.seq_parallel_dim, ring=self.seq_parallel_mode == "ring", + use_zbv=self.use_zbv, ) else: - output_parallel = F.linear(input_, self.weight) + # output_parallel = F.linear(input_, self.weight) # Replace to LinearWithGradAccum + output_parallel = linear_with_grad_accum( + input_, + self.weight, + None, + False, + use_zbv=self.use_zbv, + ) + output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) if not self.skip_bias_add: @@ -1020,3 +1200,156 @@ def forward(self, input_: Tensor) -> Tensor: return output else: return output, self.bias + + +class FusedLinear(ParallelModule): + r"""Fused Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `torch.nn.Linear` layer (Fused QKV) in normal torch layer of huggingface, like SAM. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + split_sizes (List[int]): The sizes of the split tensor. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + seq_parallel_mode: str = None, + seq_parallel_dim: int = 1, + skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + use_zbv: bool = False, + ): + super().__init__() + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.seq_parallel_mode = seq_parallel_mode + self.seq_parallel_dim = seq_parallel_dim + self.skip_bias_add = skip_bias_add + self.device = device + self.use_zbv = use_zbv + + if skip_bias_add and not bias: + raise ValueError("cannot skip bias addition if bias is None") + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=None) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" + else: + assert bias_ is None, "bias_ must be None if weight is None" + + # Parameters. + if weight is None: + # Initialize weight. + factory_kwargs = {"device": device, "dtype": dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + else: + self.bias = None + + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module( + module: nn.Module, + *args, + **kwargs, + ) -> ParallelModule: + r""" + Convert a fused `torch.nn.linear` layer to a parallelized linear layer. + + Args: + module (`nn.Linear`): The module to be converted. + process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. + split_sizes (List[int]): The sizes of the split tensor. In common, Q,K,V are fused in one weight. + """ + LazyInitContext.materialize(module) + + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + linear_1d = FusedLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + weight=module.weight, + bias_=module.bias, + *args, + **kwargs, + ) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) + # Set up backprop all-reduce. + input_parallel = input_ + + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + + output_parallel = linear_with_grad_accum(input_parallel, self.weight, bias, True, use_zbv=self.use_zbv) + + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 2df68e18c64d..6ce1eb79df3f 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -9,6 +9,43 @@ from colossalai.accelerator import get_accelerator +try: + import fused_weight_gradient_mlp_cuda + + _grad_accum_fusion_available = True +except ImportError: + _grad_accum_fusion_available = False + + +# execute_w_pass_grad_accum & execute_conv1d_w_pass for GPT2FusedLinearConv1D +def execute_conv1d_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_): + if _input_.dtype == torch.float32: + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 + elif _input_.dtype in (torch.float16, torch.bfloat16): + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_) + + +def execute_conv1d_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): + return wgrad_gemm_func(_input_.t(), _grad_output_) + + +# execute_w_pass_grad_accum & execute_w_pass for Linear (except GPT2FusedLinearConv1D) +def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_): + if _input_.dtype == torch.float32: + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 + elif _input_.dtype in (torch.float16, torch.bfloat16): + wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) + + +def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): + return wgrad_gemm_func(_grad_output_.t(), _input_) + class SeqParallelUtils: @staticmethod diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 2e73d5c2a637..a2f582e9daa6 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -51,6 +51,8 @@ def module_policy(self): else: norm_cls = col_nn.LayerNorm + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -73,6 +75,7 @@ def module_policy(self): kwargs={ "split_sizes": [self.model.config.vision_config.hidden_size] * 3, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -80,6 +83,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -88,6 +92,7 @@ def module_policy(self): kwargs={ "skip_bias_add": self.enable_bias_gelu_fused, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -95,6 +100,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], @@ -126,6 +132,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -133,6 +140,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -140,6 +148,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -151,6 +160,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -162,6 +172,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -169,6 +180,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -176,6 +188,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -187,6 +200,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -198,6 +212,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -205,6 +220,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -227,6 +243,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -234,6 +251,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -241,6 +259,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -248,6 +267,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -255,6 +275,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -262,6 +283,226 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + + policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()}) + if self.enable_bias_gelu_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_blip2_mlp_forward(), + }, + policy=policy, + target_key=Blip2MLP, + ) + elif use_zbv: + policy[Blip2EncoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="self_attn.qkv", + target_module=col_nn.FusedLinear, + kwargs={ + "split_sizes": [self.model.config.vision_config.hidden_size] * 3, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.projection", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.fc1", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "skip_bias_add": self.enable_bias_gelu_fused, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.fc2", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + + policy[Blip2QFormerModel] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ] + ) + + policy[Blip2QFormerLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.query", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.key", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.value", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate_query.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="output_query.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="output_query.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) + + policy[OPTDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 7c6259e850c2..c7691698bed2 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -59,6 +59,8 @@ def module_policy(self): sp_partial_derived = sp_mode == "split_gather" + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.n_head % self.shard_config.tensor_parallel_size == 0 @@ -78,6 +80,7 @@ def module_policy(self): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -86,6 +89,7 @@ def module_policy(self): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -98,6 +102,7 @@ def module_policy(self): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -106,6 +111,7 @@ def module_policy(self): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], @@ -120,6 +126,52 @@ def module_policy(self): }, ) + if use_zbv: + policy[BloomBlock] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + if embedding_cls is not None: self.append_or_create_submodule_replacement( description=[ @@ -247,14 +299,27 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = stage_manager.distribute_layers(len(module.h)) - if stage_manager.is_first_stage(): - held_layers.append(module.word_embeddings) - held_layers.append(module.word_embeddings_layernorm) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.h[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.ln_f) + if stage_manager.is_interleave: + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.word_embeddings) + held_layers.append(module.word_embeddings_layernorm) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.h[start_idx:end_idx]) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(module.ln_f) + else: + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + if stage_manager.is_first_stage(): + held_layers.append(module.word_embeddings) + held_layers.append(module.word_embeddings_layernorm) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) return held_layers @@ -328,8 +393,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): - held_layers.append(self.model.lm_head) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -351,6 +422,7 @@ def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification policy = super().module_policy() + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv # handle tensor parallelism if self.shard_config.enable_tensor_parallelism: @@ -363,6 +435,18 @@ def module_policy(self): policy=policy, target_key=BloomForSequenceClassification, ) + elif use_zbv: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="score", + target_module=col_nn.LinearWithGradAccum, + kwargs=dict( + gather_output=True, fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv + ), + ), + policy=policy, + target_key=BloomForSequenceClassification, + ) if self.pipeline_stage_manager: self.set_pipeline_forward( model_cls=BloomForSequenceClassification, @@ -375,8 +459,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): - held_layers.append(self.model.score) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.score) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.score) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -389,6 +479,7 @@ def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForTokenClassification policy = super().module_policy() + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv # handle tensor parallelism if self.shard_config.enable_tensor_parallelism: @@ -407,6 +498,24 @@ def module_policy(self): policy=policy, target_key=BloomForTokenClassification, ) + elif use_zbv: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="classifier", + target_module=col_nn.LinearWithGradAccum, + kwargs=dict( + gather_output=True, fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv + ), + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ], + policy=policy, + target_key=BloomForTokenClassification, + ) if self.pipeline_stage_manager: self.set_pipeline_forward( model_cls=BloomForTokenClassification, @@ -420,9 +529,16 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): - held_layers.append(self.model.dropout) - held_layers.append(self.model.classifier) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -448,8 +564,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): - held_layers.append(self.model.qa_outputs) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.qa_outputs) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 4ddcf8bfce6b..1fed4e9e8647 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -83,6 +83,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: attribute_replacement=decoder_attribute_replacement, ) + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -145,6 +147,35 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ), ], ) + elif use_zbv: + policy["GLMBlock"] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + kwargs={ + "seq_parallel_mode": sp_mode, + "seq_parallel_dim": 0, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + kwargs={ + "seq_parallel_mode": sp_mode, + "seq_parallel_dim": 0, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attention.core_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) if embedding_cls is not None: self.append_or_create_submodule_replacement( @@ -261,17 +292,30 @@ def get_held_layers(self) -> List[nn.Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = stage_manager.distribute_layers(module.num_layers) - if stage_manager.is_first_stage(): - held_layers.append(module.embedding) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.encoder.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(): - if module.encoder.post_layer_norm: - held_layers.append(module.encoder.final_layernorm) - - # rotary_pos_emb is needed for all stages - held_layers.append(module.rotary_pos_emb) + if stage_manager.is_interleave: + layers_per_stage = stage_manager.distribute_layers(module.num_layers) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + if module.encoder.post_layer_norm: + held_layers.append(module.encoder.final_layernorm) + else: + layers_per_stage = stage_manager.distribute_layers(module.num_layers) + if stage_manager.is_first_stage(): + held_layers.append(module.embedding) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.encoder.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + if module.encoder.post_layer_norm: + held_layers.append(module.encoder.final_layernorm) + + # rotary_pos_emb is needed for all stages + held_layers.append(module.rotary_pos_emb) return held_layers @@ -335,8 +379,15 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.transformer.output_layer) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.transformer.output_layer) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.transformer.output_layer) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 323480d6d084..e6e741d34a3a 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -10,6 +10,7 @@ LayerNorm, Linear1D_Col, Linear1D_Row, + LinearWithGradAccum, PaddingEmbedding, PaddingLMHead, VocabParallelEmbedding1D, @@ -107,6 +108,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_key=CohereModel, ) + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( num_q_heads % tp_size == 0 @@ -128,41 +131,137 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + ], + ) + elif use_zbv: + policy[CohereDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), ], ) - if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -258,7 +357,9 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(ignore_chunk=True): + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): held_layers.append(module.norm) else: @@ -351,8 +452,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.lm_head) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index bd54e6f2db9e..9baf068aec9f 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -6,7 +6,7 @@ from torch.nn import Module from transformers.utils import is_flash_attn_greater_or_equal_2_10 -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, LinearWithGradAccum from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D from colossalai.shardformer.layer.linear import Linear1D_Row from colossalai.shardformer.modeling.deepseek import ( @@ -107,6 +107,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.tie_weight: embedding_cls = PaddingEmbedding + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: # tensor parallelism for non-moe params assert ( @@ -133,22 +135,58 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication": self.shard_config.fp8_communication}, + kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv}, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication": self.shard_config.fp8_communication}, + kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv}, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication": self.shard_config.fp8_communication}, + kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv}, ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs={"fp8_communication": self.shard_config.fp8_communication}, + kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv}, + ), + SubModuleReplacementDescription( + suffix="mlp.gate", + target_module=DeepseekMoEGate_Col, + kwargs={ + "gather_output": True, + "fp8_communication": self.shard_config.fp8_communication, + "config": self.model.config, + }, + ignore_if_not_exist=True, + ), + ], + ) + elif use_zbv: + policy["DeepseekDecoderLayer"] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=LinearWithGradAccum, + kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv}, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=LinearWithGradAccum, + kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv}, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=LinearWithGradAccum, + kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv}, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=LinearWithGradAccum, + kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv}, ), SubModuleReplacementDescription( suffix="mlp.gate", @@ -162,7 +200,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ), ], ) - if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -291,13 +328,26 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = stage_manager.distribute_layers(len(module.layers)) - if stage_manager.is_first_stage(): - held_layers.append(module.embed_tokens) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.norm) + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(module.norm) + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) return held_layers @@ -330,6 +380,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class DeepseekForCausalLMPolicy(DeepseekPolicy): def module_policy(self): policy = super().module_policy() + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv # TODO: assign pg mesh from plugin to all modules if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm @@ -339,7 +390,29 @@ def module_policy(self): SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, - kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ) + ] + ) + } + policy.update(new_item) + elif use_zbv: + # add a new item for casual lm + new_item = { + "DeepseekForCausalLM": ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=LinearWithGradAccum, + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ) ] ) @@ -360,8 +433,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): - held_layers.append(self.model.lm_head) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index e20fb1568505..68a548aee869 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -51,6 +51,8 @@ def module_policy(self): if self.tie_weight: embedding_cls = col_nn.PaddingEmbedding + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -73,10 +75,16 @@ def module_policy(self): SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, + kwargs=dict( + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, + kwargs=dict( + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attention.attention_dropout", @@ -85,8 +93,17 @@ def module_policy(self): SubModuleReplacementDescription( suffix="mlp.dense_h_to_4h", target_module=col_nn.Linear1D_Col, + kwargs=dict( + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=col_nn.Linear1D_Row, + kwargs=dict( + use_zbv=use_zbv, + ), ), - SubModuleReplacementDescription(suffix="mlp.dense_4h_to_h", target_module=col_nn.Linear1D_Row), ], ) @@ -98,6 +115,44 @@ def module_policy(self): "build_alibi_tensor": build_falcon_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group) }, ) + elif use_zbv: + policy[FalconDecoderLayer] = ModulePolicyDescription( + method_replacement={"forward": get_tp_falcon_decoder_layer_forward()}, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.LinearWithGradAccum, + kwargs=dict( + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs=dict( + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=col_nn.LinearWithGradAccum, + kwargs=dict( + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=col_nn.LinearWithGradAccum, + kwargs=dict( + use_zbv=use_zbv, + ), + ), + ], + ) if embedding_cls is not None: self.append_or_create_submodule_replacement( @@ -191,13 +246,26 @@ def get_held_layers(self) -> List[Module]: module = self.model.transformer stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = stage_manager.distribute_layers(len(module.h)) - if stage_manager.is_first_stage(): - held_layers.append(module.word_embeddings) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.h[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.ln_f) + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.word_embeddings) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.h[start_idx:end_idx]) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(module.ln_f) + else: + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + if stage_manager.is_first_stage(): + held_layers.append(module.word_embeddings) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) return held_layers @@ -281,8 +349,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): - held_layers.append(self.model.lm_head) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -308,11 +382,23 @@ def module_policy(self): policy = super().module_policy() + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + # handle tensor parallelism if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True, use_zbv=use_zbv) + ), + policy=policy, + target_key=FalconForSequenceClassification, + ) + elif use_zbv: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="score", + target_module=col_nn.LinearWithGradAccum, + kwargs=dict(gather_output=True, use_zbv=use_zbv), ), policy=policy, target_key=FalconForSequenceClassification, @@ -330,8 +416,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): - held_layers.append(self.model.score) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.score) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.score) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -348,12 +440,32 @@ def module_policy(self): policy = super().module_policy() + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + # handle tensor parallelism if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=[ SubModuleReplacementDescription( - suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + suffix="classifier", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True, use_zbv=use_zbv), + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ], + policy=policy, + target_key=FalconForTokenClassification, + ) + elif use_zbv: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="classifier", + target_module=col_nn.LinearWithGradAccum, + kwargs=dict(gather_output=True, use_zbv=use_zbv), ), SubModuleReplacementDescription( suffix="dropout", @@ -375,9 +487,16 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): - held_layers.append(self.model.dropout) - held_layers.append(self.model.classifier) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -394,11 +513,25 @@ def module_policy(self): policy = super().module_policy() + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + # handle tensor parallelism if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="qa_outputs", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + suffix="qa_outputs", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True, use_zbv=use_zbv), + ), + policy=policy, + target_key=FalconForQuestionAnswering, + ) + elif use_zbv: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="qa_outputs", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True, use_zbv=use_zbv), ), policy=policy, target_key=FalconForQuestionAnswering, @@ -415,8 +548,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): - held_layers.append(self.model.qa_outputs) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.qa_outputs) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 08accaaea279..c57d33826a39 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -67,6 +67,8 @@ def module_policy(self): self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather" sp_partial_derived = sp_mode in ["split_gather", "ring"] use_flash_attention = self.shard_config.enable_flash_attention + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -94,12 +96,17 @@ def module_policy(self): "split_sizes": [self.model.config.hidden_size] * 3, "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( suffix="attn.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), SubModuleReplacementDescription( suffix="mlp.c_fc", @@ -109,12 +116,88 @@ def module_policy(self): "seq_parallel_mode": sp_mode, "skip_bias_add": self.enable_bias_gelu_fused, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) + if self.enable_bias_gelu_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_gpt2_mlp_forward(), + }, + policy=policy, + target_key=GPT2MLP, + ) + elif use_zbv: + policy[GPT2Model] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="drop", + target_module=col_nn.DropoutForParallelInput, + ), + ] + ) + + policy[GPT2Block] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.c_attn", + target_module=col_nn.GPT2FusedLinearConv, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attn.c_proj", + target_module=col_nn.GPT2FusedLinearConv, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.c_fc", + target_module=col_nn.GPT2FusedLinearConv, + kwargs={ + "seq_parallel_mode": sp_mode, + "skip_bias_add": self.enable_bias_gelu_fused, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.c_proj", + target_module=col_nn.GPT2FusedLinearConv, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", @@ -138,6 +221,7 @@ def module_policy(self): policy=policy, target_key=GPT2MLP, ) + if embedding_cls is not None: # padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by self.append_or_create_submodule_replacement( @@ -352,8 +436,17 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.lm_head) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + # if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True): + # held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -420,13 +513,24 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - multiple_choice_head = self.model.multiple_choice_head - held_layers.append(self.model.lm_head) - held_layers.append(multiple_choice_head.summary) - held_layers.append(multiple_choice_head.activation) - held_layers.append(multiple_choice_head.first_dropout) - held_layers.append(multiple_choice_head.last_dropout) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + held_layers.append(multiple_choice_head.summary) + held_layers.append(multiple_choice_head.activation) + held_layers.append(multiple_choice_head.first_dropout) + held_layers.append(multiple_choice_head.last_dropout) + else: + if self.pipeline_stage_manager.is_last_stage(): + multiple_choice_head = self.model.multiple_choice_head + held_layers.append(self.model.lm_head) + held_layers.append(multiple_choice_head.summary) + held_layers.append(multiple_choice_head.activation) + held_layers.append(multiple_choice_head.first_dropout) + held_layers.append(multiple_choice_head.last_dropout) return held_layers @@ -464,8 +568,17 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.qa_outputs) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.qa_outputs) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) + # if self.pipeline_stage_manager.is_last_stage(): + # held_layers.append(self.model.qa_outputs) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -503,9 +616,20 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.dropout) - held_layers.append(self.model.classifier) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + # if self.pipeline_stage_manager.is_last_stage(): + # held_layers.append(self.model.dropout) + # held_layers.append(self.model.classifier) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -530,8 +654,18 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.score) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.score) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.score) + + # if self.pipeline_stage_manager.is_last_stage(): + # held_layers.append(self.model.score) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 9fcca1385f79..891ebbdcc693 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -51,6 +51,8 @@ def module_policy(self): self.shard_config.enable_sequence_parallelism = False warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -76,6 +78,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -83,6 +86,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -90,6 +94,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -97,6 +102,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -104,6 +110,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -111,6 +118,72 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) + elif use_zbv: + policy[GPTJBlock] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.k_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attn.q_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attn.v_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attn.out_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.fc_in", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.fc_out", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -127,7 +200,6 @@ def module_policy(self): ), ], ) - if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -200,13 +272,25 @@ def get_held_layers(self) -> List[nn.Module]: held_layers = [] layers_per_stage = stage_manager.distribute_layers(len(module.h)) - if stage_manager.is_first_stage(): - held_layers.append(module.wte) - held_layers.append(module.drop) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.h[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.ln_f) + if stage_manager.is_interleave: + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.wte) + held_layers.append(module.drop) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.h[start_idx:end_idx]) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(module.ln_f) + else: + if stage_manager.is_first_stage(): + held_layers.append(module.wte) + held_layers.append(module.drop) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) return held_layers def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: @@ -309,8 +393,15 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.lm_head) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -349,8 +440,15 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.score) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.score) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.score) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -378,8 +476,15 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.qa_outputs) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.qa_outputs) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index b4b87df923a3..f9c9a9404e72 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -324,9 +324,10 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(ignore_chunk=True): + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): held_layers.append(module.norm) - else: layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): @@ -419,8 +420,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.lm_head) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -475,8 +482,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.score) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.score) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.score) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index dd64ce652f86..50742b850b24 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -10,6 +10,7 @@ LayerNorm, Linear1D_Col, Linear1D_Row, + LinearWithGradAccum, PaddingEmbedding, PaddingLMHead, VocabParallelEmbedding1D, @@ -76,6 +77,8 @@ def module_policy(self): self.shard_config.enable_sequence_parallelism = False warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -85,10 +88,16 @@ def module_policy(self): SubModuleReplacementDescription( suffix="fc1", target_module=Linear1D_Col, + kwargs=dict( + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="fc2", target_module=Linear1D_Row, + kwargs=dict( + use_zbv=use_zbv, + ), ), ] ) @@ -104,6 +113,7 @@ def module_policy(self): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -111,6 +121,7 @@ def module_policy(self): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -118,6 +129,7 @@ def module_policy(self): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -125,11 +137,67 @@ def module_policy(self): target_module=Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], ) + elif use_zbv: + policy[OPTDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=LinearWithGradAccum, + kwargs=dict( + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=LinearWithGradAccum, + kwargs=dict( + use_zbv=use_zbv, + ), + ), + ] + ) + policy[attn_cls] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -221,15 +289,30 @@ def get_held_layers(self) -> List[nn.Module]: held_layers = [] layers_per_stage = stage_manager.distribute_layers(len(module.layers)) - if stage_manager.is_first_stage(): - held_layers.append(module.embed_tokens) - held_layers.append(module.embed_positions) - held_layers.append(module.project_in) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.final_layer_norm) - held_layers.append(module.project_out) + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + held_layers.append(module.embed_positions) + held_layers.append(module.project_in) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(module.final_layer_norm) + held_layers.append(module.project_out) + else: + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + held_layers.append(module.embed_positions) + held_layers.append(module.project_in) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.final_layer_norm) + held_layers.append(module.project_out) return held_layers def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: @@ -323,8 +406,15 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.lm_head) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -395,8 +485,15 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.qa_outputs) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.qa_outputs) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 1b066200de64..84d2b2fdbd99 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -9,6 +9,7 @@ FusedRMSNorm, Linear1D_Col, Linear1D_Row, + LinearWithGradAccum, PaddingEmbedding, RMSNorm, VocabParallelEmbedding1D, @@ -96,6 +97,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: attribute_replacement=decoder_attribute_replacement, ) + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -119,37 +122,134 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + ], + ) + elif use_zbv: + policy[Qwen2DecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ), ], ) @@ -278,7 +378,9 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(ignore_chunk=True): + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): held_layers.append(module.norm) else: @@ -318,6 +420,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy): def module_policy(self): policy = super().module_policy() setattr(self.shard_config, "causal_lm", True) + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm @@ -327,7 +430,22 @@ def module_policy(self): SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, - kwargs=dict(fp8_communication=self.shard_config.fp8_communication), + kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), + ) + ], + method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, + ) + } + policy.update(new_item) + elif use_zbv: + # add a new item for casual lm + new_item = { + Qwen2ForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=LinearWithGradAccum, + kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), ) ], method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, @@ -347,8 +465,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.lm_head) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -371,6 +495,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class Qwen2ForSequenceClassificationPolicy(Qwen2Policy): def module_policy(self): policy = super().module_policy() + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv if self.shard_config.enable_tensor_parallelism: # add a new item for sequence classification new_item = { @@ -379,7 +504,28 @@ def module_policy(self): SubModuleReplacementDescription( suffix="score", target_module=Linear1D_Col, - kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ) + ] + ) + } + policy.update(new_item) + elif use_zbv: + new_item = { + Qwen2ForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", + target_module=LinearWithGradAccum, + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ) ] ) @@ -399,8 +545,14 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.score) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.score) + else: + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.score) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index a94cc9119356..f37167afffff 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -27,6 +27,7 @@ def module_policy(self): norm_cls = col_nn.FusedLayerNorm else: norm_cls = col_nn.LayerNorm + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv if self.shard_config.enable_tensor_parallelism: assert ( @@ -44,6 +45,7 @@ def module_policy(self): kwargs={ "split_sizes": [self.model.config.vision_config.hidden_size] * 3, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -51,6 +53,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -58,6 +61,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -65,6 +69,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], @@ -80,6 +85,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -87,6 +93,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -94,6 +101,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -101,6 +109,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -108,6 +117,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -115,6 +125,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -122,6 +133,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -129,6 +141,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -136,6 +149,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -143,6 +157,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -150,6 +165,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -157,6 +173,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -164,6 +181,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -171,6 +189,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], @@ -186,6 +205,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -193,6 +213,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -200,6 +221,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -207,6 +229,209 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + + # add `DropoutForParallelInput` layer to replace the useage of `nn.functional.dropout` + policy[SamVisionAttention] = ModulePolicyDescription( + attribute_replacement={ + "dropout_layer": col_nn.DropoutForParallelInput(self.model.config.vision_config.attention_dropout) + }, + method_replacement={"forward": forward_fn()}, + sub_module_replacement=[], + ) + elif use_zbv: + policy[SamVisionLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.qkv", + target_module=col_nn.FusedLinear, + kwargs={ + "split_sizes": [self.model.config.vision_config.hidden_size] * 3, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attn.proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.lin1", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.lin2", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + policy[SamTwoWayAttentionBlock] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.num_attention_heads": self.model.config.mask_decoder_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.q_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.k_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.v_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.out_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.lin1", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.lin2", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.q_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.k_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.v_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.out_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + policy[SamTwoWayTransformer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.q_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.k_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.v_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.out_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 84b5d95947f0..6320a1668b09 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -13,6 +13,7 @@ FusedRMSNorm, Linear1D_Col, Linear1D_Row, + LinearWithGradAccum, PaddingEmbedding, PaddingLMHead, RMSNorm, @@ -77,6 +78,8 @@ def module_policy(self): self.shard_config.enable_sequence_parallelism = False warnings.warn("T5 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_heads % self.shard_config.tensor_parallel_size == 0 @@ -119,6 +122,7 @@ def module_policy(self): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -126,6 +130,7 @@ def module_policy(self): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -133,6 +138,7 @@ def module_policy(self): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -140,6 +146,7 @@ def module_policy(self): target_module=Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -168,6 +175,7 @@ def module_policy(self): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -175,6 +183,7 @@ def module_policy(self): target_module=Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -183,6 +192,7 @@ def module_policy(self): kwargs=dict( gather_output=True, fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, ), ), SubModuleReplacementDescription( @@ -198,6 +208,7 @@ def module_policy(self): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -205,6 +216,142 @@ def module_policy(self): target_module=Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + elif use_zbv: + policy[T5Stack] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + policy[T5LayerSelfAttention] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + policy[T5LayerCrossAttention] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ] + ) + policy[T5Attention] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="k", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="v", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="o", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="relative_attention_bias", + target_module=Embedding1D, + kwargs=dict( + gather_output=False, + fp8_communication=self.shard_config.fp8_communication, + ), + ignore_if_not_exist=True, + ), + ], + ) + policy[T5LayerFF] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + policy[T5DenseGatedActDense] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi_0 ", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="wi_1", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="wo", + target_module=LinearWithGradAccum, + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + policy[T5DenseActDense] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="wo", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -213,7 +360,6 @@ def module_policy(self): ), ] ) - if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -369,30 +515,61 @@ def get_held_layers(self) -> List[nn.Module]: num_decoder_layers = len(decoder.block) if decoder else 0 held_layers = [] - layers_per_stage, decoder_starting_stage = self.distribute_t5_layers( - num_encoder_layers, num_decoder_layers, stage_manager.num_stages - ) - start_idx, end_idx = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) - - if stage_manager.stage < decoder_starting_stage: - # current stage is in t5's encoder - if stage_manager.is_first_stage(): - held_layers.append(model.shared) - held_layers.append(encoder.embed_tokens) - held_layers.append(encoder.dropout) - if stage_manager.stage == decoder_starting_stage - 1: - held_layers.append(encoder.final_layer_norm) - held_layers.append(encoder.dropout) - held_layers.extend(encoder.block[start_idx:end_idx]) + if stage_manager.is_interleave: + layers_per_stage, decoder_starting_stage = self.distribute_t5_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + stage_indices = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) + if stage_manager.stage < decoder_starting_stage: + # current stage is in t5's encoder + if stage_manager.is_first_stage(): + held_layers.append(model.shared) + held_layers.append(encoder.embed_tokens) + held_layers.append(encoder.dropout) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(encoder.final_layer_norm) + held_layers.append(encoder.dropout) + for start_idx, end_idx in stage_indices: + held_layers.extend(encoder.block[start_idx:end_idx]) + else: + # current stage is in t5's decoder + if stage_manager.stage == decoder_starting_stage: + held_layers.append(decoder.embed_tokens) + held_layers.append(decoder.dropout) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(decoder.final_layer_norm) + held_layers.append(decoder.dropout) + for start_idx, end_idx in stage_indices: + held_layers.extend(decoder.block[start_idx:end_idx]) else: - # current stage is in t5's decoder - if stage_manager.stage == decoder_starting_stage: - held_layers.append(decoder.embed_tokens) - held_layers.append(decoder.dropout) - if stage_manager.is_last_stage(): - held_layers.append(decoder.final_layer_norm) - held_layers.append(decoder.dropout) - held_layers.extend(decoder.block[start_idx:end_idx]) + layers_per_stage, decoder_starting_stage = self.distribute_t5_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + start_idx, end_idx = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) + + if stage_manager.stage < decoder_starting_stage: + # current stage is in t5's encoder + if stage_manager.is_first_stage(): + held_layers.append(model.shared) + held_layers.append(encoder.embed_tokens) + held_layers.append(encoder.dropout) + if stage_manager.stage == decoder_starting_stage - 1: + held_layers.append(encoder.final_layer_norm) + held_layers.append(encoder.dropout) + held_layers.extend(encoder.block[start_idx:end_idx]) + else: + # current stage is in t5's decoder + if stage_manager.stage == decoder_starting_stage: + held_layers.append(decoder.embed_tokens) + held_layers.append(decoder.dropout) + if stage_manager.is_last_stage(): + held_layers.append(decoder.final_layer_norm) + held_layers.append(decoder.dropout) + held_layers.extend(decoder.block[start_idx:end_idx]) return held_layers def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: @@ -545,8 +722,15 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.lm_head) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.lm_head) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -652,9 +836,16 @@ def get_held_layers(self) -> List[nn.Module]: """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.dropout) - held_layers.append(self.model.classifier) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + else: + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 07202094f1f3..7b7dbf5557aa 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -43,6 +43,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: self.shard_config.enable_sequence_parallelism = False warnings.warn("Vit doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -72,6 +74,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -79,6 +82,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -86,6 +90,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -97,6 +102,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -109,6 +115,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: kwargs={ "skip_bias_add": self.enable_bias_gelu_fused, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -116,6 +123,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -132,7 +140,92 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=ViTIntermediate, ) + elif use_zbv: + policy[ViTEmbeddings] = ModulePolicyDescription( + attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForReplicatedInput, + ) + ], + ) + policy[ViTLayer] = ModulePolicyDescription( + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "skip_bias_add": self.enable_bias_gelu_fused, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ], + ) + if self.enable_bias_gelu_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_vit_intermediate_forward(), + }, + policy=policy, + target_key=ViTIntermediate, + ) # use flash attention if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement( @@ -173,11 +266,20 @@ def get_held_layers(self) -> List[nn.Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) - if stage_manager.is_first_stage(): - held_layers.append(module.embeddings) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.encoder.layer[start_idx:end_idx]) + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embeddings) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.encoder.layer[start_idx:end_idx]) + else: + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) + if stage_manager.is_first_stage(): + held_layers.append(module.embeddings) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.encoder.layer[start_idx:end_idx]) return held_layers def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, policy: Dict): @@ -213,9 +315,16 @@ def get_held_layers(self) -> List[nn.Module]: module = self.model stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): - held_layers.append(module.layernorm) - held_layers.append(module.pooler) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(module.layernorm) + held_layers.append(module.pooler) + else: + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(module.pooler) return held_layers @@ -226,6 +335,9 @@ def module_policy(self): from transformers.models.vit.modeling_vit import ViTForImageClassification, ViTModel policy = super().module_policy() + + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: new_item = { ViTForImageClassification: ModulePolicyDescription( @@ -233,13 +345,33 @@ def module_policy(self): SubModuleReplacementDescription( suffix="classifier", target_module=Linear1D_Col, - kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ) + ] + ) + } + policy.update(new_item) + elif use_zbv: + new_item = { + ViTForImageClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="classifier", + target_module=col_nn.LinearWithGradAccum, + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ) ] ) } policy.update(new_item) - if self.shard_config.pipeline_stage_manager is not None: self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) self.set_pipeline_forward( @@ -256,9 +388,16 @@ def get_held_layers(self) -> List[nn.Module]: module = self.model.vit stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): - held_layers.append(module.layernorm) - held_layers.append(self.model.classifier) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(module.layernorm) + held_layers.append(self.model.classifier) + else: + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(self.model.classifier) return held_layers @@ -285,8 +424,15 @@ def get_held_layers(self) -> List[nn.Module]: module = self.model.vit stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): - held_layers.append(module.layernorm) - held_layers.append(self.model.decoder) + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(module.layernorm) + held_layers.append(self.model.decoder) + else: + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(self.model.decoder) return held_layers diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 7a1f146d5bb8..5d9b38e6fe98 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -72,6 +72,8 @@ def module_policy(self): "Whisper doesn't support sequence parallelism now, will ignore the sequence parallelism flag." ) + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + # TODO using the jit fused add_and_dropout affect the accuracy if self.shard_config.enable_jit_fused: self.shard_config.enable_jit_fused = False @@ -93,6 +95,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -100,6 +103,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -107,6 +111,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -114,6 +119,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -121,6 +127,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -128,6 +135,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], @@ -148,6 +156,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -155,6 +164,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -162,6 +172,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -169,6 +180,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -176,6 +188,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -183,6 +196,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -190,6 +204,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -197,6 +212,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -204,6 +220,7 @@ def module_policy(self): target_module=col_nn.Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -211,6 +228,145 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + elif use_zbv: + policy[WhisperEncoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + + policy[WhisperDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.q_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.k_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.v_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.out_proj", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], @@ -460,30 +616,66 @@ def get_held_layers(self) -> List[nn.Module]: num_decoder_layers = 0 held_layers = [] - layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers( - num_encoder_layers, num_decoder_layers, stage_manager.num_stages - ) - start_idx, end_idx = self.get_whisper_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) - - if stage_manager.stage < decoder_starting_stage: - # current stage is in whisper's encoder - if stage_manager.is_first_stage(): - held_layers.append(encoder.embed_positions) - held_layers.append(encoder.conv1) - held_layers.append(encoder.conv2) - if stage_manager.stage == decoder_starting_stage - 1: - held_layers.append(encoder.layer_norm) - held_layers.extend(encoder.layers[start_idx:end_idx]) + if stage_manager.is_interleave: + layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + stage_indices = self.get_whisper_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) + + if stage_manager.stage < decoder_starting_stage: + # current stage is in whisper's encoder + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(encoder.embed_positions) + held_layers.append(encoder.conv1) + held_layers.append(encoder.conv2) + # interleaved: not use_zbv & stage_manager.stage == decoder_starting_stage - 1 + # zbv: use_zbv & stage_manager.stage == first stage + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and decoder_starting_stage - 1 + ): + held_layers.append(encoder.layer_norm) + for start_idx, end_idx in stage_indices: + held_layers.extend(encoder.layers[start_idx:end_idx]) + else: + # current stage is in whisper's decoder + # TODO:(Jianghai) We divide encoder and decoder layers into different parts here, + # the case encoder and decoder put in same stage should be add in the future. + if stage_manager.stage == decoder_starting_stage: + held_layers.append(decoder.embed_tokens) + held_layers.append(decoder.embed_positions) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(decoder.layer_norm) + for start_idx, end_idx in stage_indices: + held_layers.extend(encoder.layers[start_idx:end_idx]) else: - # current stage is in whisper's decoder - # TODO:(Jianghai) We divide encoder and decoder layers into different parts here, - # the case encoder and decoder put in same stage should be add in the future. - if stage_manager.stage == decoder_starting_stage: - held_layers.append(decoder.embed_tokens) - held_layers.append(decoder.embed_positions) - if stage_manager.is_last_stage(): - held_layers.append(decoder.layer_norm) - held_layers.extend(decoder.layers[start_idx:end_idx]) + layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + start_idx, end_idx = self.get_whisper_stage_index( + layers_per_stage, stage_manager.stage, decoder_starting_stage + ) + + if stage_manager.stage < decoder_starting_stage: + # current stage is in whisper's encoder + if stage_manager.is_first_stage(): + held_layers.append(encoder.embed_positions) + held_layers.append(encoder.conv1) + held_layers.append(encoder.conv2) + if stage_manager.stage == decoder_starting_stage - 1: + held_layers.append(encoder.layer_norm) + held_layers.extend(encoder.layers[start_idx:end_idx]) + else: + # current stage is in whisper's decoder + # TODO:(Jianghai) We divide encoder and decoder layers into different parts here, + # the case encoder and decoder put in same stage should be add in the future. + if stage_manager.stage == decoder_starting_stage: + held_layers.append(decoder.embed_tokens) + held_layers.append(decoder.embed_positions) + if stage_manager.is_last_stage(): + held_layers.append(decoder.layer_norm) + held_layers.extend(decoder.layers[start_idx:end_idx]) return held_layers def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: @@ -575,8 +767,15 @@ def postprocess(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.proj_out) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.proj_out) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.proj_out) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -629,9 +828,17 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): - held_layers.append(self.model.projector) - held_layers.append(self.model.classifier) + stage_manager = self.pipeline_stage_manager + if stage_manager.is_interleave: + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + held_layers.append(self.model.projector) + held_layers.append(self.model.classifier) + else: + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.projector) + held_layers.append(self.model.classifier) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/docs/source/en/features/zerobubble_pipeline_parallelism.md b/docs/source/en/features/zerobubble_pipeline_parallelism.md new file mode 100644 index 000000000000..1f88815fcbb1 --- /dev/null +++ b/docs/source/en/features/zerobubble_pipeline_parallelism.md @@ -0,0 +1,238 @@ +# ZeroBubble Pipeline Parallelism +Author: [Junwen Duan](https://github.com/duanjunwen), [Hongxin Liu](https://github.com/ver217) + +**Related Paper** +- [Zero Bubble Pipeline Parallelism](https://arxiv.org/abs/2401.10241) + +## Introduction +ZeroBubble (V Schedule): +Crucially, splitting B into two stages (also known as an activation gradient and a weight gradient) and a scheme like 1F1B1W can further reduce the bubble compared to the 1F1B scheme in earlier work. + +## Hands-On Practice +We now demonstrate how to use ZeroBubble with booster API with 4 GPUs. + +### step 1. Import libraries +```python +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaModel + +import colossalai +from colossalai.booster.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin +from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler +``` + +### step 2. Initialize Distributed Environment and Parallism Group +```python +colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") +``` + +### step 3. Initialize Module, Optimizer, and Pipeline Schedule +Build our model and Optimizer. We created a Llama with 8 Decoder-Layer. Then, inite the PipelineGraph and Pipeline schedule by get_v_schedule() function. +```python +# Global Param +NUM_BATCH = 8 +NUM_TOK_PER_BATCH = 4 +NUM_LAYERS = 8 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 4 +# Init Llama from huggingface +configuration = LlamaConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + attn_implementation="flash_attention_2", +) +model = LlamaModel(configuration).cuda() +optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) +``` +### step 4. Initialize Module, Optimizer, and Pipeline Schedul +Then, we need to create the PipelineGraph and PipelineSchedule using the get_v_schedule() function. We need to initialise the PipelineGraph with the following parameters. +x_cost represents the runtime consumed by operation x of each model chunk. +x_mem represents the amount of memory consumed by the operation x of each model chunk. +These parameters are estimated and filled in before the pipeline starts. In fact, better results can be obtained based on the runtime and memory cost during the real computation of the model. +In the following example, we assume that the computation times for the model's forward, reverse B, and reverse W are 1, 1, 1, respectively, and the p2p communication time is 1. +```python +# Init schedule +h, a, s = config.hidden_size, config.num_attention_heads, 1024 +mem_f = 34 * h + 5 * a * s +mem_w = -32 * h +mem_b = -mem_w - mem_f +graph = PipelineGraph( + n_stage=pp_size, + n_micro=num_microbatches, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, +) +zbv_schedule = graph.get_v_schedule() +``` + +### step 5.Init Booster +Pass pp_style="zbv" when initialising the Plugin to use the ZeroBubble Pipeline. +```python +plugin = HybridParallelPlugin( + pp_size=4, + num_microbatches=4, + tp_size=1, + sp_size=1, + zero_stage=1, + initial_scale=1, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) + +dp_size = plugin.dp_size +booster = Booster(plugin=plugin) +``` + +### step 6.Train Your Model +```python +steps = 10 +for step in range(steps): + input_embeddings = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + dist.all_reduce( + input_embeddings, group=plugin.pp_group + ) + data_iter = iter([{"inputs_embeds": input_embeddings}]) + output = booster.execute_pipeline( + data_iter, + model, + lambda x, y: x.last_hidden_state.mean(), + optimizer, + return_loss=True, + return_outputs=True, + ) + optimizer.step() + optimizer.zero_grad() +``` + +## Advanced Practice +In ColossalAI, you can get better training performance by using MetaCache and HybridParallel with ZeroBubble. +### 1.Use MetaCache with ZeroBubble +Pass "enable_metadata_cache=True" when initialising the Plugin to use the Meta Cache with ZeroBubble Pipeline. +```python +plugin = HybridParallelPlugin( + pp_size=2, + num_microbatches=4, + tp_size=2, + sp_size=2, + zero_stage=1, + initial_scale=1, + enable_metadata_cache=True, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) +``` + +### 2.HybridParallel with ZeroBubble +Pass pp_size, tp_size, sp_size when initialising the Plugin to use the HybridParallel with ZeroBubble Pipeline. +```python +plugin = HybridParallelPlugin( + pp_size=2, + num_microbatches=2, + tp_size=2, + sp_size=2, + zero_stage=1, + initial_scale=1, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) +``` +Performance Benchmark + + + + + + + + + + + + + + + + + + + + + + +
HybridParallel StrategyPipeline ParallelSequence Parallel + Pipeline ParallelData Parallel + Pipeline Parallel
With 1F1B15.27 samples/sec17.22 samples/sec14.06 samples/sec
With Zero Bubble17.36 samples/sec18.38 samples/sec14.44 samples/sec
+ +### 3.Fine-tuning Scheduler parameters + +```python +``` +## Model compatibility + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Shardformer/ModelBertBlip2BloomChatglm2CommandDeepseekFalconGPT2GptjLlamaMistralOptQwen2SamT5VitWhisper
ZeroBubble✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️
+ +## API Reference +{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }} + + diff --git a/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md new file mode 100644 index 000000000000..70e9e4c98631 --- /dev/null +++ b/docs/source/zh-Hans/features/zerobubble_pipeline_parallelism.md @@ -0,0 +1,237 @@ +# 零气泡流水线并行 +作者: [Junwen Duan](https://github.com/duanjunwen), [Hongxin Liu](https://github.com/ver217) + +**相关论文** +- [Zero Bubble Pipeline Parallelism](https://arxiv.org/abs/2401.10241) + +## 介绍 +零气泡(V Schedule): +与早期工作中的1F1B方案相比,零气泡流水线并行将B分成两个阶段(也称为激活梯度和权重梯度),形如1F1B1W这样的方案可以进一步减少气泡。 + +## 使用 +我们将演示如何在 4 个 GPU 上使用带有 booster API 的 ZeroBubble + +### step 1. 引用仓库 +```python +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaModel + +import colossalai +from colossalai.booster.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin +from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler +``` + +### step 2. 初始化分布式环境 +```python +colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") +``` + +### step 3. 初始化模型优化器 +建立我们的模型和优化器 我们创建了一个带有8层Decoder-Layer的 Llama。然后,使用get_v_schedule()函数创建PipelineGraph和Pipeline schedule。 + +```python +# Global Param +NUM_BATCH = 8 +NUM_TOK_PER_BATCH = 4 +NUM_LAYERS = 8 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 4 +# Init Llama from huggingface +configuration = LlamaConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + attn_implementation="flash_attention_2", +) +model = LlamaModel(configuration).cuda() +optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) +``` +### step 4.初始化流水线Schedule +然后,我们需要使用 get_v_schedule() 函数创建 PipelineGraph 和 PipelineSchedule。我们需要用以下参数初始化 PipelineGraph。 +x_cost 表示每个模型块的操作 x 所消耗的运行时间。 +x_mem 表示每个模型块的操作 x 所消耗的内存量。 +这些参数都是在流水线启动前估算并填入的。事实上,在模型的实际计算过程中,根据运行时间和内存成本可以获得更好的结果。 +在下面的例子中,我们假设模型的正向、反向 B 和反向 W 的计算时间分别为 1、1、1,p2p 通信时间为 1。 +```python +# Init schedule +h, a, s = config.hidden_size, config.num_attention_heads, 1024 +mem_f = 34 * h + 5 * a * s +mem_w = -32 * h +mem_b = -mem_w - mem_f +graph = PipelineGraph( + n_stage=pp_size, + n_micro=num_microbatches, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, +) +zbv_schedule = graph.get_v_schedule() +``` + +### step 5.初始化Booster +在初始化Plugin时输入pp_style="zbv",以使用ZeroBubble流水线并行。 +```python +plugin = HybridParallelPlugin( + pp_size=4, + num_microbatches=4, + tp_size=1, + sp_size=1, + zero_stage=1, + initial_scale=1, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) + +dp_size = plugin.dp_size +booster = Booster(plugin=plugin) +``` + +### step 6.训练模型 +```python +steps = 10 +for step in range(steps): + input_embeddings = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + dist.all_reduce( + input_embeddings, group=plugin.pp_group + ) + data_iter = iter([{"inputs_embeds": input_embeddings}]) + output = booster.execute_pipeline( + data_iter, + model, + lambda x, y: x.last_hidden_state.mean(), + optimizer, + return_loss=True, + return_outputs=True, + ) + optimizer.step() + optimizer.zero_grad() +``` + +## 进阶使用技巧 +在 ColossalAI 中,通过使用MetaCache和混合并行的ZeroBubble,可以获得更好的训练性能。 + +### 1.在ZeroBubble中使用元数据缓存 +在初始化Plugin时输入 "enable_metadata_cache=True",以便在ZeroBubble管道中使用元数据缓存。 +```python +plugin = HybridParallelPlugin( + pp_size=2, + num_microbatches=4, + tp_size=2, + sp_size=2, + zero_stage=1, + initial_scale=1, + enable_metadata_cache=True, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) +``` + +### 2.同时使用ZeroBubble和混合并行 +在初始化插件时传递 pp_size, tp_size, sp_size, 以便使用零气泡混合并行管道(HybridParallel with ZeroBubble Pipeline)。 +```python +plugin = HybridParallelPlugin( + pp_size=2, + num_microbatches=2, + tp_size=2, + sp_size=2, + zero_stage=1, + initial_scale=1, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, +) +``` +性能指标 + + + + + + + + + + + + + + + + + + + + + + +
HybridParallel StrategyPipeline ParallelSequence Parallel + Pipeline ParallelData Parallel + Pipeline Parallel
With 1F1B15.27 samples/sec17.22 samples/sec14.06 samples/sec
With Zero Bubble17.36 samples/sec18.38 samples/sec14.44 samples/sec
+ +## 模型兼容性 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Shardformer/ModelBertBlip2BloomChatglm2CommandDeepseekFalconGPT2GptjLlamaMistralOptQwen2SamT5VitWhisper
ZeroBubble✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️✔️
+ +## API 参考 +{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }} + + diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index a45beb77108f..be1c24818424 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -8,7 +8,8 @@ import colossalai from colossalai.lazy import LazyInitContext -from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row +from colossalai.pipeline.weight_grad_store import WeightGradStore +from colossalai.shardformer.layer import GPT2FusedLinearConv, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -118,11 +119,82 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool): assert_close(target_grad, linear_row.weight.grad) +def check_linear_conv_1d_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: str): + ctx = LazyInitContext() if lazy_init else nullcontext() + + linear = Conv1D(192, 48).cuda() + with ctx: + linear_copy = Conv1D(192, 48).cuda() + linear_base = GPT2FusedLinearConv.from_native_module(linear_copy, seq_parallel_mode=seq_parallel_mode) + + assert linear.weight.shape == torch.Size([48, 192]) + assert linear_base.weight.shape == torch.Size([48, 192]) + assert linear_base.bias.shape == torch.Size([192]) + assert linear_copy.weight is linear_base.weight + assert linear_copy.bias is linear_base.bias + + # ensure weights are reversibly loadable + linear_base.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_base.state_dict()) + + # check computation correctness + x = torch.rand(1, 4, 48).cuda() + out = linear(x) + gather_out = linear_base(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + # check the input gradients & weight gradients + assert_close(out.grad, gather_out.grad) + assert_close(linear.weight.grad, linear_base.weight.grad) + + +def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: str): + ctx = LazyInitContext() if lazy_init else nullcontext() + + linear = Conv1D(192, 48).cuda() + with ctx: + linear_copy = Conv1D(192, 48).cuda() + linear_base = GPT2FusedLinearConv.from_native_module(linear_copy, seq_parallel_mode=seq_parallel_mode, use_zbv=True) + + assert linear.weight.shape == torch.Size([48, 192]) + assert linear_base.weight.shape == torch.Size([48, 192]) + assert linear_base.bias.shape == torch.Size([192]) + assert linear_copy.weight is linear_base.weight + assert linear_copy.bias is linear_base.bias + + # ensure weights are reversibly loadable + linear_base.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_base.state_dict()) + + # check computation correctness + x = torch.rand(1, 4, 48).cuda() + out = linear(x) + gather_out = linear_base(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue + WeightGradStore.pop(chunk=0) + + # check the input gradients & weight gradients + assert_close(out.grad, gather_out.grad) + assert_close(linear.weight.grad, linear_base.weight.grad) + + @parameterize("lazy_init", [False, True]) @parameterize("seq_parallel_mode", ["split_gather", None]) def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool): check_linear_conv_1d_col(lazy_init, seq_parallel_mode) check_linear_conv_1d_row(lazy_init, seq_parallel_mode) + check_linear_conv_1d_without_weight_grad_store(lazy_init, None) + check_linear_conv_1d_with_weight_grad_store(lazy_init, None) def run_dist(rank, world_size, port): diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index fccba564f7c7..b31342cb30af 100644 --- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -7,7 +7,7 @@ import colossalai from colossalai.lazy import LazyInitContext -from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row +from colossalai.shardformer.layer import FusedLinear, FusedLinear1D_Col, FusedLinear1D_Row from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -120,12 +120,45 @@ def check_linear_1d_col_row(lazy_init: bool): assert_close(target_grad2, linear_row.weight.grad) +@parameterize("lazy_init", [False, True]) +def check_linear_1d_base(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + linear = nn.Linear(8, 80).cuda() + with ctx: + linear_copy = nn.Linear(8, 80).cuda() + linear_base = FusedLinear.from_native_module(linear_copy) + + assert linear.weight.shape == torch.Size([80, 8]) + assert linear.bias.shape == torch.Size([80]) + assert linear_base.weight.shape == torch.Size([80, 8]) + assert linear_base.bias.shape == torch.Size([80]) + assert linear_copy.weight is linear_base.weight + assert linear_copy.bias is linear_base.bias + + # ensure weights are reversibly loadable + linear_base.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_base.state_dict()) + + # check computation correctness + x = torch.rand(4, 8).cuda() + out = linear(x) + base_out = linear_base(x) + assert_close(out, base_out) + + # check backward correctness + out.sum().backward() + base_out.sum().backward() + + assert_close(linear.weight.grad, linear_base.weight.grad) + + def run_dist(rank, world_size, port): colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_linear_1d_col() check_linear_1d_row() check_linear_1d_col_row() + check_linear_1d_base() @rerun_if_address_is_in_use()