diff --git a/cookbook/rl/short_math_grpo_moe.py b/cookbook/rl/short_math_grpo_moe.py new file mode 100644 index 00000000..9d870eac --- /dev/null +++ b/cookbook/rl/short_math_grpo_moe.py @@ -0,0 +1,275 @@ +"""GRPO training script for GSM8K dataset. + +Converted from the Tinker client version to Ray-based training. +Uses short reasoning format: shorter thinking gets higher format reward. +Answer extracted from \\boxed{} or #### format. +""" +import os +import re +from typing import List, Tuple, Dict, Any + +from peft import LoraConfig + +import twinkle +from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger +from twinkle.advantage import GRPOAdvantage +from twinkle.checkpoint_engine import CheckpointEngineManager +from twinkle.data_format import SamplingParams +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.metric import CompletionRewardMetric +from twinkle.model import TransformersModel +from twinkle.processor import InputProcessor +from twinkle.reward import GSM8KAccuracyReward +from twinkle.reward.base import Reward +from twinkle.sampler import vLLMSampler +from twinkle.preprocessor.llm import GSM8KProcessor + +logger = get_logger() + +# ========== Configuration ========== +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B') +USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1'))) + +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) +MODEL_EP = int(os.environ.get('MODEL_EP', 2)) +MODEL_TP = int(os.environ.get('MODEL_TP', 2)) +MODEL_PP = int(os.environ.get('MODEL_PP', 2)) + +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2)) +SAMPLER_TP = int(os.environ.get('SAMPLER_TP', 2)) +NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS + +NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) +MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) +LEARNING_RATE = float(os.environ.get('LR', 5e-5)) +MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) +MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4)) +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) +ADAPTER_NAME = 'default' +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000)) +LORA_RANK = int(os.environ.get('LORA_RANK', 16)) + +SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning ' + 'and put your final answer within \\boxed{}.') + +# ========== Reward Functions ========== +class GSM8KBrevityReward(Reward): + """Brevity reward: rewards shorter completions that contain a valid answer. + + Returns 0.0 if no valid answer format (\\boxed{} or ####). + Otherwise returns higher score for shorter completions (1.0 at <=200 chars). + """ + + def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: + rewards = [] + for traj in trajectories: + messages = traj.get('messages', []) + completion = '' + for msg in reversed(messages): + if msg.get('role') == 'assistant': + completion = msg.get('content', '') + break + + has_answer = bool( + re.search(r'\\boxed\{[^}]+\}', completion) + or re.search(r'####\s*[\-\d,\.]+', completion) + ) + + if not has_answer: + rewards.append(0.0) + else: + length = len(completion) + if length <= 200: + rewards.append(1.0) + else: + rewards.append(max(0.0, 1.0 - (length - 200) / 3000)) + return rewards + + +# ========== Dataset ========== +def create_gsm8k_dataset(): + dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=4096, truncation_strategy='delete', enable_thinking=False) + dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT)) + dataset.encode(add_generation_prompt=True) + return dataset + + +def compute_rewards( + trajectories: List[Dict[str, Any]], +) -> Tuple[List[float], List[float], List[float]]: + accuracy_reward_fn = GSM8KAccuracyReward() + brevity_reward_fn = GSM8KBrevityReward() + + accuracy_rewards = accuracy_reward_fn(trajectories) + brevity_rewards = brevity_reward_fn(trajectories) + total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)] + return total_rewards, brevity_rewards, accuracy_rewards + + +# ========== Main ========== +def main(): + device_groups = [ + DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU', gpus_per_worker=SAMPLER_TP), + ] + dp_size = MODEL_GPUS // (MODEL_TP * MODEL_PP) + model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=dp_size, tp_size=MODEL_TP, pp_size=MODEL_PP, ep_size=MODEL_EP, sequence_parallel=True) + sampler_dp_size = SAMPLER_GPUS // (SAMPLER_TP) + sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=sampler_dp_size, tp_size=SAMPLER_TP) + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) + + lora_config = LoraConfig( + target_modules=['all-linear'], + r=LORA_RANK, + lora_alpha=LORA_RANK * 2, + lora_dropout=0.05, + ) + + if USE_MEGATRON: + from twinkle.model.megatron import MegatronModel + model = MegatronModel( + model_id=MODEL_ID, + device_mesh=model_mesh, + remote_group='model', + mixed_precision='bf16', + ) + else: + model = TransformersModel( + model_id=MODEL_ID, + device_mesh=model_mesh, + remote_group='model', + ) + + model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + if USE_MEGATRON: + model.set_optimizer('default', lr=LEARNING_RATE) + model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE) + else: + model.set_optimizer('AdamW', lr=LEARNING_RATE) + model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0) + + model.set_loss('GRPOLoss', epsilon=0.2) + model.set_processor(InputProcessor) + model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False) + + sampler = vLLMSampler( + model_id=MODEL_ID, + engine_args={ + 'tensor_parallel_size': SAMPLER_TP, + 'gpu_memory_utilization': 0.7, + 'max_model_len': 10000, + 'max_lora_rank': LORA_RANK, # save as lora_config + # NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976 + # enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False) + # meaning only sync lora weights, if merge_and_sync=True, + # lora will be merged into the base model and sync all weights to vLLM + 'enable_lora': True, + 'enable_tower_connector_lora': True, + }, + device_mesh=sampler_mesh, + remote_group='sampler', + ) + sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False) + + ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler) + + GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS + dataloader = DataLoader( + dataset=create_gsm8k_dataset, + batch_size=GLOBAL_BATCH_SIZE, + min_batch_size=GLOBAL_BATCH_SIZE, + device_mesh=model_mesh, + remote_group='model', + ) + + advantage_fn = GRPOAdvantage() + metrics = CompletionRewardMetric() + sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1, temperature=1.0, top_p=0.95) + + optim_step = 0 + logger.info('Starting GSM8K GRPO training (short reasoning)') + logger.info(get_device_placement()) + + for batch in dataloader: + if optim_step >= MAX_STEPS: + break + + metrics.reset() + expand_prompts = [] + for prompt in batch: + expand_prompts.extend([prompt] * NUM_GENERATIONS) + + # enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False) + # meaning only sync lora weights, if merge_and_sync=True, + # lora will be merged into the base model and sync all weights to vLLM + ckpt_manager.sync_weights(merge_and_sync=False) + sampler.reset_prefix_cache() + + sample_responses = sampler.sample( + expand_prompts, + sampling_params, + ) + if sample_responses and sample_responses[0].sequences: + first_decoded = sample_responses[0].sequences[0].decoded + if isinstance(first_decoded, str): + logger.info('[sample_debug] first_generation=%r', first_decoded[:512]) + + all_input_data: List[Dict[str, Any]] = [] + all_old_logps: List[List[float]] = [] + all_completion_lengths: List[int] = [] + + for sample_response in sample_responses: + for sequence in sample_response.sequences: + all_input_data.append(sequence.new_input_feature) + all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) + all_completion_lengths.append(len(sequence.tokens)) + + total_rewards, brevity_rewards, accuracy_rewards = compute_rewards(all_input_data) + + metrics.accumulate( + completion_lengths=all_completion_lengths, + rewards={ + 'total': total_rewards, + 'brevity': brevity_rewards, + 'accuracy': accuracy_rewards, + }, + ) + + advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist() + + total_completions = len(all_input_data) + for mb_start in range(0, total_completions, MINI_BATCH_SIZE): + mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions) + mb_inputs = all_input_data[mb_start:mb_end] + mb_old_logps = all_old_logps[mb_start:mb_end] + mb_advantages = advantages[mb_start:mb_end] + + model.forward_backward( + inputs=mb_inputs, + old_logps=mb_old_logps, + advantages=mb_advantages, + micro_batch_size=MICRO_BATCH_SIZE, + ) + model.clip_grad_and_step() + optim_step += 1 + + if optim_step >= MAX_STEPS: + break + if optim_step % SAVE_STEPS == 0: + model.save(f'math-grpo-checkpoint-{optim_step}') + + log_dict = metrics.calculate() + log_dict.update(model.calculate_metric(is_training=True)) + metrics.reset() + logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}') + + logger.info(f'Training completed. optim_steps={optim_step}') + model.save('math-grpo-final') + + +if __name__ == '__main__': + main() diff --git a/cookbook/rl/short_math_grpo_multi_lora.py b/cookbook/rl/short_math_grpo_multi_lora.py index fbbdcc27..9dad8df3 100644 --- a/cookbook/rl/short_math_grpo_multi_lora.py +++ b/cookbook/rl/short_math_grpo_multi_lora.py @@ -35,19 +35,21 @@ logger = get_logger() # ========== Configuration ========== -MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-27B') +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B') MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) -SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2)) +SAMPLER_TP = int(os.environ.get('SAMPLER_TP', 2)) + NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) -LEARNING_RATE = float(os.environ.get('LR', 5e-6)) +LEARNING_RATE = float(os.environ.get('LR', 5e-5)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4)) -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1)) GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) ADAPTER_NAME = 'default_0' SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000)) @@ -124,23 +126,19 @@ def main(): device_groups = [ DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'), DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU', - gpus_per_worker=2), + gpus_per_worker=SAMPLER_TP), ] # Model mesh: tp=2, ep=2, pp=2, sequence_parallel (ref: server_config.yaml) model_mesh = DeviceMesh.from_sizes( world_size=MODEL_GPUS, tp_size=2, - # ep_size=2, + ep_size=2, pp_size=2, - # sequence_parallel=True, - ) - # Sampler mesh: dp=2, tp=2 - sampler_mesh = DeviceMesh.from_sizes( - world_size=SAMPLER_GPUS, - dp_size=2, - tp_size=2, + sequence_parallel=True, ) + sampler_dp_size = SAMPLER_GPUS // (SAMPLER_TP) + sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=sampler_dp_size, tp_size=SAMPLER_TP) twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) @@ -179,9 +177,10 @@ def main(): sampler = vLLMSampler( model_id=MODEL_ID, engine_args={ + 'tensor_parallel_size': SAMPLER_TP, 'gpu_memory_utilization': 0.8, 'max_model_len': 8192, - 'max_lora_rank': 32, + 'max_lora_rank': LORA_RANK, # NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976 'enable_lora': True, 'enable_tower_connector_lora': True, @@ -241,6 +240,10 @@ def main(): sampling_params, adapter_path=lora_sync_path, ) + if sample_responses and sample_responses[0].sequences: + first_decoded = sample_responses[0].sequences[0].decoded + if isinstance(first_decoded, str): + logger.info('[sample_debug] first_generation=%r', first_decoded[:512]) all_input_data: List[Dict[str, Any]] = [] all_old_logps: List[List[float]] = [] diff --git a/src/twinkle/checkpoint_engine/manager.py b/src/twinkle/checkpoint_engine/manager.py index bd61b4f7..cde5c519 100644 --- a/src/twinkle/checkpoint_engine/manager.py +++ b/src/twinkle/checkpoint_engine/manager.py @@ -161,4 +161,5 @@ def _expand_keys(keys): if not self.base_sync_done: self.base_sync_done = True - logger.info('Base model sync completed, subsequent syncs will be LoRA-only') + if not merge_and_sync: + logger.info('Base model sync completed, subsequent syncs will be LoRA-only') diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index b2aa4690..68b6f39a 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -1444,6 +1444,7 @@ def _print_weight_example(names): logger.info(f'Sync weight: {name}') def _add_base_layer_suffix(name): + base_layer_name = None if name.endswith('.weight'): base_layer_name = f'{name[:-7]}.base_layer.weight' if not model_keys or base_layer_name in model_keys: @@ -1452,7 +1453,7 @@ def _add_base_layer_suffix(name): base_layer_name = f'{name[:-5]}.base_layer.bias' if not model_keys or base_layer_name in model_keys: name = base_layer_name - if 'experts' in name: + if 'experts' in name and base_layer_name is not None: return base_layer_name return name diff --git a/src/twinkle/patch/vllm_lora_weights.py b/src/twinkle/patch/vllm_lora_weights.py index 7e33419e..558c0389 100644 --- a/src/twinkle/patch/vllm_lora_weights.py +++ b/src/twinkle/patch/vllm_lora_weights.py @@ -29,6 +29,10 @@ def embeddings(self): class VLLMLoraWeights(Patch): def __call__(self, sampler, **kwargs): + from twinkle.patch.vllm_moe_loader import patch_qwen35_moe_is_3d_moe_weight_false + + patch_qwen35_moe_is_3d_moe_weight_false() + _sampler_ref = sampler def _get_tokenizer(): diff --git a/src/twinkle/patch/vllm_moe_loader.py b/src/twinkle/patch/vllm_moe_loader.py index 5d064c21..45b30103 100644 --- a/src/twinkle/patch/vllm_moe_loader.py +++ b/src/twinkle/patch/vllm_moe_loader.py @@ -58,6 +58,16 @@ pass +def patch_qwen35_moe_is_3d_moe_weight_false() -> None: + # expected_lora_modules : up_proj -> experts.0.up_proj + try: + from vllm.model_executor.models.qwen3_5 import Qwen3_5MoeForConditionalGeneration + + Qwen3_5MoeForConditionalGeneration.is_3d_moe_weight = False + except ImportError: + pass + + class VLLMMoEWeights(Patch): def __call__(self, model, **kwargs): diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index 2023e9e9..719acf1c 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -1,6 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import contextlib import inspect import os +import re import torch import uuid from typing import Any, Dict, List, Optional, Union @@ -95,6 +97,18 @@ def __init__( # ``list_loras()`` per request. self._synced_lora_request: Optional[Any] = None + # Long-lived CUDA IPC bucket reused across all update_weights() + # calls. Allocating a new IPC buffer (and hence a new IPC handle) + # per sync forces every worker to create a new CUDA IPC mapping via + # ``rebuild_cuda_tensor`` because PyTorch's ``shared_cache`` cannot + # hit on unseen storage handles. The driver reclaims those mappings + # lazily, which is the root cause of the slow GPU memory drift we + # observed under frequent LoRA syncs. By pinning a single buffer + # and its handle we guarantee the worker-side cache always hits. + self._ipc_buffer: Optional[torch.Tensor] = None + self._ipc_handle: Any = None + self._ipc_buffer_size: int = 0 + # Initialize engine self.engine = self._create_engine() @@ -521,7 +535,14 @@ async def _sync_iter(): sync_id = uuid.uuid4().hex zmq_handle = f'ipc:///tmp/twinkle-ipc-{device_uuid}-{os.getpid()}-{sync_id}.sock' + env_bucket_mb = os.environ.get('TWINKLE_VLLM_BUCKET_SIZE_MB') + if env_bucket_mb is not None: + bucket_size_mb = int(env_bucket_mb) + if bucket_size_mb <= 0: + raise ValueError(f'bucket_size_mb must be > 0, got {bucket_size_mb}') + bucket_size = bucket_size_mb << 20 + lora_mode = bool(base_sync_done and peft_config) # Create transfer buffer buffer = None @@ -529,8 +550,25 @@ async def _sync_iter(): if use_gpu_ipc: from torch.multiprocessing.reductions import reduce_tensor - buffer = torch.empty(bucket_size, dtype=torch.uint8, device=first_tensor.device) - ipc_handle = reduce_tensor(buffer) + + # Reuse a long-lived IPC bucket whenever the requested size + # fits. The handle is produced once and shipped to every + # subsequent sync so each worker's ``shared_cache`` stays warm + # and no new CUDA IPC mapping is created per sync. + need_realloc = ( + self._ipc_buffer is None or self._ipc_buffer_size < bucket_size + or self._ipc_buffer.device != first_tensor.device) + if need_realloc: + # Drop the old handle/buffer before allocating a bigger one + # so we do not briefly hold both and double the peak usage. + self._ipc_buffer = None + self._ipc_handle = None + self._ipc_buffer_size = 0 + self._ipc_buffer = torch.empty(bucket_size, dtype=torch.uint8, device=first_tensor.device) + self._ipc_handle = reduce_tensor(self._ipc_buffer) + self._ipc_buffer_size = bucket_size + buffer = self._ipc_buffer + ipc_handle = self._ipc_handle else: from multiprocessing import shared_memory shm_name = f'twinkle_weights_{uuid.uuid4().hex}' @@ -558,6 +596,8 @@ def _zmq_send_recv(payload, where: str): except zmq.error.Again as e: raise RuntimeError(f'IPC timeout ({zmq_timeout_s}s) during {where} on {zmq_handle}') from e + n_weights = 0 + worker_task: Optional['asyncio.Future'] = None try: # Launch worker side concurrently worker_task = asyncio.ensure_future( @@ -584,10 +624,14 @@ async def _chain_first(): offset = 0 bucket_meta: list[dict] = [] - n_weights = 0 + current_expert_layer: Optional[str] = None + + def _get_expert_layer_prefix(weight_name: str) -> Optional[str]: + m = re.match(r'^(.*\.mlp\.experts)\.\d+\.', weight_name) + return m.group(1) if m else None async def _flush_bucket(is_last: bool) -> None: - nonlocal offset, bucket_meta + nonlocal offset, bucket_meta, current_expert_layer if not bucket_meta and not is_last: return if buffer.device.type != 'cpu': @@ -603,6 +647,7 @@ async def _flush_bucket(is_last: bool) -> None: ) offset = 0 bucket_meta = [] + current_expert_layer = None async for name, weight in _chain_first(): if use_shm and weight.device.type != 'cpu': @@ -612,6 +657,15 @@ async def _flush_bucket(is_last: bool) -> None: weight_u8 = weight.view(-1).view(torch.uint8) total_nbytes = int(weight_u8.numel()) + expert_layer_prefix = _get_expert_layer_prefix(name) if lora_mode else None + if lora_mode and offset > 0: + # Keep each expert layer in an isolated bucket to avoid sending + # partial expert-layer weights. + if current_expert_layer != expert_layer_prefix: + await _flush_bucket(is_last=False) + if lora_mode: + current_expert_layer = expert_layer_prefix + chunk_offset = 0 while chunk_offset < total_nbytes: if offset >= bucket_size: @@ -642,6 +696,10 @@ async def _flush_bucket(is_last: bool) -> None: await worker_task finally: # Clean up — always release resources regardless of exceptions + if worker_task is not None and not worker_task.done(): + worker_task.cancel() + with contextlib.suppress(BaseException): + await worker_task socket.close() zmq_ctx.term() if zmq_handle.startswith('ipc://'): @@ -673,6 +731,10 @@ async def shutdown(self) -> None: logger.info('Shutting down VLLMEngine...') + self._ipc_buffer = None + self._ipc_handle = None + self._ipc_buffer_size = 0 + if self.engine is not None: try: # vLLM v1 AsyncLLM has shutdown() method @@ -693,10 +755,7 @@ async def shutdown(self) -> None: # Force garbage collection gc.collect() - # Clear CUDA cache if available - if torch.cuda.is_available(): - torch.cuda.empty_cache() - if hasattr(torch.cuda, 'ipc_collect'): - torch.cuda.ipc_collect() + Torch.empty_cache() + Torch.ipc_collect() logger.info('VLLMEngine shutdown complete') diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index c707479e..5f1dae58 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -435,6 +435,7 @@ async def _receive_and_load(): self._run_in_loop(_receive_and_load()) + @remote_function(dispatch='all', collect='first', lazy_collect=False) def shutdown(self): """Gracefully shutdown the vLLM engine and background event loop. diff --git a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py index ac58f92f..9879c638 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py @@ -59,6 +59,33 @@ def _rebuild_ipc(handle, device_id: Optional[int] = None) -> torch.Tensor: return rebuild_cuda_tensor(*list_args) +def _ipc_handle_signature(handle) -> Optional[tuple]: + """Derive a stable signature for a CUDA IPC handle. + + ``reduce_tensor`` returns ``(func, args)`` where ``args`` contains the + CUDA IPC storage handle bytes, storage size, ref-counter handle, etc. + Two handles are equivalent (i.e. map the same CUDA memory region) when + these inner fields match. We hash only the parts that are picklable and + comparable to avoid accidental mismatches due to local objects. + """ + try: + _, args = handle + except Exception: + return None + sig = [] + for v in args: + if isinstance(v, (bytes, bytearray)): + sig.append(('bytes', bytes(v))) + elif isinstance(v, (int, float, bool, str)) or v is None: + sig.append(('scalar', v)) + else: + try: + sig.append(('repr', repr(v))) + except Exception: + return None + return tuple(sig) + + def _rebuild_shared_memory(name: str, size: int): """Rebuild tensor from shared memory. Returns (tensor, shm).""" from multiprocessing import shared_memory @@ -129,9 +156,6 @@ def update_weights_from_ipc( logger.info(f'vLLM worker bind device: local_rank={local_rank}, device={device_str}') self.device = torch.device(device_str) - if peft_config and base_sync_done: - self.remove_lora(VLLM_LORA_INT_ID) - # Detect TP rank — vLLM sets self.rank on each worker. tp_rank = getattr(self, 'rank', 0) tp_size = 1 @@ -187,7 +211,27 @@ def _broadcast_obj(obj): handle = comm_metadata # All TP ranks rebuild the IPC buffer from the same handle. # CUDA IPC allows any process on the same node to map the memory. - buffer = _rebuild_ipc(handle, self.device.index) + # Reuse a cached buffer across syncs when the sender reuses the + # same IPC handle: this avoids creating a fresh CUDA IPC mapping + # per sync, which the driver releases lazily and is the root + # cause of the apparent GPU memory growth under frequent syncs. + handle_signature = _ipc_handle_signature(handle) + cached_buffer = getattr(self, '_twinkle_ipc_buffer', None) + cached_signature = getattr(self, '_twinkle_ipc_handle_signature', None) + if cached_buffer is not None and cached_signature == handle_signature: + buffer = cached_buffer + else: + # Drop the previous mapping before creating a new one so the + # driver can reclaim the old shared memory region. + if cached_buffer is not None: + self._twinkle_ipc_buffer = None + self._twinkle_ipc_handle_signature = None + del cached_buffer + gc.collect() + Torch.ipc_collect() + buffer = _rebuild_ipc(handle, self.device.index) + self._twinkle_ipc_buffer = buffer + self._twinkle_ipc_handle_signature = handle_signature else: from multiprocessing import shared_memory buffer, shm = _rebuild_shared_memory( @@ -200,6 +244,8 @@ def _broadcast_obj(obj): # ── Step 3: Receive and process weight buckets ── partial_tensors: dict = {} + lora_bucket_accum: list[tuple[str, torch.Tensor]] = [] + lora_mode = bool(peft_config and base_sync_done) while True: # Only the driver receives bucket metadata from VLLMEngine. if is_driver: @@ -292,7 +338,10 @@ def _broadcast_obj(obj): if tp_size > 1: dist.barrier(group=cpu_group) - self._load_weights(weights, peft_config=peft_config, base_sync_done=base_sync_done) + if lora_mode: + lora_bucket_accum.extend(weights) + else: + self._load_weights(weights, peft_config=peft_config, base_sync_done=base_sync_done) del weights if metadata['is_last']: @@ -300,9 +349,16 @@ def _broadcast_obj(obj): pending = ', '.join(sorted(partial_tensors.keys())[:8]) raise RuntimeError( f'Incomplete chunked weights at stream end: pending {len(partial_tensors)} ({pending})') + if lora_mode: + self._load_weights( + lora_bucket_accum, + peft_config=peft_config, + base_sync_done=base_sync_done, + ) break partial_tensors.clear() + lora_bucket_accum.clear() metadata = None raw_u8 = None cpu_u8 = None @@ -312,7 +368,6 @@ def _broadcast_obj(obj): if is_driver and socket is not None: socket.close() del buffer - gc.collect() if shm is not None: try: shm.close() @@ -324,6 +379,7 @@ def _broadcast_obj(obj): except BufferError as e: logger.warning(f'SharedMemory close skipped due to exported pointers: {e}') del shm + gc.collect() Torch.ipc_collect() Torch.empty_cache() @@ -403,9 +459,6 @@ def _load_weights( here. """ if peft_config and base_sync_done: - # Remove existing LoRA before replacing - self.remove_lora(VLLM_LORA_INT_ID) - from twinkle.patch.vllm_lora_weights import TensorLoRARequest converted = {self._convert_peft_to_vllm_lora_name(n): t for n, t in weights} @@ -415,6 +468,7 @@ def _load_weights( lora_path=VLLM_LORA_PATH, peft_config=peft_config, lora_tensors=converted, + load_inplace=True, ) self.add_lora(lora_request) else: diff --git a/src/twinkle/utils/torch_utils.py b/src/twinkle/utils/torch_utils.py index b0d85080..0f08427b 100644 --- a/src/twinkle/utils/torch_utils.py +++ b/src/twinkle/utils/torch_utils.py @@ -77,7 +77,7 @@ def selective_log_softmax(logits, index) -> 'torch.Tensor': if mpu.get_tensor_model_parallel_world_size() > 1: # clone to avoid modifying the original logits return _vocab_parallel_selective_log_softmax(logits.clone(), index) - except Exception: + except (ImportError, AssertionError): pass if logits.dtype in [torch.float32, torch.float64]: diff --git a/tests/sampler/test_megatron_weight_sync.py b/tests/sampler/test_megatron_weight_sync.py index 2d32b5f3..6cfb2c87 100644 --- a/tests/sampler/test_megatron_weight_sync.py +++ b/tests/sampler/test_megatron_weight_sync.py @@ -41,7 +41,7 @@ os.environ['NCCL_CUMEM_ENABLE'] = '0' # Model configuration — use a small model for testing -MODEL_ID = os.environ.get('TEST_MODEL_ID', 'Qwen/Qwen2.5-0.5B-Instruct') +MODEL_ID = os.environ.get('TEST_MODEL_ID', 'Qwen/Qwen3.5-27B') logger = logging.getLogger(__name__) @@ -169,9 +169,15 @@ def test_megatron_weight_sync( model_id=model_path, device_mesh=model_device_mesh, mixed_precision='bf16', - sequence_parallel=(tp_size > 1), remote_group='model', ) + lora_config = { + 'target_modules': 'all-linear', + 'r': 8, + 'lora_alpha': 32, + 'lora_dropout': 0.05, + } + model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=1) log(' MegatronModel created successfully') # ── Create Sampler (dummy weights) ──────────────────────────────── @@ -180,11 +186,11 @@ def test_megatron_weight_sync( model_id=model_path, engine_args={ 'load_format': 'dummy', - 'gpu_memory_utilization': 0.3, - 'max_model_len': 256, - 'enforce_eager': True, + 'gpu_memory_utilization': 0.7, + 'max_model_len': 1024, + 'enforce_eager': False, 'enable_sleep_mode': True, - 'enable_lora': False, + 'enable_lora': True, }, device_mesh=DeviceMesh.from_sizes(world_size=sampler_gpus, dp_size=sampler_gpus), remote_group='sampler', @@ -197,9 +203,12 @@ def test_megatron_weight_sync( time.sleep(5) # ── Helper: sample one prompt ───────────────────────────────────── + dp = sampler.device_mesh.data_world_size + def do_sample(prompt: str, max_tokens: int = 32) -> str: - traj = Trajectory(messages=[{'role': 'user', 'content': prompt}]) - responses = wait_result(sampler.sample(traj, SamplingParams(max_tokens=max_tokens, temperature=0.0))) + one = Trajectory(messages=[{'role': 'user', 'content': prompt}]) + batch = [one for _ in range(dp)] + responses = wait_result(sampler.sample(batch, SamplingParams(max_tokens=max_tokens, temperature=0.0))) for response in responses: if response and response.sequences: tokens = response.sequences[0].tokens @@ -221,7 +230,7 @@ def do_sample(prompt: str, max_tokens: int = 32) -> str: ) sync_start = time.time() - manager.sync_weights() + manager.sync_weights(merge_and_sync=False) sampler.reset_prefix_cache() sync_time = time.time() - sync_start log(f' Weight sync completed in {sync_time:.2f}s') @@ -257,10 +266,10 @@ def do_sample(prompt: str, max_tokens: int = 32) -> str: def main(): parser = argparse.ArgumentParser(description='Test Megatron standalone weight synchronization') - parser.add_argument('--model-gpus', type=int, default=2, help='Number of GPUs for Megatron model (default: 2)') + parser.add_argument('--model-gpus', type=int, default=4, help='Number of GPUs for Megatron model (default: 4)') parser.add_argument('--sampler-gpus', type=int, default=2, help='Number of GPUs for vLLM sampler (default: 2)') parser.add_argument('--tp-size', type=int, default=2, help='Tensor parallel size (default: 2)') - parser.add_argument('--pp-size', type=int, default=1, help='Pipeline parallel size (default: 1)') + parser.add_argument('--pp-size', type=int, default=2, help='Pipeline parallel size (default: 2)') args = parser.parse_args() log('Starting Megatron standalone weight sync test...')