diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index 101f66c7..f2f34534 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -15,10 +15,12 @@ class NativeFSDPStrategy: def __init__(self, device_mesh: Optional[DeviceMesh] = None, mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16', - fsdp_config: Dict[str, Any] = None): + fsdp_config: Dict[str, Any] = None, + enable_ep: bool = True): self.device_mesh = device_mesh self.mixed_precision = mixed_precision self.fsdp_config = fsdp_config or {} + self.enable_ep = enable_ep def wrap_model(self, model, optimizer=None): if self.device_mesh is None: @@ -26,11 +28,12 @@ def wrap_model(self, model, optimizer=None): fsdp_mesh = _build_fsdp_mesh(self.device_mesh) if fsdp_mesh is not None: - _ensure_moe_patched_if_needed(model, self.device_mesh) - _place_ep_experts_on_local_device(model, self.device_mesh) + if self.enable_ep: + _ensure_moe_patched_if_needed(model, self.device_mesh) + _place_ep_experts_on_local_device(model, self.device_mesh) mp_policy = _build_mp_policy(self.mixed_precision) reshard_after_forward = self.fsdp_config.get("reshard_after_forward", True) - ignored_params = _collect_expert_params(model) + ignored_params = _collect_expert_params(model) if self.enable_ep else None _maybe_shard_layers( model, diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 675f82bb..3960d33d 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -209,6 +209,7 @@ def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']): mixed_precision=self.mixed_precision, fsdp_config=self._fsdp_config, device_mesh=self.device_mesh, + enable_ep=self._enable_expert_parallel, ) else: self.strategy = AccelerateStrategy(mixed_precision=self.mixed_precision, ddp_config=self._ddp_config,