Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 6 additions & 20 deletions cookbook/sft/ep_fsdp_qwen3_moe.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import os

import numpy as np
import torch.distributed as dist
from transformers import AutoConfig

import twinkle
from twinkle import DeviceGroup, DeviceMesh, Platform, get_device_placement, get_logger
from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import TransformersModel
Expand All @@ -19,29 +18,18 @@
'QWEN3_DATASET_ID', '/path/to/alpaca/dataset')
TEMPLATE_ID = os.environ.get('QWEN3_TEMPLATE_ID', 'Template')
PROCESSOR_ID = os.environ.get('QWEN3_PROCESSOR_ID', 'AlpacaProcessor')
REMOTE_GROUP = 'model'
NUM_LAYERS = int(os.environ.get('QWEN3_NUM_LAYERS', '1'))

device_group = [
DeviceGroup(
name=REMOTE_GROUP,
ranks=[0, 1, 2, 3],
device_type=Platform.get_platform().device_prefix(),
)
]

# 4 GPUs: dp=2, ep=2
device_mesh = DeviceMesh(
device_mesh = DeviceMesh.from_sizes(
device_type=Platform.get_platform().device_prefix(),
mesh=np.array([[0, 1], [2, 3]]),
mesh_dim_names=('dp', 'ep'),
dp_size=2,
ep_size=2,
)

os.environ.setdefault("RAY_DEDUP_LOGS", "0")
twinkle.initialize(
mode='ray',
mode='local',
nproc_per_node=4,
groups=device_group,
global_device_mesh=device_mesh,
)

Expand Down Expand Up @@ -78,15 +66,13 @@ def train():
dataloader = DataLoader(
dataset=dataset,
batch_size=4,
remote_group=REMOTE_GROUP,
device_mesh=device_mesh,
)

grad_accum_steps = 4
model = TransformersModel(
model_id=MODEL_ID,
config=config,
remote_group=REMOTE_GROUP,
device_mesh=device_mesh,
fsdp_config={
"expert_parallel": {
Expand Down Expand Up @@ -115,7 +101,7 @@ def train():
metric = metric()
logger.info(
f'Current is step {step // grad_accum_steps}, metric: {metric}')
if step % 50 == 0:
if step > 1 and step % 50 == 0:
model.save('./output')


Expand Down
5 changes: 5 additions & 0 deletions src/twinkle/utils/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ def from_sizes(*, world_size: int = 1, dp_size: int = 1, fsdp_size: int = None,
mesh_dim_sizes.append(dp_size)
else:
mesh_dim_sizes.append(-1)
if ep_size is not None:
mesh_dim_sizes.append(ep_size)
mesh_dim_names.append("ep")
if origin_world_size == 1:
world_size *= ep_size
Comment on lines +89 to +93
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve robustness, it's a good practice to validate input parameters. The ep_size should be a positive integer. Consider adding a check to ensure ep_size > 0 to prevent potential runtime errors during mesh creation if an invalid value is passed.

Suggested change
if ep_size is not None:
mesh_dim_sizes.append(ep_size)
mesh_dim_names.append("ep")
if origin_world_size == 1:
world_size *= ep_size
if ep_size is not None:
if ep_size <= 0:
raise ValueError(f'ep_size must be positive, but got {ep_size}')
mesh_dim_sizes.append(ep_size)
mesh_dim_names.append("ep")
if origin_world_size == 1:
world_size *= ep_size

if cp_size is not None:
mesh_dim_sizes.append(cp_size)
mesh_dim_names.append("cp")
Expand Down