Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cookbook/megatron/tp_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from twinkle.model import MegatronModel
from twinkle.preprocessor import SelfCognitionProcessor

# Construct a device_mesh, tp=pp=cp=ep=2, dp=1
device_mesh = DeviceMesh.from_sizes(dp_size=1, tp_size=2, pp_size=2, cp_size=2, ep_size=2)
# Construct a device_mesh, tp=pp=ep=dp=2
device_mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2, ep_size=2, sequence_parallel=True)
# use torchrun mode
twinkle.initialize(mode='local', global_device_mesh=device_mesh)

Expand Down
2 changes: 1 addition & 1 deletion cookbook/mm/fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def train():
# Print metric
metric = model.calculate_metric(is_training=True)
logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
if step > 0 and step % 40 == 0:
if step > 0 and step % 200 == 0:
metrics = eval(model)
logger.info(f'Eval metric: {metrics}')
metrics['step'] = step
Expand Down
1 change: 0 additions & 1 deletion cookbook/ray/run.sh

This file was deleted.

91 changes: 0 additions & 91 deletions cookbook/ray/single_controller.py

This file was deleted.

4 changes: 2 additions & 2 deletions cookbook/rl/gkd_off_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3-0.6B')
TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-8B')

MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8))
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 8))
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS

BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16))
Expand Down
132 changes: 89 additions & 43 deletions cookbook/rl/gkd_on_policy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
"""GKD On-Policy Distillation via Ray.
"""GKD On-Policy Multimodal Distillation via Ray.

On-policy knowledge distillation: student vLLM generates responses,
teacher vLLM provides top-k prompt logprobs, then student model learns
to match the teacher's token distribution.
On-policy knowledge distillation on OlympiadBench multimodal math/physics:
student vLLM generates responses, teacher vLLM provides top-k prompt logprobs,
then student model learns to match the teacher's token distribution.

Supports three OlympiadBench subsets:
- OE_MM_maths_zh_CEE: Multimodal math problems (Chinese CEE)
- OE_MM_physics_zh_CEE: Multimodal physics problems (Chinese CEE)
- OE_TO_maths_zh_CEE: Text-only math problems (Chinese CEE)

Pipeline:
1. Sync student model weights to student vLLM sampler.
Expand All @@ -23,8 +28,8 @@
student + teacher (model GPUs)

Environment variables (all optional):
STUDENT_MODEL_ID – (default: ms://Qwen/Qwen3-0.6B)
TEACHER_MODEL_ID – (default: ms://Qwen/Qwen3-8B)
STUDENT_MODEL_ID – (default: ms://Qwen/Qwen3.5-4B)
TEACHER_MODEL_ID – (default: ms://Qwen/Qwen3.5-9B)
MODEL_GPUS – GPUs for student model (default: 4)
SAMPLER_GPUS – GPUs for each vLLM sampler (default: 4)
MAX_NEW_TOKENS – max completion tokens (default: 2048)
Expand All @@ -48,44 +53,57 @@
from twinkle.checkpoint_engine import CheckpointEngineManager
from twinkle.data_format import SamplingParams
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.dataset import DatasetMeta, LazyDataset
from twinkle.loss import GKDLoss
from twinkle.model import TransformersModel
from twinkle.preprocessor import GSM8KProcessor
from twinkle.preprocessor.olympiad_bench import OlympiadBenchProcessor
from twinkle.sampler import vLLMSampler
from twinkle.template import Template

logger = get_logger()

# ── Configuration ─────────────────────────────────────────────────────────────
STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3-0.6B')
TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-8B')
STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3.5-9B')
USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1')))

MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2))
NUM_GPUS = MODEL_GPUS + 2*SAMPLER_GPUS

MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 2048))
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16))
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4))
MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
LEARNING_RATE = float(os.environ.get('LR', 5e-5))
N_SAMPLES = int(os.environ.get('N_SAMPLES', 1))

GKD_BETA = float(os.environ.get('GKD_BETA', 0.5))
GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0))
GKD_TOPK = int(os.environ.get('GKD_TOPK', 64))
SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem step by step and put '
'your final answer within #### <number>')
ADAPTER_NAME = 'default'

# OlympiadBench subsets
SUBSETS = [
'OE_MM_maths_zh_CEE',
'OE_MM_physics_zh_CEE',
'OE_TO_maths_zh_CEE',
]


# ── Dataset ───────────────────────────────────────────────────────────────────

def create_dataset():
"""Prompt-only dataset; student vLLM will generate completions on-policy."""
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
dataset.set_template('Template', model_id=STUDENT_MODEL_ID, max_length=1024)
dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT))
"""OlympiadBench multimodal dataset; student vLLM will generate completions on-policy."""
ds = DatasetMeta('ms://AI-ModelScope/OlympiadBench', subset_name=SUBSETS[0], split='train')
dataset = LazyDataset(ds)
dataset.map(OlympiadBenchProcessor(language='zh'), dataset_meta=ds)

for subset in SUBSETS[1:]:
ds = DatasetMeta('ms://AI-ModelScope/OlympiadBench', subset_name=subset, split='train')
dataset.add_dataset(ds)
dataset.map(OlympiadBenchProcessor(language='zh'), dataset_meta=ds)

dataset.set_template('Qwen3_5Template', model_id=STUDENT_MODEL_ID, max_length=2048, enable_thinking=False)
dataset.mix_dataset(interleave=True)
return dataset


Expand Down Expand Up @@ -155,41 +173,69 @@ def main():
)

# ── Student model (trainable) ──────────────────────────────────────────────
student_model = TransformersModel(
model_id=STUDENT_MODEL_ID,
device_mesh=model_mesh,
remote_group='student_model',
)
student_model.add_adapter_to_model(
ADAPTER_NAME,
LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules='all-linear'),
gradient_accumulation_steps=1,
)
student_model.set_optimizer('AdamW', lr=LEARNING_RATE)
student_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0)
student_model.set_loss(GKDLoss(beta=GKD_BETA, temperature=GKD_TEMPERATURE))
student_model.set_template('Template', model_id=STUDENT_MODEL_ID)
lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules='all-linear')

if USE_MEGATRON:
from twinkle.model.megatron import MegatronModel
student_model = MegatronModel(
model_id=STUDENT_MODEL_ID,
device_mesh=model_mesh,
remote_group='student_model',
)
else:
from transformers import Qwen3_5ForConditionalGeneration
student_model = TransformersModel(
model_id=STUDENT_MODEL_ID,
model_cls=Qwen3_5ForConditionalGeneration,
device_mesh=model_mesh,
remote_group='student_model',
)

student_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1)

if USE_MEGATRON:
student_model.set_optimizer('default', lr=LEARNING_RATE, adapter_name=ADAPTER_NAME)
student_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE, adapter_name=ADAPTER_NAME)
else:
student_model.set_optimizer('AdamW', lr=LEARNING_RATE)
student_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0)

student_model.set_loss(GKDLoss(beta=GKD_BETA, temperature=GKD_TEMPERATURE), adapter_name=ADAPTER_NAME)
student_model.set_template('Qwen3_5Template', model_id=STUDENT_MODEL_ID, adapter_name=ADAPTER_NAME, enable_thinking=False)

# ── Student vLLM sampler (for on-policy generation) ────────────────────────
student_sampler = vLLMSampler(
model_id=STUDENT_MODEL_ID,
# enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False)
# meaning only sync lora weights, if merge_and_sync=True,
# lora will be merged into the base model and sync all weights to vLLM
engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 4096, 'enable_lora': True, 'max_loras': 1},
engine_args={
'gpu_memory_utilization': 0.75,
'max_model_len': 8192,
'enable_lora': True,
'max_loras': 1,
'limit_mm_per_prompt': {'image': 3},
'enable_tower_connector_lora': True,
},
device_mesh=sampler_mesh,
remote_group='student_sampler',
)
student_sampler.set_template(Template, model_id=STUDENT_MODEL_ID)
student_sampler.set_template('Qwen3_5Template', model_id=STUDENT_MODEL_ID, enable_thinking=False)

# ── Teacher vLLM sampler (for prompt logprobs) ───────────────────────────────
teacher_sampler = vLLMSampler(
model_id=TEACHER_MODEL_ID,
engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 4096, 'logprobs_mode': 'raw_logprobs', 'max_logprobs': 64},
engine_args={
'gpu_memory_utilization': 0.75,
'max_model_len': 8192,
'logprobs_mode': 'raw_logprobs',
'max_logprobs': 64,
'limit_mm_per_prompt': {'image': 3},
},
device_mesh=sampler_mesh,
remote_group='teacher_sampler',
)
teacher_sampler.set_template(Template, model_id=TEACHER_MODEL_ID)
teacher_sampler.set_template('Qwen3_5Template', model_id=TEACHER_MODEL_ID, enable_thinking=False)

# ── DataLoader (prompt-only) ───────────────────────────────────────────────
dataloader = DataLoader(
Expand Down Expand Up @@ -236,19 +282,19 @@ def main():
)

# 5. Student forward + GKD backward
student_model.forward_backward(inputs=input_data, **teacher_output)
student_model.clip_grad_and_step()
student_model.forward_backward(inputs=input_data, adapter_name=ADAPTER_NAME, **teacher_output)
student_model.clip_grad_and_step(adapter_name=ADAPTER_NAME)

if optim_step > 0 and optim_step % 10 == 0:
metric = student_model.calculate_metric(is_training=True)
if optim_step > 0 and optim_step % 1 == 0:
metric = student_model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME)
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}')

if optim_step > 0 and optim_step % 50 == 0:
student_model.save(f'gkd-onpolicy-ckpt-{optim_step}')
student_model.save(f'gkd-onpolicy-ckpt-{optim_step}', adapter_name=ADAPTER_NAME)

optim_step += 1

student_model.save('gkd-onpolicy-final')
student_model.save('gkd-onpolicy-final', adapter_name=ADAPTER_NAME)
logger.info('GKD on-policy training completed.')


Expand Down
6 changes: 3 additions & 3 deletions cookbook/rl/short_math_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
# ========== Dataset ==========
def create_gsm8k_dataset():
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=4096, truncation_strategy='delete', enable_thinking=False)
dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=4096, truncation_strategy='delete', enable_thinking=True)
dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT))
dataset.encode(add_generation_prompt=True)
return dataset
Expand Down Expand Up @@ -153,7 +153,7 @@ def main():

model.set_loss('GRPOLoss', epsilon=0.2)
model.set_processor(InputProcessor)
model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False)
model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=True)

sampler = vLLMSampler(
model_id=MODEL_ID,
Expand All @@ -170,7 +170,7 @@ def main():
device_mesh=sampler_mesh,
remote_group='sampler',
)
sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False)
sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=True)

ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler)

Expand Down
2 changes: 1 addition & 1 deletion cookbook/transformers/fsdp2.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 fsdp2.py
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 fsdp2.py
6 changes: 3 additions & 3 deletions cookbook/transformers/fsdp2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
def eval(model):
# 100 Samples
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100)))
dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B')
dataset.set_template('Template', model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507')
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
dataset.encode()
dataloader = DataLoader(dataset=dataset, batch_size=4)
Expand All @@ -35,15 +35,15 @@ def train():
# 1000 samples
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
# Set template to prepare encoding
dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B')
dataset.set_template('Template', model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507')
# Preprocess the dataset to standard format
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
# Encode dataset
dataset.encode()
# Global batch size = 4, for GPUs, so 1 sample per GPU
dataloader = DataLoader(dataset=dataset, batch_size=8)
# Use a TransformersModel, transformer_cls_names_to_wrap=Qwen3MoeSparseMoeBlock to avoid hang of fsdp2
model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B', fsdp_config={'transformer_cls_names_to_wrap':['Qwen3MoeSparseMoeBlock']})
model = TransformersModel(model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507', fsdp_config={'transformer_cls_names_to_wrap':['Qwen3MoeSparseMoeBlock']})
# Patch MoE model to fix the hang bug, support transformers==4.*
model.apply_patch('ms://twinkle-kit/qwen3_moe_transformers4_patch')
lora_config = LoraConfig(
Expand Down
Loading
Loading