Skip to content
275 changes: 275 additions & 0 deletions cookbook/rl/short_math_grpo_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
"""GRPO training script for GSM8K dataset.

Converted from the Tinker client version to Ray-based training.
Uses short reasoning format: shorter thinking gets higher format reward.
Answer extracted from \\boxed{} or #### format.
"""
import os
import re
from typing import List, Tuple, Dict, Any

from peft import LoraConfig

import twinkle
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
from twinkle.advantage import GRPOAdvantage
from twinkle.checkpoint_engine import CheckpointEngineManager
from twinkle.data_format import SamplingParams
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.metric import CompletionRewardMetric
from twinkle.model import TransformersModel
from twinkle.processor import InputProcessor
from twinkle.reward import GSM8KAccuracyReward
from twinkle.reward.base import Reward
from twinkle.sampler import vLLMSampler
from twinkle.preprocessor.llm import GSM8KProcessor

logger = get_logger()

# ========== Configuration ==========
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B')
USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1')))

MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
MODEL_EP = int(os.environ.get('MODEL_EP', 2))
MODEL_TP = int(os.environ.get('MODEL_TP', 2))
MODEL_PP = int(os.environ.get('MODEL_PP', 2))

SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2))
SAMPLER_TP = int(os.environ.get('SAMPLER_TP', 2))
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS

NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
LEARNING_RATE = float(os.environ.get('LR', 5e-5))
MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4))
MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4))
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1))
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
ADAPTER_NAME = 'default'
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000))
LORA_RANK = int(os.environ.get('LORA_RANK', 16))

SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning '
'and put your final answer within \\boxed{}.')

# ========== Reward Functions ==========
class GSM8KBrevityReward(Reward):
"""Brevity reward: rewards shorter completions that contain a valid answer.

Returns 0.0 if no valid answer format (\\boxed{} or ####).
Otherwise returns higher score for shorter completions (1.0 at <=200 chars).
"""

def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
rewards = []
for traj in trajectories:
messages = traj.get('messages', [])
completion = ''
for msg in reversed(messages):
if msg.get('role') == 'assistant':
completion = msg.get('content', '')
break

has_answer = bool(
re.search(r'\\boxed\{[^}]+\}', completion)
or re.search(r'####\s*[\-\d,\.]+', completion)
)

if not has_answer:
rewards.append(0.0)
else:
length = len(completion)
if length <= 200:
rewards.append(1.0)
else:
rewards.append(max(0.0, 1.0 - (length - 200) / 3000))
return rewards


# ========== Dataset ==========
def create_gsm8k_dataset():
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=4096, truncation_strategy='delete', enable_thinking=False)
dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT))
dataset.encode(add_generation_prompt=True)
return dataset


def compute_rewards(
trajectories: List[Dict[str, Any]],
) -> Tuple[List[float], List[float], List[float]]:
accuracy_reward_fn = GSM8KAccuracyReward()
brevity_reward_fn = GSM8KBrevityReward()

accuracy_rewards = accuracy_reward_fn(trajectories)
brevity_rewards = brevity_reward_fn(trajectories)
total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)]
return total_rewards, brevity_rewards, accuracy_rewards


# ========== Main ==========
def main():
device_groups = [
DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU', gpus_per_worker=SAMPLER_TP),
]
dp_size = MODEL_GPUS // (MODEL_TP * MODEL_PP)
model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=dp_size, tp_size=MODEL_TP, pp_size=MODEL_PP, ep_size=MODEL_EP, sequence_parallel=True)
sampler_dp_size = SAMPLER_GPUS // (SAMPLER_TP)
sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=sampler_dp_size, tp_size=SAMPLER_TP)
twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)

lora_config = LoraConfig(
target_modules=['all-linear'],
r=LORA_RANK,
lora_alpha=LORA_RANK * 2,
lora_dropout=0.05,
)

if USE_MEGATRON:
from twinkle.model.megatron import MegatronModel
model = MegatronModel(
model_id=MODEL_ID,
device_mesh=model_mesh,
remote_group='model',
mixed_precision='bf16',
)
else:
model = TransformersModel(
model_id=MODEL_ID,
device_mesh=model_mesh,
remote_group='model',
)

model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
if USE_MEGATRON:
model.set_optimizer('default', lr=LEARNING_RATE)
model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE)
else:
model.set_optimizer('AdamW', lr=LEARNING_RATE)
model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0)

model.set_loss('GRPOLoss', epsilon=0.2)
model.set_processor(InputProcessor)
model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False)

sampler = vLLMSampler(
model_id=MODEL_ID,
engine_args={
'tensor_parallel_size': SAMPLER_TP,
'gpu_memory_utilization': 0.7,
'max_model_len': 10000,
'max_lora_rank': LORA_RANK, # save as lora_config
# NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976
# enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False)
# meaning only sync lora weights, if merge_and_sync=True,
# lora will be merged into the base model and sync all weights to vLLM
'enable_lora': True,
'enable_tower_connector_lora': True,
},
device_mesh=sampler_mesh,
remote_group='sampler',
)
sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False)

ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler)

GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS
dataloader = DataLoader(
dataset=create_gsm8k_dataset,
batch_size=GLOBAL_BATCH_SIZE,
min_batch_size=GLOBAL_BATCH_SIZE,
device_mesh=model_mesh,
remote_group='model',
)

advantage_fn = GRPOAdvantage()
metrics = CompletionRewardMetric()
sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1, temperature=1.0, top_p=0.95)

optim_step = 0
logger.info('Starting GSM8K GRPO training (short reasoning)')
logger.info(get_device_placement())

for batch in dataloader:
if optim_step >= MAX_STEPS:
break

metrics.reset()
expand_prompts = []
for prompt in batch:
expand_prompts.extend([prompt] * NUM_GENERATIONS)

# enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False)
# meaning only sync lora weights, if merge_and_sync=True,
# lora will be merged into the base model and sync all weights to vLLM
ckpt_manager.sync_weights(merge_and_sync=False)
sampler.reset_prefix_cache()

sample_responses = sampler.sample(
expand_prompts,
sampling_params,
)
if sample_responses and sample_responses[0].sequences:
first_decoded = sample_responses[0].sequences[0].decoded
if isinstance(first_decoded, str):
logger.info('[sample_debug] first_generation=%r', first_decoded[:512])

all_input_data: List[Dict[str, Any]] = []
all_old_logps: List[List[float]] = []
all_completion_lengths: List[int] = []

for sample_response in sample_responses:
for sequence in sample_response.sequences:
all_input_data.append(sequence.new_input_feature)
all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs])
all_completion_lengths.append(len(sequence.tokens))

total_rewards, brevity_rewards, accuracy_rewards = compute_rewards(all_input_data)

metrics.accumulate(
completion_lengths=all_completion_lengths,
rewards={
'total': total_rewards,
'brevity': brevity_rewards,
'accuracy': accuracy_rewards,
},
)

advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist()

total_completions = len(all_input_data)
for mb_start in range(0, total_completions, MINI_BATCH_SIZE):
mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions)
mb_inputs = all_input_data[mb_start:mb_end]
mb_old_logps = all_old_logps[mb_start:mb_end]
mb_advantages = advantages[mb_start:mb_end]

model.forward_backward(
inputs=mb_inputs,
old_logps=mb_old_logps,
advantages=mb_advantages,
micro_batch_size=MICRO_BATCH_SIZE,
)
model.clip_grad_and_step()
optim_step += 1

if optim_step >= MAX_STEPS:
break
if optim_step % SAVE_STEPS == 0:
model.save(f'math-grpo-checkpoint-{optim_step}')

log_dict = metrics.calculate()
log_dict.update(model.calculate_metric(is_training=True))
metrics.reset()
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}')

logger.info(f'Training completed. optim_steps={optim_step}')
model.save('math-grpo-final')


if __name__ == '__main__':
main()
31 changes: 17 additions & 14 deletions cookbook/rl/short_math_grpo_multi_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,21 @@
logger = get_logger()

# ========== Configuration ==========
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-27B')
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B')

MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2))
SAMPLER_TP = int(os.environ.get('SAMPLER_TP', 2))

NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS

NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
LEARNING_RATE = float(os.environ.get('LR', 5e-6))
LEARNING_RATE = float(os.environ.get('LR', 5e-5))
MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4))
MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4))
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2))
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1))
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
ADAPTER_NAME = 'default_0'
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000))
Expand Down Expand Up @@ -124,23 +126,19 @@ def main():
device_groups = [
DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU',
gpus_per_worker=2),
gpus_per_worker=SAMPLER_TP),
]

# Model mesh: tp=2, ep=2, pp=2, sequence_parallel (ref: server_config.yaml)
model_mesh = DeviceMesh.from_sizes(
world_size=MODEL_GPUS,
tp_size=2,
# ep_size=2,
ep_size=2,
pp_size=2,
# sequence_parallel=True,
)
# Sampler mesh: dp=2, tp=2
sampler_mesh = DeviceMesh.from_sizes(
world_size=SAMPLER_GPUS,
dp_size=2,
tp_size=2,
sequence_parallel=True,
)
sampler_dp_size = SAMPLER_GPUS // (SAMPLER_TP)
sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=sampler_dp_size, tp_size=SAMPLER_TP)

twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)

Expand Down Expand Up @@ -179,9 +177,10 @@ def main():
sampler = vLLMSampler(
model_id=MODEL_ID,
engine_args={
'tensor_parallel_size': SAMPLER_TP,
'gpu_memory_utilization': 0.8,
'max_model_len': 8192,
'max_lora_rank': 32,
'max_lora_rank': LORA_RANK,
# NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976
'enable_lora': True,
'enable_tower_connector_lora': True,
Expand Down Expand Up @@ -241,6 +240,10 @@ def main():
sampling_params,
adapter_path=lora_sync_path,
)
if sample_responses and sample_responses[0].sequences:
first_decoded = sample_responses[0].sequences[0].decoded
if isinstance(first_decoded, str):
logger.info('[sample_debug] first_generation=%r', first_decoded[:512])

all_input_data: List[Dict[str, Any]] = []
all_old_logps: List[List[float]] = []
Expand Down
3 changes: 2 additions & 1 deletion src/twinkle/checkpoint_engine/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,5 @@ def _expand_keys(keys):

if not self.base_sync_done:
self.base_sync_done = True
logger.info('Base model sync completed, subsequent syncs will be LoRA-only')
if not merge_and_sync:
logger.info('Base model sync completed, subsequent syncs will be LoRA-only')
3 changes: 2 additions & 1 deletion src/twinkle/model/megatron/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -1444,6 +1444,7 @@ def _print_weight_example(names):
logger.info(f'Sync weight: {name}')

def _add_base_layer_suffix(name):
base_layer_name = None
if name.endswith('.weight'):
base_layer_name = f'{name[:-7]}.base_layer.weight'
if not model_keys or base_layer_name in model_keys:
Expand All @@ -1452,7 +1453,7 @@ def _add_base_layer_suffix(name):
base_layer_name = f'{name[:-5]}.base_layer.bias'
if not model_keys or base_layer_name in model_keys:
name = base_layer_name
if 'experts' in name:
if 'experts' in name and base_layer_name is not None:
return base_layer_name
return name

Expand Down
4 changes: 4 additions & 0 deletions src/twinkle/patch/vllm_lora_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def embeddings(self):
class VLLMLoraWeights(Patch):

def __call__(self, sampler, **kwargs):
from twinkle.patch.vllm_moe_loader import patch_qwen35_moe_is_3d_moe_weight_false

patch_qwen35_moe_is_3d_moe_weight_false()

_sampler_ref = sampler

def _get_tokenizer():
Expand Down
Loading
Loading