From fb2d4fdfb44b6a03f2c97adc0ecf8b6e78d1ec59 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 15 Aug 2024 09:13:14 +0000 Subject: [PATCH] fix the moe --- colossalai/booster/plugin/moe_hybrid_parallel_plugin.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index ca3a68373184..374fc653556e 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -215,6 +215,7 @@ def __init__( overlap_p2p: bool = True, overlap_allgather: bool = False, fp8_communication: bool = False, + use_fp8: bool = False, ) -> None: if overlap_communication or zero_stage == 2: overlap_communication = False @@ -324,7 +325,7 @@ def __init__( self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) else: self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) - + self.use_fp8 = use_fp8 self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, sequence_parallel_process_group=self.sp_group, @@ -428,6 +429,7 @@ def configure( use_ddp=use_ddp, ddp_config=self.ddp_config, custom_policy=self.custom_policy, + use_fp8=self.use_fp8, ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.ep_size > 1: