From 09abd376506495fd094f14eed90ee303987e907c Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 6 Feb 2026 15:31:14 +0800 Subject: [PATCH] update ep cookbook --- cookbook/sft/ep_fsdp_qwen3_moe.py | 26 ++++++-------------------- src/twinkle/utils/platform.py | 5 +++++ 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/cookbook/sft/ep_fsdp_qwen3_moe.py b/cookbook/sft/ep_fsdp_qwen3_moe.py index 9b669a1f..68c6d15a 100644 --- a/cookbook/sft/ep_fsdp_qwen3_moe.py +++ b/cookbook/sft/ep_fsdp_qwen3_moe.py @@ -1,12 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os -import numpy as np import torch.distributed as dist from transformers import AutoConfig import twinkle -from twinkle import DeviceGroup, DeviceMesh, Platform, get_device_placement, get_logger +from twinkle import DeviceMesh, Platform, get_device_placement, get_logger from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel @@ -19,29 +18,18 @@ 'QWEN3_DATASET_ID', '/path/to/alpaca/dataset') TEMPLATE_ID = os.environ.get('QWEN3_TEMPLATE_ID', 'Template') PROCESSOR_ID = os.environ.get('QWEN3_PROCESSOR_ID', 'AlpacaProcessor') -REMOTE_GROUP = 'model' NUM_LAYERS = int(os.environ.get('QWEN3_NUM_LAYERS', '1')) -device_group = [ - DeviceGroup( - name=REMOTE_GROUP, - ranks=[0, 1, 2, 3], - device_type=Platform.get_platform().device_prefix(), - ) -] - # 4 GPUs: dp=2, ep=2 -device_mesh = DeviceMesh( +device_mesh = DeviceMesh.from_sizes( device_type=Platform.get_platform().device_prefix(), - mesh=np.array([[0, 1], [2, 3]]), - mesh_dim_names=('dp', 'ep'), + dp_size=2, + ep_size=2, ) -os.environ.setdefault("RAY_DEDUP_LOGS", "0") twinkle.initialize( - mode='ray', + mode='local', nproc_per_node=4, - groups=device_group, global_device_mesh=device_mesh, ) @@ -78,7 +66,6 @@ def train(): dataloader = DataLoader( dataset=dataset, batch_size=4, - remote_group=REMOTE_GROUP, device_mesh=device_mesh, ) @@ -86,7 +73,6 @@ def train(): model = TransformersModel( model_id=MODEL_ID, config=config, - remote_group=REMOTE_GROUP, device_mesh=device_mesh, fsdp_config={ "expert_parallel": { @@ -115,7 +101,7 @@ def train(): metric = metric() logger.info( f'Current is step {step // grad_accum_steps}, metric: {metric}') - if step % 50 == 0: + if step > 1 and step % 50 == 0: model.save('./output') diff --git a/src/twinkle/utils/platform.py b/src/twinkle/utils/platform.py index f922e0c5..986c4efc 100644 --- a/src/twinkle/utils/platform.py +++ b/src/twinkle/utils/platform.py @@ -86,6 +86,11 @@ def from_sizes(*, world_size: int = 1, dp_size: int = 1, fsdp_size: int = None, mesh_dim_sizes.append(dp_size) else: mesh_dim_sizes.append(-1) + if ep_size is not None: + mesh_dim_sizes.append(ep_size) + mesh_dim_names.append("ep") + if origin_world_size == 1: + world_size *= ep_size if cp_size is not None: mesh_dim_sizes.append(cp_size) mesh_dim_names.append("cp")