From 75d1d1cca44f1d0855c49fb15a0ccb16683ba323 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 7 Aug 2024 09:51:18 +0000 Subject: [PATCH 01/15] fix --- colossalai/quantization/fp8.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index bc8c3ced4cdd..805824a896c7 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -376,28 +376,6 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"): output_list[i].copy_(cast_from_fp8(tensor, scale, input_type)) -def all_to_all_single_fp8(output_tensor, input_tensor, group=None, fp8_format="e5m2"): - - world_size = dist.get_world_size(group) - - per_slice_len = input_tensor.size(0) // world_size - input_type = input_tensor.dtype - ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format) - fp8_type = ret.dtype - input_tensor = ret.view(torch.uint8) - tensor = torch.empty_like(input_tensor) - scale_list = [torch.empty_like(scale) for _ in range(world_size)] - dist.all_to_all_single(tensor, input_tensor, group=group) - dist.all_gather(scale_list, scale, group=group) - cast_tensor_list = [] - - for i in range(world_size): - output_part = tensor[per_slice_len * i : per_slice_len * (i + 1)].view(fp8_type) - output_part = cast_from_fp8(output_part, scale_list[i], input_type) - cast_tensor_list.append(output_part) - output_tensor.copy_(torch.concatenate(cast_tensor_list, dim=0)) - - def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"): world_size = dist.get_world_size(group) From 244f72ea52bd3ce5a56176f3846edb7bde1f424c Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 7 Aug 2024 09:57:07 +0000 Subject: [PATCH 02/15] support moe fp8 --- .../plugin/moe_hybrid_parallel_plugin.py | 2 + colossalai/moe/_operation.py | 24 ++++++-- colossalai/quantization/fp8.py | 58 ++++++++++++++++++- colossalai/shardformer/modeling/mixtral.py | 13 +++-- colossalai/shardformer/policies/mixtral.py | 7 ++- .../test_plugin/test_torch_ddp_plugin.py | 1 + tests/test_moe/test_mixtral_layer.py | 3 +- .../test_model/test_shard_mixtral.py | 6 +- 8 files changed, 100 insertions(+), 14 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index b3415af0eed6..2b957c580c82 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -214,6 +214,7 @@ def __init__( moe_dp_outside: bool = True, overlap_p2p: bool = True, overlap_allgather: bool = False, + fp8_communication: bool = False ) -> None: if overlap_communication or zero_stage == 2: overlap_communication = False @@ -341,6 +342,7 @@ def __init__( parallel_output=parallel_output, make_vocab_size_divisible_by=make_vocab_size_divisible_by, gradient_checkpoint_config=gradient_checkpoint_config, + fp8_communication=fp8_communication ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index ac422a4da98f..5c1c2dd9c67f 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -5,6 +5,7 @@ from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup +from colossalai.quantization.fp8 import all_to_all_single_fp8 MOE_KERNEL = None @@ -380,6 +381,7 @@ def _all_to_all( output_split_sizes: Optional[List[int]] = None, group=None, async_op: bool = False, + fp8_communication: bool = False ): """ Returns: @@ -392,9 +394,12 @@ def _all_to_all( outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device) inputs = inputs.contiguous() outputs = outputs.contiguous() - handle = dist.all_to_all_single( - outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op - ) + if fp8_communication: + handle = all_to_all_single_fp8(outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=False) + else: + handle = dist.all_to_all_single( + outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op + ) return outputs, handle @@ -407,6 +412,7 @@ def forward( output_split_sizes=None, group=None, overlap: bool = False, + fp8_communication: bool = False ): """ Returns: @@ -416,7 +422,7 @@ def forward( ctx.input_split_sizes = input_split_sizes ctx.output_split_sizes = output_split_sizes ctx.group = group - return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap) + return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication) @staticmethod def backward(ctx: Any, *grad_outputs): @@ -426,6 +432,7 @@ def backward(ctx: Any, *grad_outputs): None, None, None, + None, ) @@ -435,8 +442,15 @@ def all_to_all_uneven( output_split_sizes: Optional[List[int]] = None, group=None, overlap: bool = False, + fp8_communication: bool=False ): assert ( inputs.requires_grad ), "Input must require grad to assure that backward is executed, otherwise it might hang the program." - return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap) + return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication) + +def all_to_all_single(output, input, group=None, fp8_communication: bool=False): + if fp8_communication: + all_to_all_single_fp8(output, input, group=group) + else: + dist.all_to_all_single(output, input, group=group) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 52bb8cc9bc33..890639bcb124 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -18,7 +18,7 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) - Returns: Tuples: A tuple (fp8_tensor, scale) """ - + print("inp.dtype", inp.dtype) if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]: raise TypeError("Only float16, bfloat16, and float32 are allowed.") @@ -115,6 +115,62 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro tensor.copy_(out[:input_size].view(input_shape).to(input_type)) +def all_to_all_single_fp8( + output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False +) -> None: + r""" + This is an in-place operation for compressed all_reduce using fp8. + It works like dist.all_to_all_single but during communication the data is cast to fp8 format. + Args: + tensor: torch.Tensor in fp32, fp16, bf16 datatype. + fp8_format: e4m3 or e5m2 + Returns: + None + """ + world_size = dist.get_world_size(group=group) + input_type = input.dtype + input_shape = input.shape + input_device = input.device + input = input.flatten() + + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + + ret, scale = cast_to_fp8(input, fp8_format=fp8_format) + + inp = ret.view(torch.uint8) + if input_split_sizes is not None: + input_split_sizes = [input_split_sizes[i] * np.prod(input_shape[1:]) for i in range(world_size)] + input_chunks = list(torch.split(inp, input_split_sizes)) + else: + input_chunks = list(torch.chunk(inp, world_size, dim=0)) + + if output_split_sizes is not None: + output_chunks = [ + torch.empty((output_split_sizes[i] * np.prod(input_shape[1:]),), device=input_device, dtype=inp.dtype) + for i in range(world_size) + ] + else: + if dist.get_rank() == world_size - 1: + output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)] + else: + output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)] + + dist.all_to_all(output_chunks, input_chunks, group=group, async_op=async_op) + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] + dist.all_gather(scale_list, scale, group=group, async_op=async_op) + cast_output_chunk = [ + cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks) + ] + + tensor_out = torch.cat(cast_output_chunk, dim=0) + outputs_shape = list(input_shape) + if output_split_sizes is not None: + outputs_shape[0] = sum(output_split_sizes) + else: + outputs_shape = input_shape + output.data = tensor_out.view(outputs_shape).to(input_type) + + def all_to_all_single_fp8( output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False ) -> None: diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index d30ce5ea85cc..90c86f7552a4 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -29,6 +29,7 @@ EPGradScalerIn, EPGradScalerOut, all_to_all_uneven, + all_to_all_single ) from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import ( @@ -53,7 +54,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): def __init__(self, *args, **kwargs): raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") - def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup): + def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, fp8_communication: bool = False): assert tp_group is not None assert moe_dp_group is not None assert ep_group is not None @@ -84,6 +85,7 @@ def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGrou expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group) expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group) + self.fp8_communication = fp8_communication for p in self.experts.parameters(): set_moe_tensor_ep_group(p, ep_group) @@ -99,7 +101,8 @@ def from_native_module( # TODO: better init LazyInitContext.materialize(module) module.__class__ = EPMixtralSparseMoeBlock - module.setup_process_groups(tp_group, moe_dp_group, ep_group) + fp8_communication = kwargs.get("fp8_communication", False) + module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication) return module def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -120,6 +123,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_split_sizes = selected_experts.bincount(minlength=self.num_experts) output_split_sizes = torch.zeros_like(input_split_sizes) + + # all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group, fp8_communication=False) dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) with torch.no_grad(): @@ -132,7 +137,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() - output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) + output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group, self.fp8_communication) # compute expert output output_states = EPGradScalerIn.apply(output_states, self.ep_size) if output_states.size(0) > 0: @@ -162,7 +167,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: output_states = torch.cat(output_states_list) output_states = EPGradScalerOut.apply(output_states, self.ep_size) - dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) + dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group, self.fp8_communication) recover_experts_idx = torch.empty_like(selected_experts_idx) recover_experts_idx[selected_experts_idx] = torch.arange( diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 10df143c99da..b03bf89a32b9 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -114,21 +114,25 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, + kwargs={"fp8_communication": self.shard_config.fp8_communication} ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, + kwargs={"fp8_communication": self.shard_config.fp8_communication} ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, + kwargs={"fp8_communication": self.shard_config.fp8_communication} ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, + kwargs={"fp8_communication": self.shard_config.fp8_communication} ), SubModuleReplacementDescription( # or replicate? - suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True} + suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication} ), ], ) @@ -155,6 +159,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, + "fp8_communication": self.shard_config.fp8_communication }, ) ], diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index f92b5c6e5675..3573d86a84ee 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -26,6 +26,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + model = model.to(torch.float16) assert isinstance(model.module, DDP) assert isinstance(optimizer, OptimizerWrapper) diff --git a/tests/test_moe/test_mixtral_layer.py b/tests/test_moe/test_mixtral_layer.py index bc41ac4f33e9..8b06cfa90aa0 100644 --- a/tests/test_moe/test_mixtral_layer.py +++ b/tests/test_moe/test_mixtral_layer.py @@ -42,6 +42,7 @@ def check_mixtral_moe_layer(): ep_group=plugin.ep_group, tp_group=plugin.tp_group, moe_dp_group=plugin.moe_dp_group, + fp8_communication=True ) ep_output, ep_logits = model(x) assert_close(orig_logits, ep_logits) @@ -63,7 +64,7 @@ def run_dist(rank: int, world_size: int, port: int): check_mixtral_moe_layer() -@pytest.mark.skip("tested in corresponding sharderformer") +# @pytest.mark.skip("tested in corresponding sharderformer") @pytest.mark.parametrize("world_size", [2]) def test_mixtral_moe_layer(world_size: int): spawn(run_dist, world_size) diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index de09eedcbed5..bf2331ec9755 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -40,8 +40,8 @@ "config", [ (1, 2, 2, 1, 1), - (1, 2, 1, 2, 1), - (1, 2, 1, 1, 2), + # (1, 2, 1, 2, 1), + # (1, 2, 1, 1, 2), ], ) def run_zero_with_original_model(config: Tuple[int, ...]): @@ -64,6 +64,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]): initial_scale=1, precision=precision, find_unused_parameters=True, + fp8_communication=True ) dp_size = plugin.dp_size @@ -153,6 +154,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]): torch_optimizer.step() torch_optimizer.zero_grad() + print("parallel_output", parallel_output.dtype, dtype) assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) # use checkpoint to load sharded zero model From e7cbe74c2a9c042df588d0005449606552886b75 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 7 Aug 2024 11:12:05 +0000 Subject: [PATCH 03/15] fix --- colossalai/shardformer/layer/embedding.py | 4 +++- colossalai/shardformer/modeling/deepseek.py | 17 ++++++++++++----- colossalai/shardformer/policies/deepseek.py | 9 +++++++-- colossalai/shardformer/policies/mixtral.py | 6 +++--- .../test_model/test_shard_mixtral.py | 5 ++--- 5 files changed, 27 insertions(+), 14 deletions(-) diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 9b77774aaeaa..186063503fd4 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -274,6 +274,7 @@ def __init__( weight: Optional[nn.Parameter] = None, weight_initializer: Callable = init.normal_(), make_vocab_size_divisible_by: int = 64, + fp8_communication: bool = False, *args, **kwargs, ): @@ -282,6 +283,7 @@ def __init__( self.embed_args = args self.embed_kwargs = kwargs self.process_group = process_group + self.fp8_communication = fp8_communication tensor_parallel_size = dist.get_world_size(group=process_group) tensor_parallel_rank = dist.get_rank(group=process_group) @@ -390,5 +392,5 @@ def forward(self, input_: Tensor) -> Tensor: embedding_output = output_parallel.clone() embedding_output[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. - output = reduce_forward(embedding_output, self.process_group) + output = reduce_forward(embedding_output, self.process_group, fp8_communication=self.fp8_communication) return output diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index a84a3097231a..f6c9bb752f02 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -23,6 +23,7 @@ EPGradScalerOut, all_to_all_uneven, ) +from colossalai.quantization.fp8 import all_reduce_fp8 from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import ( all_to_all_comm, @@ -61,7 +62,7 @@ class EPDeepseekMoE(nn.Module): def __init__(self): raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") - def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup): + def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, fp8_communication: bool = False): assert tp_group is not None assert moe_dp_group is not None assert ep_group is not None @@ -70,6 +71,7 @@ def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGrou self.ep_rank = dist.get_rank(ep_group) self.num_experts = self.config.n_routed_experts assert self.num_experts % self.ep_size == 0 + self.fp8_communication = fp8_communication self.ep_group = ep_group self.num_experts_per_ep = self.num_experts // self.ep_size @@ -106,7 +108,8 @@ def from_native_module( if module.__class__.__name__ == "DeepseekMLP": return module module.__class__ = EPDeepseekMoE - module.setup_process_groups(tp_group, moe_dp_group, ep_group) + fp8_communication = kwargs["fp8_communication"] + module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication) return module def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -137,11 +140,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for i in range(1, self.ep_size): activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep] activate_experts = (activate_experts > 0).float() - dist.all_reduce(activate_experts, group=self.moe_dp_group) + + if self.fp8_communication: + all_reduce_fp8(activate_experts, group=self.moe_dp_group) + else: + dist.all_reduce(activate_experts, group=self.moe_dp_group) input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() - output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) + output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group, self.fp8_communication) output_states = EPGradScalerIn.apply(output_states, self.ep_size) if output_states.size(0) > 0: @@ -167,7 +174,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: output_states_list.append(split_states) output_states = torch.cat(output_states_list) output_states = EPGradScalerOut.apply(output_states, self.ep_size) - dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) + dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group, self.fp8_communication) recover_token_idx = torch.empty_like(flat_topk_token_idx) recover_token_idx[flat_topk_token_idx] = torch.arange( flat_topk_token_idx.size(0), device=flat_topk_token_idx.device diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index 605f69c4a632..35954079c4c4 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -118,18 +118,22 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, + kwargs={"fp8_communication", self.shard_config.fp8_communication} ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, + kwargs={"fp8_communication", self.shard_config.fp8_communication} ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, + kwargs={"fp8_communication", self.shard_config.fp8_communication} ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, + kwargs={"fp8_communication", self.shard_config.fp8_communication} ), ], ) @@ -138,7 +142,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, "fp8_communication": self.shard_config.fp8_communication}, ), policy=policy, target_key="DeepseekModel", @@ -155,6 +159,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, + "fp8_communication": self.shard_config.fp8_communication }, ) ], @@ -305,7 +310,7 @@ def module_policy(self): SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, - kwargs=dict(gather_output=True), + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ) ] ) diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index b03bf89a32b9..54b5c0e93d6d 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -142,7 +142,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, "fp8_communication": self.shard_config.fp8_communication}, ), policy=policy, target_key=MixtralModel, @@ -287,7 +287,7 @@ def module_policy(self): SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, - kwargs=dict(gather_output=True), + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ) ] ) @@ -341,7 +341,7 @@ def module_policy(self): MixtralForSequenceClassification: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication) ) ] ) diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index bf2331ec9755..425e576cbc37 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -40,8 +40,8 @@ "config", [ (1, 2, 2, 1, 1), - # (1, 2, 1, 2, 1), - # (1, 2, 1, 1, 2), + (1, 2, 1, 2, 1), + (1, 2, 1, 1, 2), ], ) def run_zero_with_original_model(config: Tuple[int, ...]): @@ -64,7 +64,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]): initial_scale=1, precision=precision, find_unused_parameters=True, - fp8_communication=True ) dp_size = plugin.dp_size From 0eee454d2b2080e3882e00f5136660a3c86540ae Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 7 Aug 2024 11:14:48 +0000 Subject: [PATCH 04/15] fix --- colossalai/quantization/fp8.py | 1 - colossalai/shardformer/modeling/mixtral.py | 1 - tests/test_moe/test_mixtral_layer.py | 2 +- tests/test_shardformer/test_model/test_shard_mixtral.py | 1 - 4 files changed, 1 insertion(+), 4 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 890639bcb124..374719be0700 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -18,7 +18,6 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) - Returns: Tuples: A tuple (fp8_tensor, scale) """ - print("inp.dtype", inp.dtype) if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]: raise TypeError("Only float16, bfloat16, and float32 are allowed.") diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 90c86f7552a4..bce00507351e 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -124,7 +124,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: output_split_sizes = torch.zeros_like(input_split_sizes) - # all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group, fp8_communication=False) dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) with torch.no_grad(): diff --git a/tests/test_moe/test_mixtral_layer.py b/tests/test_moe/test_mixtral_layer.py index 8b06cfa90aa0..61c372f09979 100644 --- a/tests/test_moe/test_mixtral_layer.py +++ b/tests/test_moe/test_mixtral_layer.py @@ -64,7 +64,7 @@ def run_dist(rank: int, world_size: int, port: int): check_mixtral_moe_layer() -# @pytest.mark.skip("tested in corresponding sharderformer") +@pytest.mark.skip("tested in corresponding sharderformer") @pytest.mark.parametrize("world_size", [2]) def test_mixtral_moe_layer(world_size: int): spawn(run_dist, world_size) diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index 425e576cbc37..de09eedcbed5 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -153,7 +153,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]): torch_optimizer.step() torch_optimizer.zero_grad() - print("parallel_output", parallel_output.dtype, dtype) assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) # use checkpoint to load sharded zero model From 4a2559b0d33accc1b9ba8bb391dacb59d83b0c47 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Aug 2024 04:03:25 +0000 Subject: [PATCH 05/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../plugin/moe_hybrid_parallel_plugin.py | 4 ++-- colossalai/moe/_operation.py | 14 +++++++---- colossalai/shardformer/modeling/deepseek.py | 18 +++++++++++---- colossalai/shardformer/modeling/mixtral.py | 19 +++++++++++---- colossalai/shardformer/policies/deepseek.py | 15 +++++++----- colossalai/shardformer/policies/mixtral.py | 23 ++++++++++++------- tests/test_moe/test_mixtral_layer.py | 2 +- 7 files changed, 64 insertions(+), 31 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 2b957c580c82..ca3a68373184 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -214,7 +214,7 @@ def __init__( moe_dp_outside: bool = True, overlap_p2p: bool = True, overlap_allgather: bool = False, - fp8_communication: bool = False + fp8_communication: bool = False, ) -> None: if overlap_communication or zero_stage == 2: overlap_communication = False @@ -342,7 +342,7 @@ def __init__( parallel_output=parallel_output, make_vocab_size_divisible_by=make_vocab_size_divisible_by, gradient_checkpoint_config=gradient_checkpoint_config, - fp8_communication=fp8_communication + fp8_communication=fp8_communication, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index 5c1c2dd9c67f..f5ef6b6984e9 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -5,6 +5,7 @@ from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup + from colossalai.quantization.fp8 import all_to_all_single_fp8 MOE_KERNEL = None @@ -381,7 +382,7 @@ def _all_to_all( output_split_sizes: Optional[List[int]] = None, group=None, async_op: bool = False, - fp8_communication: bool = False + fp8_communication: bool = False, ): """ Returns: @@ -395,7 +396,9 @@ def _all_to_all( inputs = inputs.contiguous() outputs = outputs.contiguous() if fp8_communication: - handle = all_to_all_single_fp8(outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=False) + handle = all_to_all_single_fp8( + outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=False + ) else: handle = dist.all_to_all_single( outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op @@ -412,7 +415,7 @@ def forward( output_split_sizes=None, group=None, overlap: bool = False, - fp8_communication: bool = False + fp8_communication: bool = False, ): """ Returns: @@ -442,14 +445,15 @@ def all_to_all_uneven( output_split_sizes: Optional[List[int]] = None, group=None, overlap: bool = False, - fp8_communication: bool=False + fp8_communication: bool = False, ): assert ( inputs.requires_grad ), "Input must require grad to assure that backward is executed, otherwise it might hang the program." return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication) -def all_to_all_single(output, input, group=None, fp8_communication: bool=False): + +def all_to_all_single(output, input, group=None, fp8_communication: bool = False): if fp8_communication: all_to_all_single_fp8(output, input, group=group) else: diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index f6c9bb752f02..43ab6847d59c 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -23,8 +23,8 @@ EPGradScalerOut, all_to_all_uneven, ) -from colossalai.quantization.fp8 import all_reduce_fp8 from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization.fp8 import all_reduce_fp8 from colossalai.shardformer.layer._operation import ( all_to_all_comm, gather_forward_split_backward, @@ -62,7 +62,13 @@ class EPDeepseekMoE(nn.Module): def __init__(self): raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") - def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, fp8_communication: bool = False): + def setup_process_groups( + self, + tp_group: ProcessGroup, + moe_dp_group: ProcessGroup, + ep_group: ProcessGroup, + fp8_communication: bool = False, + ): assert tp_group is not None assert moe_dp_group is not None assert ep_group is not None @@ -148,7 +154,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() - output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group, self.fp8_communication) + output_states, _ = all_to_all_uneven( + dispatch_states, input_split_list, output_split_list, self.ep_group, self.fp8_communication + ) output_states = EPGradScalerIn.apply(output_states, self.ep_size) if output_states.size(0) > 0: @@ -174,7 +182,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: output_states_list.append(split_states) output_states = torch.cat(output_states_list) output_states = EPGradScalerOut.apply(output_states, self.ep_size) - dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group, self.fp8_communication) + dispatch_states, _ = all_to_all_uneven( + output_states, output_split_list, input_split_list, self.ep_group, self.fp8_communication + ) recover_token_idx = torch.empty_like(flat_topk_token_idx) recover_token_idx[flat_topk_token_idx] = torch.arange( flat_topk_token_idx.size(0), device=flat_topk_token_idx.device diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index bce00507351e..c739a97ed492 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -29,7 +29,6 @@ EPGradScalerIn, EPGradScalerOut, all_to_all_uneven, - all_to_all_single ) from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import ( @@ -54,7 +53,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): def __init__(self, *args, **kwargs): raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") - def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, fp8_communication: bool = False): + def setup_process_groups( + self, + tp_group: ProcessGroup, + moe_dp_group: ProcessGroup, + ep_group: ProcessGroup, + fp8_communication: bool = False, + ): assert tp_group is not None assert moe_dp_group is not None assert ep_group is not None @@ -123,7 +128,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_split_sizes = selected_experts.bincount(minlength=self.num_experts) output_split_sizes = torch.zeros_like(input_split_sizes) - + dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) with torch.no_grad(): @@ -136,7 +141,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() - output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group, self.fp8_communication) + output_states, _ = all_to_all_uneven( + dispatch_states, input_split_list, output_split_list, self.ep_group, self.fp8_communication + ) # compute expert output output_states = EPGradScalerIn.apply(output_states, self.ep_size) if output_states.size(0) > 0: @@ -166,7 +173,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: output_states = torch.cat(output_states_list) output_states = EPGradScalerOut.apply(output_states, self.ep_size) - dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group, self.fp8_communication) + dispatch_states, _ = all_to_all_uneven( + output_states, output_split_list, input_split_list, self.ep_group, self.fp8_communication + ) recover_experts_idx = torch.empty_like(selected_experts_idx) recover_experts_idx[selected_experts_idx] = torch.arange( diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index 35954079c4c4..b3f2963fd15f 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -118,22 +118,22 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication", self.shard_config.fp8_communication} + kwargs={"fp8_communication", self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication", self.shard_config.fp8_communication} + kwargs={"fp8_communication", self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication", self.shard_config.fp8_communication} + kwargs={"fp8_communication", self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs={"fp8_communication", self.shard_config.fp8_communication} + kwargs={"fp8_communication", self.shard_config.fp8_communication}, ), ], ) @@ -142,7 +142,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, "fp8_communication": self.shard_config.fp8_communication}, + kwargs={ + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + }, ), policy=policy, target_key="DeepseekModel", @@ -159,7 +162,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, - "fp8_communication": self.shard_config.fp8_communication + "fp8_communication": self.shard_config.fp8_communication, }, ) ], diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 54b5c0e93d6d..550b7f2360f5 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -114,25 +114,27 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication": self.shard_config.fp8_communication} + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication": self.shard_config.fp8_communication} + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication": self.shard_config.fp8_communication} + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs={"fp8_communication": self.shard_config.fp8_communication} + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( # or replicate? - suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication} + suffix="block_sparse_moe.gate", + target_module=Linear1D_Col, + kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication}, ), ], ) @@ -142,7 +144,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, "fp8_communication": self.shard_config.fp8_communication}, + kwargs={ + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + }, ), policy=policy, target_key=MixtralModel, @@ -159,7 +164,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, - "fp8_communication": self.shard_config.fp8_communication + "fp8_communication": self.shard_config.fp8_communication, }, ) ], @@ -341,7 +346,9 @@ def module_policy(self): MixtralForSequenceClassification: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication) + suffix="score", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ) ] ) diff --git a/tests/test_moe/test_mixtral_layer.py b/tests/test_moe/test_mixtral_layer.py index 61c372f09979..b4f9a45ae892 100644 --- a/tests/test_moe/test_mixtral_layer.py +++ b/tests/test_moe/test_mixtral_layer.py @@ -42,7 +42,7 @@ def check_mixtral_moe_layer(): ep_group=plugin.ep_group, tp_group=plugin.tp_group, moe_dp_group=plugin.moe_dp_group, - fp8_communication=True + fp8_communication=True, ) ep_output, ep_logits = model(x) assert_close(orig_logits, ep_logits) From 71b421b7a27f60d2a8ab2b957b989e68141f13a9 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 8 Aug 2024 04:05:03 +0000 Subject: [PATCH 06/15] fix --- colossalai/quantization/fp8.py | 57 +--------------------------------- 1 file changed, 1 insertion(+), 56 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 374719be0700..c999f7752c42 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -18,6 +18,7 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) - Returns: Tuples: A tuple (fp8_tensor, scale) """ + if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]: raise TypeError("Only float16, bfloat16, and float32 are allowed.") @@ -170,62 +171,6 @@ def all_to_all_single_fp8( output.data = tensor_out.view(outputs_shape).to(input_type) -def all_to_all_single_fp8( - output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False -) -> None: - r""" - This is an in-place operation for compressed all_reduce using fp8. - It works like dist.all_to_all_single but during communication the data is cast to fp8 format. - Args: - tensor: torch.Tensor in fp32, fp16, bf16 datatype. - fp8_format: e4m3 or e5m2 - Returns: - None - """ - world_size = dist.get_world_size(group=group) - input_type = input.dtype - input_shape = input.shape - input_device = input.device - input = input.flatten() - - fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 - - ret, scale = cast_to_fp8(input, fp8_format=fp8_format) - - inp = ret.view(torch.uint8) - if input_split_sizes is not None: - input_split_sizes = [input_split_sizes[i] * np.prod(input_shape[1:]) for i in range(world_size)] - input_chunks = list(torch.split(inp, input_split_sizes)) - else: - input_chunks = list(torch.chunk(inp, world_size, dim=0)) - - if output_split_sizes is not None: - output_chunks = [ - torch.empty((output_split_sizes[i] * np.prod(input_shape[1:]),), device=input_device, dtype=inp.dtype) - for i in range(world_size) - ] - else: - if dist.get_rank() == world_size - 1: - output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)] - else: - output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)] - - dist.all_to_all(output_chunks, input_chunks, group=group) - scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] - dist.all_gather(scale_list, scale, group=group) - cast_output_chunk = [ - cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks) - ] - - tensor_out = torch.cat(cast_output_chunk, dim=0) - outputs_shape = list(input_shape) - if output_split_sizes is not None: - outputs_shape[0] = sum(output_split_sizes) - else: - outputs_shape = input_shape - output.data = tensor_out.view(outputs_shape).to(input_type) - - def cast_to_fp8_pipeline(inp: Any) -> None: """ Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline. From e7b16fcb47990af46241553a68f512459977ca2e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 8 Aug 2024 04:10:01 +0000 Subject: [PATCH 07/15] fix --- colossalai/moe/_operation.py | 7 ------- colossalai/quantization/fp8.py | 4 ++-- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index f5ef6b6984e9..1ab72b01b2db 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -451,10 +451,3 @@ def all_to_all_uneven( inputs.requires_grad ), "Input must require grad to assure that backward is executed, otherwise it might hang the program." return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication) - - -def all_to_all_single(output, input, group=None, fp8_communication: bool = False): - if fp8_communication: - all_to_all_single_fp8(output, input, group=group) - else: - dist.all_to_all_single(output, input, group=group) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index c999f7752c42..52bb8cc9bc33 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -155,9 +155,9 @@ def all_to_all_single_fp8( else: output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)] - dist.all_to_all(output_chunks, input_chunks, group=group, async_op=async_op) + dist.all_to_all(output_chunks, input_chunks, group=group) scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] - dist.all_gather(scale_list, scale, group=group, async_op=async_op) + dist.all_gather(scale_list, scale, group=group) cast_output_chunk = [ cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks) ] From ae92057aa67c9f19853953d3032a1a0de7dbf54c Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 8 Aug 2024 06:10:42 +0000 Subject: [PATCH 08/15] fix --- colossalai/shardformer/modeling/deepseek.py | 22 +++++++++---------- colossalai/shardformer/modeling/mixtral.py | 20 ++++++++--------- .../test_plugin/test_torch_ddp_plugin.py | 1 - 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index 43ab6847d59c..b372427f0f03 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -94,9 +94,9 @@ def setup_process_groups( self.tp_group = tp_group if self.tp_group.size() > 1: for expert in held_experts: - expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group) - expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group) - expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group) + expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group, fp8_communication=self.fp8_communication) + expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group, fp8_communication=self.fp8_communication) + expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group, fp8_communication=self.fp8_communication) for p in self.experts.parameters(): set_moe_tensor_ep_group(p, ep_group) @@ -551,9 +551,9 @@ def forward( # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) + query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) + key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) + value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim @@ -612,7 +612,7 @@ def forward( # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128) - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256) + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication) # (1, 4, 256) else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -702,9 +702,9 @@ def forward( ) if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication) # embed positions hidden_states = inputs_embeds @@ -748,9 +748,9 @@ def forward( hidden_states = self.norm(hidden_states) if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication) elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index c739a97ed492..4316e72d2f06 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -86,9 +86,9 @@ def setup_process_groups( self.tp_group = tp_group if self.tp_group.size() > 1: for expert in held_experts: - expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group) - expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group) - expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group) + expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group, fp8_communication=self.fp8_communication) + expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group, fp8_communication=self.fp8_communication) + expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group, fp8_communication=self.fp8_communication) self.fp8_communication = fp8_communication for p in self.experts.parameters(): @@ -579,9 +579,9 @@ def forward( # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) + query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) + key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) + value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -793,9 +793,9 @@ def forward( ) if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication) hidden_states = inputs_embeds # decoder layers @@ -844,9 +844,9 @@ def forward( hidden_states = self.norm(hidden_states) if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication) elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index 3573d86a84ee..f92b5c6e5675 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -26,7 +26,6 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - model = model.to(torch.float16) assert isinstance(model.module, DDP) assert isinstance(optimizer, OptimizerWrapper) From e8ab0d013f980fef5316d33c4a923cb0c90ba49b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Aug 2024 06:11:42 +0000 Subject: [PATCH 09/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/modeling/deepseek.py | 32 +++++++++++++++------ colossalai/shardformer/modeling/mixtral.py | 28 +++++++++++++----- 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index b372427f0f03..bd8df0b21ac3 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -94,9 +94,15 @@ def setup_process_groups( self.tp_group = tp_group if self.tp_group.size() > 1: for expert in held_experts: - expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group, fp8_communication=self.fp8_communication) - expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group, fp8_communication=self.fp8_communication) - expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group, fp8_communication=self.fp8_communication) + expert.gate_proj = Linear1D_Col.from_native_module( + expert.gate_proj, self.tp_group, fp8_communication=self.fp8_communication + ) + expert.up_proj = Linear1D_Col.from_native_module( + expert.up_proj, self.tp_group, fp8_communication=self.fp8_communication + ) + expert.down_proj = Linear1D_Row.from_native_module( + expert.down_proj, self.tp_group, fp8_communication=self.fp8_communication + ) for p in self.experts.parameters(): set_moe_tensor_ep_group(p, ep_group) @@ -612,7 +618,9 @@ def forward( # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128) - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication) # (1, 4, 256) + attn_output = all_to_all_comm( + attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication + ) # (1, 4, 256) else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -702,9 +710,13 @@ def forward( ) if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication + ) # embed positions hidden_states = inputs_embeds @@ -748,9 +760,13 @@ def forward( hidden_states = self.norm(hidden_states) if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication + ) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 4316e72d2f06..127aab5a1f75 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -86,9 +86,15 @@ def setup_process_groups( self.tp_group = tp_group if self.tp_group.size() > 1: for expert in held_experts: - expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group, fp8_communication=self.fp8_communication) - expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group, fp8_communication=self.fp8_communication) - expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group, fp8_communication=self.fp8_communication) + expert.w1 = Linear1D_Col.from_native_module( + expert.w1, self.tp_group, fp8_communication=self.fp8_communication + ) + expert.w3 = Linear1D_Col.from_native_module( + expert.w3, self.tp_group, fp8_communication=self.fp8_communication + ) + expert.w2 = Linear1D_Row.from_native_module( + expert.w2, self.tp_group, fp8_communication=self.fp8_communication + ) self.fp8_communication = fp8_communication for p in self.experts.parameters(): @@ -793,9 +799,13 @@ def forward( ) if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication + ) hidden_states = inputs_embeds # decoder layers @@ -844,9 +854,13 @@ def forward( hidden_states = self.norm(hidden_states) if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication + ) # add hidden states from the last decoder layer if output_hidden_states: From eab341acd369c22d751dd1948afba0e442fe69eb Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 8 Aug 2024 06:58:53 +0000 Subject: [PATCH 10/15] fix --- colossalai/shardformer/modeling/mixtral.py | 2 +- colossalai/shardformer/policies/deepseek.py | 8 ++++---- colossalai/shardformer/shard/sharder.py | 11 ++++++----- tests/test_moe/test_deepseek_layer.py | 3 ++- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 4316e72d2f06..fca16a9d489f 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -68,6 +68,7 @@ def setup_process_groups( self.ep_size = dist.get_world_size(ep_group) self.ep_rank = dist.get_rank(ep_group) self.ep_group = ep_group + self.fp8_communication = fp8_communication if self.num_experts % self.ep_size != 0: raise ValueError("The number of experts must be divisible by the number of expert parallel groups.") @@ -90,7 +91,6 @@ def setup_process_groups( expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group, fp8_communication=self.fp8_communication) expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group, fp8_communication=self.fp8_communication) - self.fp8_communication = fp8_communication for p in self.experts.parameters(): set_moe_tensor_ep_group(p, ep_group) diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index b3f2963fd15f..0b8a602d1164 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -118,22 +118,22 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication", self.shard_config.fp8_communication}, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication", self.shard_config.fp8_communication}, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication", self.shard_config.fp8_communication}, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs={"fp8_communication", self.shard_config.fp8_communication}, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), ], ) diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index ee2f1f405879..3cd44426409c 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -198,11 +198,12 @@ def _replace_sub_module( native_sub_module, process_group=self.shard_config.tensor_parallel_process_group, **kwargs ) except Exception as e: - raise RuntimeError( - f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}" - f" with {target_module.__qualname__} with the exception: {e}. " - "Please check your model configuration or sharding policy, you can set up an issue for us to help you as well." - ) + # raise RuntimeError( + # f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}" + # f" with {target_module.__qualname__} with the exception: {e}. " + # "Please check your model configuration or sharding policy, you can set up an issue for us to help you as well." + # ) + raise e setattr_(org_layer, suffix, replace_layer) diff --git a/tests/test_moe/test_deepseek_layer.py b/tests/test_moe/test_deepseek_layer.py index d18ba2eacd84..98f732880d84 100644 --- a/tests/test_moe/test_deepseek_layer.py +++ b/tests/test_moe/test_deepseek_layer.py @@ -48,6 +48,7 @@ def check_deepseek_moe_layer(): ep_group=plugin.ep_group, moe_dp_group=plugin.moe_dp_group, tp_group=plugin.tp_group, + fp8_communication=True, ) ep_output = model(x) assert_close(orig_output, ep_output) @@ -68,7 +69,7 @@ def run_dist(rank: int, world_size: int, port: int): check_deepseek_moe_layer() -@pytest.mark.skip("tested in corresponding sharderformer") +# @pytest.mark.skip("tested in corresponding sharderformer") @pytest.mark.parametrize("world_size", [2]) def test_deepseek_moe_layer(world_size: int): spawn(run_dist, world_size) From 57e75fc498ec3188c026bd5ddcb88a9b3330869e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 8 Aug 2024 07:05:18 +0000 Subject: [PATCH 11/15] fix --- tests/test_moe/test_deepseek_layer.py | 3 +-- tests/test_moe/test_mixtral_layer.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_moe/test_deepseek_layer.py b/tests/test_moe/test_deepseek_layer.py index 98f732880d84..d18ba2eacd84 100644 --- a/tests/test_moe/test_deepseek_layer.py +++ b/tests/test_moe/test_deepseek_layer.py @@ -48,7 +48,6 @@ def check_deepseek_moe_layer(): ep_group=plugin.ep_group, moe_dp_group=plugin.moe_dp_group, tp_group=plugin.tp_group, - fp8_communication=True, ) ep_output = model(x) assert_close(orig_output, ep_output) @@ -69,7 +68,7 @@ def run_dist(rank: int, world_size: int, port: int): check_deepseek_moe_layer() -# @pytest.mark.skip("tested in corresponding sharderformer") +@pytest.mark.skip("tested in corresponding sharderformer") @pytest.mark.parametrize("world_size", [2]) def test_deepseek_moe_layer(world_size: int): spawn(run_dist, world_size) diff --git a/tests/test_moe/test_mixtral_layer.py b/tests/test_moe/test_mixtral_layer.py index b4f9a45ae892..bc41ac4f33e9 100644 --- a/tests/test_moe/test_mixtral_layer.py +++ b/tests/test_moe/test_mixtral_layer.py @@ -42,7 +42,6 @@ def check_mixtral_moe_layer(): ep_group=plugin.ep_group, tp_group=plugin.tp_group, moe_dp_group=plugin.moe_dp_group, - fp8_communication=True, ) ep_output, ep_logits = model(x) assert_close(orig_logits, ep_logits) From 0fe1ca9a9af17cd7aa300614d50702f523bd5d7c Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 9 Aug 2024 03:28:36 +0000 Subject: [PATCH 12/15] fix fix fi --- colossalai/moe/_operation.py | 4 +++- colossalai/shardformer/modeling/deepseek.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index 1ab72b01b2db..ba087a03b728 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -425,7 +425,9 @@ def forward( ctx.input_split_sizes = input_split_sizes ctx.output_split_sizes = output_split_sizes ctx.group = group - return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication) + return _all_to_all( + inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication=fp8_communication + ) @staticmethod def backward(ctx: Any, *grad_outputs): diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index bd8df0b21ac3..48a629705a7c 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -120,8 +120,8 @@ def from_native_module( if module.__class__.__name__ == "DeepseekMLP": return module module.__class__ = EPDeepseekMoE - fp8_communication = kwargs["fp8_communication"] - module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication) + fp8_communication = kwargs.get("fp8_communication", False) + module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication=fp8_communication) return module def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -161,7 +161,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_states, _ = all_to_all_uneven( - dispatch_states, input_split_list, output_split_list, self.ep_group, self.fp8_communication + dispatch_states, + input_split_list, + output_split_list, + self.ep_group, + fp8_communication=self.fp8_communication, ) output_states = EPGradScalerIn.apply(output_states, self.ep_size) @@ -189,7 +193,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: output_states = torch.cat(output_states_list) output_states = EPGradScalerOut.apply(output_states, self.ep_size) dispatch_states, _ = all_to_all_uneven( - output_states, output_split_list, input_split_list, self.ep_group, self.fp8_communication + output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication ) recover_token_idx = torch.empty_like(flat_topk_token_idx) recover_token_idx[flat_topk_token_idx] = torch.arange( From d3c0624c22e3a2ed754623a9b13df19b76bee81b Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 9 Aug 2024 07:59:26 +0000 Subject: [PATCH 13/15] fix --- colossalai/quantization/fp8.py | 20 ++++++++++++-------- colossalai/shardformer/modeling/mixtral.py | 8 ++++++-- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 52bb8cc9bc33..711ee5588883 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -25,14 +25,18 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) - fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 fp8_max = torch.finfo(fp8_type).max - if per_channel_scale: - per_channel_max = inp.abs().max(dim=-1).values.float() - per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0) - scale = fp8_max / per_channel_max[:, None] - else: - per_tensor_max = inp.abs().max().float() - per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) + if inp.numel() == 0: + per_tensor_max = torch.tensor([1.0], device=inp.device) scale = fp8_max / per_tensor_max + else: + if per_channel_scale: + per_channel_max = inp.abs().max(dim=-1).values.float() + per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0) + scale = fp8_max / per_channel_max[:, None] + else: + per_tensor_max = inp.abs().max().float() + per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) + scale = fp8_max / per_tensor_max scale_inv = 1.0 / scale ret = (scale * inp.float()).to(fp8_type) @@ -110,7 +114,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)] dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group) for i in range(world_size): - tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] + tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i].to(input_device) out = torch.cat(tensor_list, dim=0) tensor.copy_(out[:input_size].view(input_shape).to(input_type)) diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 9b5b7f9ad74a..76b441fe4258 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -148,7 +148,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_states, _ = all_to_all_uneven( - dispatch_states, input_split_list, output_split_list, self.ep_group, self.fp8_communication + dispatch_states, + input_split_list, + output_split_list, + self.ep_group, + fp8_communication=self.fp8_communication, ) # compute expert output output_states = EPGradScalerIn.apply(output_states, self.ep_size) @@ -180,7 +184,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: output_states = EPGradScalerOut.apply(output_states, self.ep_size) dispatch_states, _ = all_to_all_uneven( - output_states, output_split_list, input_split_list, self.ep_group, self.fp8_communication + output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication ) recover_experts_idx = torch.empty_like(selected_experts_idx) From cd177e2bf29915a4ac4ca0dcbe2e3868fff755c3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 Aug 2024 08:04:44 +0000 Subject: [PATCH 14/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/quantization/fp8.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index d5a0cab960ba..3d190c4f2c5f 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -32,16 +32,16 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) - scale = fp8_max / per_tensor_max scale_inv = 1.0 / scale else: - if per_channel_scale: - per_channel_max = inp.abs().max(dim=-1).values.float() - per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0) - scale = fp8_max / per_channel_max[:, None] - scale_inv = per_channel_max / fp8_max - else: - per_tensor_max = inp.abs().max().float() - per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) - scale = fp8_max / per_tensor_max - scale_inv = 1.0 / scale + if per_channel_scale: + per_channel_max = inp.abs().max(dim=-1).values.float() + per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0) + scale = fp8_max / per_channel_max[:, None] + scale_inv = per_channel_max / fp8_max + else: + per_tensor_max = inp.abs().max().float() + per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) + scale = fp8_max / per_tensor_max + scale_inv = 1.0 / scale ret = (scale * inp.float()).to(fp8_type) return ret, scale_inv From 977992dbf1aeb7e012cae7583bff8efc1d527fb1 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 9 Aug 2024 09:46:43 +0000 Subject: [PATCH 15/15] fix --- colossalai/quantization/fp8.py | 4 +--- colossalai/shardformer/shard/sharder.py | 11 +++++------ 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 3d190c4f2c5f..f2bffa09f15d 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -28,9 +28,7 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) - fp8_max = torch.finfo(fp8_type).max if inp.numel() == 0: - per_tensor_max = torch.tensor([1.0], device=inp.device) - scale = fp8_max / per_tensor_max - scale_inv = 1.0 / scale + return inp.to(fp8_type), torch.tensor([1.0], device=inp.device) else: if per_channel_scale: per_channel_max = inp.abs().max(dim=-1).values.float() diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 3cd44426409c..ee2f1f405879 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -198,12 +198,11 @@ def _replace_sub_module( native_sub_module, process_group=self.shard_config.tensor_parallel_process_group, **kwargs ) except Exception as e: - # raise RuntimeError( - # f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}" - # f" with {target_module.__qualname__} with the exception: {e}. " - # "Please check your model configuration or sharding policy, you can set up an issue for us to help you as well." - # ) - raise e + raise RuntimeError( + f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}" + f" with {target_module.__qualname__} with the exception: {e}. " + "Please check your model configuration or sharding policy, you can set up an issue for us to help you as well." + ) setattr_(org_layer, suffix, replace_layer)