From 820d475632f005c07200f7c5a74162cd838a28b4 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 11 Apr 2026 00:05:28 +0800 Subject: [PATCH 1/8] fix --- cookbook/megatron/tp_moe.py | 4 ++-- cookbook/mm/fsdp2.py | 2 +- cookbook/rl/gkd_off_policy.py | 4 ++-- cookbook/rl/gkd_on_policy.py | 2 +- cookbook/transformers/fsdp2.sh | 2 +- cookbook/transformers/fsdp2_moe.py | 6 +++--- src/twinkle/model/transformers/transformers.py | 14 +++++++++----- src/twinkle/processor/base.py | 2 +- src/twinkle/sampler/vllm_sampler/vllm_engine.py | 4 ++-- src/twinkle/sampler/vllm_sampler/vllm_sampler.py | 7 ++++++- src/twinkle/template/qwen3_5_vl.py | 6 ++++-- 11 files changed, 32 insertions(+), 21 deletions(-) diff --git a/cookbook/megatron/tp_moe.py b/cookbook/megatron/tp_moe.py index b66b109f..a13b0e58 100644 --- a/cookbook/megatron/tp_moe.py +++ b/cookbook/megatron/tp_moe.py @@ -9,8 +9,8 @@ from twinkle.model import MegatronModel from twinkle.preprocessor import SelfCognitionProcessor -# Construct a device_mesh, tp=pp=cp=ep=2, dp=1 -device_mesh = DeviceMesh.from_sizes(dp_size=1, tp_size=2, pp_size=2, cp_size=2, ep_size=2) +# Construct a device_mesh, tp=pp=ep=dp=2 +device_mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2, ep_size=2, sequence_parallel=True) # use torchrun mode twinkle.initialize(mode='local', global_device_mesh=device_mesh) diff --git a/cookbook/mm/fsdp2.py b/cookbook/mm/fsdp2.py index cbe6f50d..4dc50850 100644 --- a/cookbook/mm/fsdp2.py +++ b/cookbook/mm/fsdp2.py @@ -89,7 +89,7 @@ def train(): # Print metric metric = model.calculate_metric(is_training=True) logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - if step > 0 and step % 40 == 0: + if step > 0 and step % 200 == 0: metrics = eval(model) logger.info(f'Eval metric: {metrics}') metrics['step'] = step diff --git a/cookbook/rl/gkd_off_policy.py b/cookbook/rl/gkd_off_policy.py index 3315c962..204e90f9 100644 --- a/cookbook/rl/gkd_off_policy.py +++ b/cookbook/rl/gkd_off_policy.py @@ -60,8 +60,8 @@ STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3-0.6B') TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-8B') -MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8)) -SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 8)) +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16)) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index f30df2ea..9c792eab 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -62,7 +62,7 @@ TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-8B') MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) -SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2)) NUM_GPUS = MODEL_GPUS + 2*SAMPLER_GPUS MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 2048)) diff --git a/cookbook/transformers/fsdp2.sh b/cookbook/transformers/fsdp2.sh index 46e9f27f..93c531a9 100644 --- a/cookbook/transformers/fsdp2.sh +++ b/cookbook/transformers/fsdp2.sh @@ -1 +1 @@ -CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 fsdp2.py +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 fsdp2.py diff --git a/cookbook/transformers/fsdp2_moe.py b/cookbook/transformers/fsdp2_moe.py index 23a53f4a..3ea649d3 100644 --- a/cookbook/transformers/fsdp2_moe.py +++ b/cookbook/transformers/fsdp2_moe.py @@ -20,7 +20,7 @@ def eval(model): # 100 Samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507') dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.encode() dataloader = DataLoader(dataset=dataset, batch_size=4) @@ -35,7 +35,7 @@ def train(): # 1000 samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) # Set template to prepare encoding - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507') # Preprocess the dataset to standard format dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) # Encode dataset @@ -43,7 +43,7 @@ def train(): # Global batch size = 4, for GPUs, so 1 sample per GPU dataloader = DataLoader(dataset=dataset, batch_size=8) # Use a TransformersModel, transformer_cls_names_to_wrap=Qwen3MoeSparseMoeBlock to avoid hang of fsdp2 - model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B', fsdp_config={'transformer_cls_names_to_wrap':['Qwen3MoeSparseMoeBlock']}) + model = TransformersModel(model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507', fsdp_config={'transformer_cls_names_to_wrap':['Qwen3MoeSparseMoeBlock']}) # Patch MoE model to fix the hang bug, support transformers==4.* model.apply_patch('ms://twinkle-kit/qwen3_moe_transformers4_patch') lora_config = LoraConfig( diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index ab464811..29450647 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -476,8 +476,12 @@ def calculate_loss(self, **kwargs): optimizer_config = self.optimizer_group[adapter_name] loss_instance: Loss = optimizer_config.loss_instance assert isinstance(loss_instance, Loss), 'Set a loss_instance before calculating loss' - inputs = optimizer_config.train_status.inputs - outputs = optimizer_config.train_status.outputs + if self.model.training: + status = optimizer_config.train_status + else: + status = optimizer_config.eval_status + inputs = status.inputs + outputs = status.outputs assert inputs is not None and outputs is not None, 'Cannot calculate loss of empty inputs and outputs' result = loss_instance(inputs, outputs, **kwargs) loss_value = result['loss'] @@ -505,9 +509,9 @@ def calculate_loss(self, **kwargs): if reduction is not None: self.sp_strategy.sp_config['loss_reduction'] = str(reduction) loss_value = self.sp_strategy.reduce_loss(loss_value, inputs['labels']) - optimizer_config.train_status.loss_value += loss_value - outputs['loss'] = optimizer_config.train_status.loss_value - return optimizer_config.train_status.loss_value.item() + status.loss_value += loss_value + outputs['loss'] = status.loss_value + return status.loss_value.item() @remote_function() def backward(self, **kwargs): diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index 3269a574..eb182bf5 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -130,7 +130,7 @@ def pad_cp_inputs(input_tensor: torch.Tensor, padding_value: int) -> torch.Tenso if input_tensor is None: return input_tensor - seq_len = input_tensor.shape[1] + seq_len = input_tensor.shape[-1] # Calculate required divisor based on parallelism settings if cp_size > 1: diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index ce487436..a1b7123e 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -291,8 +291,8 @@ async def sample(self, continue # Get logprob for the actual token - if i < len(prompt_token_ids): - token_id = prompt_token_ids[i] + if i < len(result.prompt_token_ids): + token_id = result.prompt_token_ids[i] if token_id in lp_dict: lp_obj = lp_dict[token_id] result_prompt_logprobs.append(lp_obj.logprob) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 4c3bc6de..97a684b0 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -166,6 +166,9 @@ def encode_trajectory_for_vllm(self, add_generation_prompt=add_generation_prompt, )[0] encoded['prompt'] = prompt['prompt'] + for key in encoded: + if isinstance(encoded[key], np.ndarray): + encoded[key] = encoded[key].tolist() return encoded def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs) -> None: @@ -235,7 +238,9 @@ async def _sample_single( """ multi_modal_data = self._extract_multi_modal_data(feat) response = await self.engine.sample( - prompt=feat['prompt'] if 'prompt' in feat else feat['input_ids'], + # pick input_ids because prompt may not contain response + # if vLLM are used sequentially + prompt=feat['input_ids'] if 'input_ids' in feat else feat['prompt'], sampling_params=sampling_params, lora_request=lora_request, multi_modal_data=multi_modal_data, diff --git a/src/twinkle/template/qwen3_5_vl.py b/src/twinkle/template/qwen3_5_vl.py index a7967777..3a8ffe9a 100644 --- a/src/twinkle/template/qwen3_5_vl.py +++ b/src/twinkle/template/qwen3_5_vl.py @@ -26,9 +26,11 @@ def __init__(self, *args, **kwargs): self._patch_size: Optional[int] = None self._merge_size: Optional[int] = None self._init_vision_config() - from transformers.models.qwen3_vl import Qwen3VLModel with torch.device('meta'): - self.dummy_model = Qwen3VLModel(self.config) + import transformers + model_cls = self.config.architectures[0] + model_cls = getattr(transformers, model_cls) + self.dummy_model = model_cls(self.config) self.rope_index_func = self.get_rope_index() def get_rope_index(self): From d55f90adfe4742b8f86ff62069a37767f6e603c9 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 11 Apr 2026 00:37:07 +0800 Subject: [PATCH 2/8] fix --- cookbook/ray/run.sh | 1 - cookbook/ray/single_controller.py | 91 ------------------- .../model/transformers/transformers.py | 2 +- .../sampler/vllm_sampler/vllm_sampler.py | 3 +- 4 files changed, 3 insertions(+), 94 deletions(-) delete mode 100644 cookbook/ray/run.sh delete mode 100644 cookbook/ray/single_controller.py diff --git a/cookbook/ray/run.sh b/cookbook/ray/run.sh deleted file mode 100644 index bbf8a400..00000000 --- a/cookbook/ray/run.sh +++ /dev/null @@ -1 +0,0 @@ -python3 single_controller.py diff --git a/cookbook/ray/single_controller.py b/cookbook/ray/single_controller.py deleted file mode 100644 index edb8d8e6..00000000 --- a/cookbook/ray/single_controller.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -from peft import LoraConfig -from tqdm import tqdm - -import twinkle -from twinkle import DeviceGroup, DeviceMesh, Platform, get_device_placement, get_logger -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.model import TransformersModel -from twinkle.preprocessor import SelfCognitionProcessor - -device_group = [DeviceGroup( - name='default', - ranks=8, - device_type='cuda', -)] - -# Construct a device_mesh, fsdp=4, dp=2 -device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2) -# use ray mode -twinkle.initialize(mode='ray', groups=device_group, global_device_mesh=device_mesh) - -logger = get_logger() - - -def eval(model): - # 100 Samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-35B-A3B') - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) - dataset.encode() - dataloader = DataLoader(dataset=dataset, batch_size=8, min_batch_size=8) - for step, batch in tqdm(enumerate(dataloader)): - model.forward_only(inputs=batch) - model.calculate_loss() - metrics = model.calculate_metric(is_training=False) - return metrics - - -def train(): - # 1000 samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) - # Set template to prepare encoding - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') - # Preprocess the dataset to standard format - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) - # Encode dataset - dataset.encode() - # Global batch size = 8, for GPUs, so 1 sample per GPU - dataloader = DataLoader(dataset=dataset, batch_size=8, min_batch_size=8) - # Use a TransformersModel - model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B', remote_group='default') - - lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') - - # Add a lora to model, with name `default` - # Comment this to use full-parameter training - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) - # Add Optimizer for lora `default` - model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) - # Add LRScheduler for lora `default` - model.set_lr_scheduler( - scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) - logger.info(get_device_placement()) - # Print the training config - logger.info(model.get_train_configs()) - logger.info(f'Total steps: {len(dataloader)}') - loss_metric = 99.0 - # lora: 18G * 4 - # full: 50G * 4 - for step, batch in enumerate(dataloader): - # Do forward and backward - model.forward_backward(inputs=batch) - # Step - model.clip_grad_and_step() - if step % 20 == 0: - # Print metric - metric = model.calculate_metric(is_training=True) - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - if step > 0 and step % 40 == 0: - metrics = eval(model) - logger.info(f'Eval metric: {metrics}') - metrics['step'] = step - if loss_metric > float(metrics['loss']): - model.save(f'checkpoint-{step}') - loss_metric = float(metrics['loss']) - model.save(f'last-checkpoint') - - -if __name__ == '__main__': - train() diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 29450647..390666d6 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -503,7 +503,7 @@ def calculate_loss(self, **kwargs): # = global_per_token_grad / dp_world_size = avg_per_token_grad counts = counts / self.device_mesh.data_world_size optimizer_config = self.optimizer_group[adapter_name] - optimizer_config.train_status.num_tokens += counts.item() + status.num_tokens += counts.item() if self.sp_strategy is not None and 'labels' in inputs: reduction = getattr(loss_instance, 'reduction', None) if reduction is not None: diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 97a684b0..70b4cf54 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -240,7 +240,8 @@ async def _sample_single( response = await self.engine.sample( # pick input_ids because prompt may not contain response # if vLLM are used sequentially - prompt=feat['input_ids'] if 'input_ids' in feat else feat['prompt'], + # multi-modal does not support input_ids + prompt=feat['input_ids'] if 'input_ids' in feat and len(multi_modal_data) == 0 else feat['prompt'], sampling_params=sampling_params, lora_request=lora_request, multi_modal_data=multi_modal_data, From 362b7c85644119127d8ca02d4794a7ae6637fbc0 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 11 Apr 2026 00:52:38 +0800 Subject: [PATCH 3/8] fix --- .../sampler/vllm_sampler/vllm_sampler.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 70b4cf54..b52aa6f6 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -221,6 +221,7 @@ async def _sample_single( sampling_params: SamplingParams, lora_request: Optional[Any] = None, *, + multi_modal_data: Optional[Dict[str, Any]] = None, logprobs_only: bool = False, ) -> SampleResponse: """Sample a single input asynchronously. @@ -231,23 +232,23 @@ async def _sample_single( adapter_path: Optional LoRA adapter path (legacy, prefer lora_request). lora_request: Pre-built LoRARequest to attach to the sampling request. Avoids repeated ``_get_or_load_lora`` calls per input. + multi_modal_data: The multi modal data dict. logprobs_only: Only return logprobs (no generated tokens). Returns: A SampleResponse object """ - multi_modal_data = self._extract_multi_modal_data(feat) response = await self.engine.sample( - # pick input_ids because prompt may not contain response + # Pick input_ids first because prompt may not contain response # if vLLM are used sequentially # multi-modal does not support input_ids - prompt=feat['input_ids'] if 'input_ids' in feat and len(multi_modal_data) == 0 else feat['prompt'], + prompt=feat['input_ids'] if 'input_ids' in feat and multi_modal_data else feat['prompt'], sampling_params=sampling_params, lora_request=lora_request, multi_modal_data=multi_modal_data, mm_processor_kwargs=feat.get('mm_processor_kwargs'), ) - if 'input_ids' not in feat: + if 'input_ids' not in feat or multi_modal_data: feat['input_ids'] = response.prompt_token_ids feat['labels'] = [-100] * len(response.prompt_token_ids) if not logprobs_only: @@ -331,7 +332,11 @@ def sample( sampling_params.max_tokens = 1 logprobs_only = True - if is_trajectory: + multi_modal_data_list = [] + for feat in inputs_list: + multi_modal_data_list.append(self._extract_multi_modal_data(feat)) + + if is_trajectory or any(multi_modal_data_list): template = self.template assert template is not None, \ 'Use set_template to add a template when trying to input Trajectory' @@ -355,8 +360,9 @@ async def _sample_all(): feat, sampling_params, lora_request=lora_request, + multi_modal_data=multi_modal_data, logprobs_only=logprobs_only, - ) for feat in encoded_inputs + ) for feat, multi_modal_data in zip(encoded_inputs, multi_modal_data_list) ] return await asyncio.gather(*tasks) From e8f2ffcf818490e89e250fd1b1a8ae0770e36e06 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 11 Apr 2026 12:23:49 +0800 Subject: [PATCH 4/8] fix --- cookbook/rl/gkd_on_policy.py | 86 +++++++++++++------ src/twinkle/infra/__init__.py | 15 ++++ .../sampler/vllm_sampler/vllm_sampler.py | 2 +- 3 files changed, 74 insertions(+), 29 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index 9c792eab..79ba33ea 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -1,8 +1,13 @@ -"""GKD On-Policy Distillation via Ray. +"""GKD On-Policy Multimodal Distillation via Ray. -On-policy knowledge distillation: student vLLM generates responses, -teacher vLLM provides top-k prompt logprobs, then student model learns -to match the teacher's token distribution. +On-policy knowledge distillation on OlympiadBench multimodal math/physics: +student vLLM generates responses, teacher vLLM provides top-k prompt logprobs, +then student model learns to match the teacher's token distribution. + +Supports three OlympiadBench subsets: +- OE_MM_maths_zh_CEE: Multimodal math problems (Chinese CEE) +- OE_MM_physics_zh_CEE: Multimodal physics problems (Chinese CEE) +- OE_TO_maths_zh_CEE: Text-only math problems (Chinese CEE) Pipeline: 1. Sync student model weights to student vLLM sampler. @@ -23,8 +28,8 @@ student + teacher (model GPUs) Environment variables (all optional): - STUDENT_MODEL_ID – (default: ms://Qwen/Qwen3-0.6B) - TEACHER_MODEL_ID – (default: ms://Qwen/Qwen3-8B) + STUDENT_MODEL_ID – (default: ms://Qwen/Qwen3.5-4B) + TEACHER_MODEL_ID – (default: ms://Qwen/Qwen3.5-9B) MODEL_GPUS – GPUs for student model (default: 4) SAMPLER_GPUS – GPUs for each vLLM sampler (default: 4) MAX_NEW_TOKENS – max completion tokens (default: 2048) @@ -48,25 +53,24 @@ from twinkle.checkpoint_engine import CheckpointEngineManager from twinkle.data_format import SamplingParams from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta +from twinkle.dataset import DatasetMeta, LazyDataset from twinkle.loss import GKDLoss -from twinkle.model import TransformersModel -from twinkle.preprocessor import GSM8KProcessor +from twinkle.model import MegatronModel +from twinkle.preprocessor.olympiad_bench import OlympiadBenchProcessor from twinkle.sampler import vLLMSampler -from twinkle.template import Template logger = get_logger() # ── Configuration ───────────────────────────────────────────────────────────── -STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3-0.6B') -TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3-8B') +STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3.5-4B') +TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3.5-9B') MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2)) NUM_GPUS = MODEL_GPUS + 2*SAMPLER_GPUS MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 2048)) -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) LEARNING_RATE = float(os.environ.get('LR', 5e-5)) N_SAMPLES = int(os.environ.get('N_SAMPLES', 1)) @@ -74,18 +78,31 @@ GKD_BETA = float(os.environ.get('GKD_BETA', 0.5)) GKD_TEMPERATURE = float(os.environ.get('GKD_TEMPERATURE', 1.0)) GKD_TOPK = int(os.environ.get('GKD_TOPK', 64)) -SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem step by step and put ' - 'your final answer within #### ') ADAPTER_NAME = 'default' +# OlympiadBench subsets +SUBSETS = [ + 'OE_MM_maths_zh_CEE', + 'OE_MM_physics_zh_CEE', + 'OE_TO_maths_zh_CEE', +] + # ── Dataset ─────────────────────────────────────────────────────────────────── def create_dataset(): - """Prompt-only dataset; student vLLM will generate completions on-policy.""" - dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) - dataset.set_template('Template', model_id=STUDENT_MODEL_ID, max_length=1024) - dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT)) + """OlympiadBench multimodal dataset; student vLLM will generate completions on-policy.""" + ds = DatasetMeta('ms://AI-ModelScope/OlympiadBench', subset_name=SUBSETS[0], split='train') + dataset = LazyDataset(ds) + dataset.map(OlympiadBenchProcessor(language='zh'), dataset_meta=ds) + + for subset in SUBSETS[1:]: + ds = DatasetMeta('ms://AI-ModelScope/OlympiadBench', subset_name=subset, split='train') + dataset.add_dataset(ds) + dataset.map(OlympiadBenchProcessor(language='zh'), dataset_meta=ds) + + dataset.set_template('Qwen3_5Template', model_id=STUDENT_MODEL_ID, max_length=2048, enable_thinking=False) + dataset.mix_dataset(interleave=True) return dataset @@ -155,7 +172,7 @@ def main(): ) # ── Student model (trainable) ────────────────────────────────────────────── - student_model = TransformersModel( + student_model = MegatronModel( model_id=STUDENT_MODEL_ID, device_mesh=model_mesh, remote_group='student_model', @@ -165,10 +182,10 @@ def main(): LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules='all-linear'), gradient_accumulation_steps=1, ) - student_model.set_optimizer('AdamW', lr=LEARNING_RATE) - student_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0) + student_model.set_optimizer('default', lr=LEARNING_RATE) + student_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS) student_model.set_loss(GKDLoss(beta=GKD_BETA, temperature=GKD_TEMPERATURE)) - student_model.set_template('Template', model_id=STUDENT_MODEL_ID) + student_model.set_template('Qwen3_5Template', model_id=STUDENT_MODEL_ID, enable_thinking=False) # ── Student vLLM sampler (for on-policy generation) ──────────────────────── student_sampler = vLLMSampler( @@ -176,20 +193,33 @@ def main(): # enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False) # meaning only sync lora weights, if merge_and_sync=True, # lora will be merged into the base model and sync all weights to vLLM - engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 4096, 'enable_lora': True, 'max_loras': 1}, + engine_args={ + 'gpu_memory_utilization': 0.75, + 'max_model_len': 8192, + 'enable_lora': True, + 'max_loras': 1, + 'limit_mm_per_prompt': {'image': 3}, + 'enable_tower_connector_lora': True, + }, device_mesh=sampler_mesh, remote_group='student_sampler', ) - student_sampler.set_template(Template, model_id=STUDENT_MODEL_ID) + student_sampler.set_template('Qwen3_5Template', model_id=STUDENT_MODEL_ID, enable_thinking=False) # ── Teacher vLLM sampler (for prompt logprobs) ─────────────────────────────── teacher_sampler = vLLMSampler( model_id=TEACHER_MODEL_ID, - engine_args={'gpu_memory_utilization': 0.85, 'max_model_len': 4096, 'logprobs_mode': 'raw_logprobs', 'max_logprobs': 64}, + engine_args={ + 'gpu_memory_utilization': 0.75, + 'max_model_len': 8192, + 'logprobs_mode': 'raw_logprobs', + 'max_logprobs': 64, + 'limit_mm_per_prompt': {'image': 3}, + }, device_mesh=sampler_mesh, remote_group='teacher_sampler', ) - teacher_sampler.set_template(Template, model_id=TEACHER_MODEL_ID) + teacher_sampler.set_template('Qwen3_5Template', model_id=TEACHER_MODEL_ID, enable_thinking=False) # ── DataLoader (prompt-only) ─────────────────────────────────────────────── dataloader = DataLoader( @@ -239,7 +269,7 @@ def main(): student_model.forward_backward(inputs=input_data, **teacher_output) student_model.clip_grad_and_step() - if optim_step > 0 and optim_step % 10 == 0: + if optim_step > 0 and optim_step % 1 == 0: metric = student_model.calculate_metric(is_training=True) logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}') diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index f25f0fa3..a0af7cf4 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -367,6 +367,21 @@ def dispatch_func(arg, n): sliced_args = tuple(arg[i] for arg in args) sliced_kwargs = {k: v[i] for k, v in kwargs.items()} result.append((workers[i], sliced_args, sliced_kwargs)) + + # Raise early if some ranks got data and others didn't (causes hangs). + def _check_uniform(slices): + lens = [len(s) if s is not None and isinstance(s, (list, tuple)) else 0 for s in slices] + return not lens or all(l > 0 for l in lens) or all(l == 0 for l in lens) + + for arg in args: + if not _check_uniform(arg): + raise ValueError(f'Batch too small for {length} workers, some ranks have no data. ' + f'Please increase batch size to at least {length}.') + for v in kwargs.values(): + if not _check_uniform(v): + raise ValueError(f'Batch too small for {length} workers, some ranks have no data. ' + f'Please increase batch size to at least {length}.') + return result elif isinstance(dispatch, Callable): length = len(workers) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index b52aa6f6..5c7d5be3 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -242,7 +242,7 @@ async def _sample_single( # Pick input_ids first because prompt may not contain response # if vLLM are used sequentially # multi-modal does not support input_ids - prompt=feat['input_ids'] if 'input_ids' in feat and multi_modal_data else feat['prompt'], + prompt=feat['input_ids'] if 'input_ids' in feat and not multi_modal_data else feat['prompt'], sampling_params=sampling_params, lora_request=lora_request, multi_modal_data=multi_modal_data, From ab8e686431be5402136ba6f65963bd2a65c736fe Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 11 Apr 2026 12:45:54 +0800 Subject: [PATCH 5/8] fix --- cookbook/rl/gkd_on_policy.py | 56 ++++++++++++------- .../model/transformers/transformers.py | 5 ++ 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index 79ba33ea..7a238775 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -55,7 +55,7 @@ from twinkle.dataloader import DataLoader from twinkle.dataset import DatasetMeta, LazyDataset from twinkle.loss import GKDLoss -from twinkle.model import MegatronModel +from twinkle.model import TransformersModel from twinkle.preprocessor.olympiad_bench import OlympiadBenchProcessor from twinkle.sampler import vLLMSampler @@ -64,6 +64,7 @@ # ── Configuration ───────────────────────────────────────────────────────────── STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3.5-4B') TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3.5-9B') +USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '0'))) MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2)) @@ -172,20 +173,35 @@ def main(): ) # ── Student model (trainable) ────────────────────────────────────────────── - student_model = MegatronModel( - model_id=STUDENT_MODEL_ID, - device_mesh=model_mesh, - remote_group='student_model', - ) - student_model.add_adapter_to_model( - ADAPTER_NAME, - LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules='all-linear'), - gradient_accumulation_steps=1, - ) - student_model.set_optimizer('default', lr=LEARNING_RATE) - student_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS) - student_model.set_loss(GKDLoss(beta=GKD_BETA, temperature=GKD_TEMPERATURE)) - student_model.set_template('Qwen3_5Template', model_id=STUDENT_MODEL_ID, enable_thinking=False) + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules='all-linear') + + if USE_MEGATRON: + from twinkle.model.megatron import MegatronModel + student_model = MegatronModel( + model_id=STUDENT_MODEL_ID, + device_mesh=model_mesh, + remote_group='student_model', + ) + else: + from transformers import Qwen3_5ForConditionalGeneration + student_model = TransformersModel( + model_id=STUDENT_MODEL_ID, + model_cls=Qwen3_5ForConditionalGeneration, + device_mesh=model_mesh, + remote_group='student_model', + ) + + student_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1) + + if USE_MEGATRON: + student_model.set_optimizer('default', lr=LEARNING_RATE, adapter_name=ADAPTER_NAME) + student_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE, adapter_name=ADAPTER_NAME) + else: + student_model.set_optimizer('AdamW', lr=LEARNING_RATE) + student_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0) + + student_model.set_loss(GKDLoss(beta=GKD_BETA, temperature=GKD_TEMPERATURE), adapter_name=ADAPTER_NAME) + student_model.set_template('Qwen3_5Template', model_id=STUDENT_MODEL_ID, adapter_name=ADAPTER_NAME, enable_thinking=False) # ── Student vLLM sampler (for on-policy generation) ──────────────────────── student_sampler = vLLMSampler( @@ -266,19 +282,19 @@ def main(): ) # 5. Student forward + GKD backward - student_model.forward_backward(inputs=input_data, **teacher_output) - student_model.clip_grad_and_step() + student_model.forward_backward(inputs=input_data, adapter_name=ADAPTER_NAME, **teacher_output) + student_model.clip_grad_and_step(adapter_name=ADAPTER_NAME) if optim_step > 0 and optim_step % 1 == 0: - metric = student_model.calculate_metric(is_training=True) + metric = student_model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME) logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metric}') if optim_step > 0 and optim_step % 50 == 0: - student_model.save(f'gkd-onpolicy-ckpt-{optim_step}') + student_model.save(f'gkd-onpolicy-ckpt-{optim_step}', adapter_name=ADAPTER_NAME) optim_step += 1 - student_model.save('gkd-onpolicy-final') + student_model.save('gkd-onpolicy-final', adapter_name=ADAPTER_NAME) logger.info('GKD on-policy training completed.') diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 390666d6..7d8022d6 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -1167,6 +1167,7 @@ def send_weights( adapter_name: str = None, base_sync_done: bool = False, merge_and_sync: bool = False, + model_keys: List[str] = None, **kwargs, ): if adapter_name is None: @@ -1179,6 +1180,10 @@ def _normalize(name: str, keep_base_layer: bool) -> str: name = name.replace('base_model.model.', '') if not keep_base_layer: name = name.replace('.base_layer', '') + else: + if 'conv1d.weight' in name: + if any('conv1d.base_layer.weight' in name for name in model_keys): + name = name.replace('conv1d.weight', 'conv1d.base_layer.weight') return name def _print_weight_example(names): From 5412a64273ca030effd6e5b350c33777c09dfa7c Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 11 Apr 2026 20:19:57 +0800 Subject: [PATCH 6/8] wip --- cookbook/rl/gkd_on_policy.py | 2 +- .../sampler/vllm_sampler/vllm_sampler.py | 13 +++++++--- src/twinkle/template/base.py | 25 ++++++++++++++++--- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/cookbook/rl/gkd_on_policy.py b/cookbook/rl/gkd_on_policy.py index 7a238775..9b44e559 100644 --- a/cookbook/rl/gkd_on_policy.py +++ b/cookbook/rl/gkd_on_policy.py @@ -64,7 +64,7 @@ # ── Configuration ───────────────────────────────────────────────────────────── STUDENT_MODEL_ID = os.environ.get('STUDENT_MODEL_ID', 'ms://Qwen/Qwen3.5-4B') TEACHER_MODEL_ID = os.environ.get('TEACHER_MODEL_ID', 'ms://Qwen/Qwen3.5-9B') -USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '0'))) +USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1'))) MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 2)) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 5c7d5be3..01f81b47 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -248,9 +248,15 @@ async def _sample_single( multi_modal_data=multi_modal_data, mm_processor_kwargs=feat.get('mm_processor_kwargs'), ) + if 'input_ids' not in feat or multi_modal_data: - feat['input_ids'] = response.prompt_token_ids - feat['labels'] = [-100] * len(response.prompt_token_ids) + if 'input_ids' in feat: + if len(feat['input_ids']) != len(response.prompt_token_ids): + breakpoint() + raise RuntimeError(f'Input ids length {len(feat["input_ids"])} does not match prompt_token_ids length {len(response.prompt_token_ids)}') + else: + feat['input_ids'] = response.prompt_token_ids + feat['labels'] = [-100] * len(response.prompt_token_ids) if not logprobs_only: # response.sequences contains num_samples sequences for this prompt sequences = [] @@ -331,12 +337,13 @@ def sample( if sampling_params.max_tokens == 0: sampling_params.max_tokens = 1 logprobs_only = True + assert not is_trajectory, 'Logprobs only not supported for Trajectory inputs' multi_modal_data_list = [] for feat in inputs_list: multi_modal_data_list.append(self._extract_multi_modal_data(feat)) - if is_trajectory or any(multi_modal_data_list): + if is_trajectory or any(multi_modal_data_list) and not logprobs_only: template = self.template assert template is not None, \ 'Use set_template to add a template when trying to input Trajectory' diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index ce9abd6a..41f5c51d 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -162,8 +162,11 @@ def concat_input_feature(self, prompt_input_feature: InputFeature, new_tokens: L assert self.truncation_strategy != 'split', 'concat_input_feature does not support `truncation_strategy=split`' result = copy.deepcopy(prompt_input_feature) prompt_ids = result['input_ids'] + labels = result['labels'] input_ids = list(prompt_ids) + new_tokens - labels = [-100] * len(prompt_ids) + new_tokens + labels = labels[-1:] + labels[:-1] # roll to input order + labels = labels + new_tokens + labels = labels[1:] + labels[:1] # roll to input-1 order result['input_ids'] = input_ids result['labels'] = labels if 'mm_token_type_ids' in result: @@ -181,6 +184,23 @@ def concat_input_feature(self, prompt_input_feature: InputFeature, new_tokens: L response_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) messages.append(Message(role='assistant', content=response_text)) result['messages'] = messages + prompt = self.batch_encode( + [result], + add_generation_prompt=False, + tokenize=False, + )[0] # regenerate prompt + + decoded = self.decode(input_ids[-10]) + suffix = ['\n'] + # Models like qwen3 may end with \n, this will mismatch + # with the vLLM output, which endwith + if self.tokenizer.eos_token: + suffix += [self.tokenizer.eos_token] + _prompt = prompt['prompt'] + for token in suffix: + if _prompt.endswith(token) and not decoded.endswith(token): + _prompt = _prompt[:-len(token)] + result['prompt'] = _prompt return result def _add_default_system(self, trajectory: Trajectory) -> List[Trajectory]: @@ -467,8 +487,6 @@ def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bo apply_chat_template_kwargs['tokenize'] = True # Set default values for processor_kwargs - if 'enable_thinking' not in kwargs: - processor_kwargs['enable_thinking'] = self.enable_thinking if 'padding' not in kwargs: processor_kwargs['padding'] = False @@ -482,6 +500,7 @@ def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bo add_generation_prompt=add_generation_prompt, return_tensors='pt', processor_kwargs=processor_kwargs, + enable_thinking=self.enable_thinking, **apply_chat_template_kwargs) else: # No processor_kwargs support, pass all kwargs directly From 13dac86cdb961445c9474c056fdfea604150ec4c Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 11 Apr 2026 21:22:41 +0800 Subject: [PATCH 7/8] fix --- cookbook/rl/short_math_grpo.py | 6 ++-- src/twinkle/infra/__init__.py | 2 +- .../sampler/vllm_sampler/vllm_sampler.py | 20 ++++--------- src/twinkle/template/base.py | 24 ++++----------- src/twinkle/template/qwen3_5_vl.py | 29 +++++++++++++++++++ 5 files changed, 43 insertions(+), 38 deletions(-) diff --git a/cookbook/rl/short_math_grpo.py b/cookbook/rl/short_math_grpo.py index bbfda68b..7637dae6 100644 --- a/cookbook/rl/short_math_grpo.py +++ b/cookbook/rl/short_math_grpo.py @@ -87,7 +87,7 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: # ========== Dataset ========== def create_gsm8k_dataset(): dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) - dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=4096, truncation_strategy='delete', enable_thinking=False) + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=4096, truncation_strategy='delete', enable_thinking=True) dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT)) dataset.encode(add_generation_prompt=True) return dataset @@ -153,7 +153,7 @@ def main(): model.set_loss('GRPOLoss', epsilon=0.2) model.set_processor(InputProcessor) - model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False) + model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=True) sampler = vLLMSampler( model_id=MODEL_ID, @@ -170,7 +170,7 @@ def main(): device_mesh=sampler_mesh, remote_group='sampler', ) - sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False) + sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=True) ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler) diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index a0af7cf4..8462d06b 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -371,7 +371,7 @@ def dispatch_func(arg, n): # Raise early if some ranks got data and others didn't (causes hangs). def _check_uniform(slices): lens = [len(s) if s is not None and isinstance(s, (list, tuple)) else 0 for s in slices] - return not lens or all(l > 0 for l in lens) or all(l == 0 for l in lens) + return not lens or all(length > 0 for length in lens) or all(length == 0 for length in lens) for arg in args: if not _check_uniform(arg): diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 01f81b47..70f9121e 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -155,17 +155,10 @@ def encode_trajectory_for_vllm(self, template = self.template if template is None: raise ValueError(f"Template not set for adapter '{adapter_name}'. Use set_template() first.") - - prompt = template.batch_encode( - [trajectory], - add_generation_prompt=add_generation_prompt, - tokenize=False, - )[0] encoded = template.batch_encode( [trajectory], add_generation_prompt=add_generation_prompt, )[0] - encoded['prompt'] = prompt['prompt'] for key in encoded: if isinstance(encoded[key], np.ndarray): encoded[key] = encoded[key].tolist() @@ -239,10 +232,7 @@ async def _sample_single( A SampleResponse object """ response = await self.engine.sample( - # Pick input_ids first because prompt may not contain response - # if vLLM are used sequentially - # multi-modal does not support input_ids - prompt=feat['input_ids'] if 'input_ids' in feat and not multi_modal_data else feat['prompt'], + prompt=self.template.get_vllm_input_ids(feat['input_ids']), sampling_params=sampling_params, lora_request=lora_request, multi_modal_data=multi_modal_data, @@ -252,8 +242,8 @@ async def _sample_single( if 'input_ids' not in feat or multi_modal_data: if 'input_ids' in feat: if len(feat['input_ids']) != len(response.prompt_token_ids): - breakpoint() - raise RuntimeError(f'Input ids length {len(feat["input_ids"])} does not match prompt_token_ids length {len(response.prompt_token_ids)}') + raise RuntimeError(f'Input ids length {len(feat["input_ids"])} does not' + f'match prompt_token_ids length {len(response.prompt_token_ids)}') else: feat['input_ids'] = response.prompt_token_ids feat['labels'] = [-100] * len(response.prompt_token_ids) @@ -332,7 +322,7 @@ def sample( inputs_list = self._normalize_inputs(inputs) # Check if inputs are Trajectory (not encoded) - aligned with Model.forward logic - is_trajectory = 'prompt' not in inputs_list[0] and 'input_ids' not in inputs_list[0] + is_trajectory = 'input_ids' not in inputs_list[0] logprobs_only = False if sampling_params.max_tokens == 0: sampling_params.max_tokens = 1 @@ -343,7 +333,7 @@ def sample( for feat in inputs_list: multi_modal_data_list.append(self._extract_multi_modal_data(feat)) - if is_trajectory or any(multi_modal_data_list) and not logprobs_only: + if is_trajectory and not logprobs_only: template = self.template assert template is not None, \ 'Use set_template to add a template when trying to input Trajectory' diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 41f5c51d..5945328e 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -164,9 +164,9 @@ def concat_input_feature(self, prompt_input_feature: InputFeature, new_tokens: L prompt_ids = result['input_ids'] labels = result['labels'] input_ids = list(prompt_ids) + new_tokens - labels = labels[-1:] + labels[:-1] # roll to input order + labels = labels[-1:] + labels[:-1] # roll to input order labels = labels + new_tokens - labels = labels[1:] + labels[:1] # roll to input-1 order + labels = labels[1:] + labels[:1] # roll to input-1 order result['input_ids'] = input_ids result['labels'] = labels if 'mm_token_type_ids' in result: @@ -184,23 +184,6 @@ def concat_input_feature(self, prompt_input_feature: InputFeature, new_tokens: L response_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) messages.append(Message(role='assistant', content=response_text)) result['messages'] = messages - prompt = self.batch_encode( - [result], - add_generation_prompt=False, - tokenize=False, - )[0] # regenerate prompt - - decoded = self.decode(input_ids[-10]) - suffix = ['\n'] - # Models like qwen3 may end with \n, this will mismatch - # with the vLLM output, which endwith - if self.tokenizer.eos_token: - suffix += [self.tokenizer.eos_token] - _prompt = prompt['prompt'] - for token in suffix: - if _prompt.endswith(token) and not decoded.endswith(token): - _prompt = _prompt[:-len(token)] - result['prompt'] = _prompt return result def _add_default_system(self, trajectory: Trajectory) -> List[Trajectory]: @@ -254,6 +237,9 @@ def _truncate_feature(self, feature: InputFeature, strategy: str) -> InputFeatur def set_mm_position_ids(self, input_feature: InputFeature): return np.arange(len(input_feature['input_ids'])) + def get_vllm_input_ids(input_ids): + return input_ids + def _check_max_length(self, input_feature: InputFeature) -> List[InputFeature]: if not self.max_length or 'input_ids' not in input_feature: return [input_feature] diff --git a/src/twinkle/template/qwen3_5_vl.py b/src/twinkle/template/qwen3_5_vl.py index 3a8ffe9a..22799bab 100644 --- a/src/twinkle/template/qwen3_5_vl.py +++ b/src/twinkle/template/qwen3_5_vl.py @@ -130,6 +130,35 @@ def set_mm_position_ids(self, input_feature: InputFeature): **kwargs) return self._concat_text_position_ids(position_ids) + def get_vllm_input_ids(self, input_ids): + """Collapse each ... group + into (single pad token).""" + image_token_id = self.config.image_token_id + vision_start_id = self.config.vision_start_token_id + vision_end_id = self.config.vision_end_token_id + + result = [] + i = 0 + while i < len(input_ids): + if input_ids[i] == vision_start_id: + result.append(vision_start_id) + i += 1 + # Skip all consecutive image_pad tokens, keep only one + found_pad = False + while i < len(input_ids) and input_ids[i] == image_token_id: + if not found_pad: + result.append(image_token_id) + found_pad = True + i += 1 + # Append vision_end if present + if i < len(input_ids) and input_ids[i] == vision_end_id: + result.append(vision_end_id) + i += 1 + else: + result.append(input_ids[i]) + i += 1 + return result + @staticmethod def _concat_text_position_ids(position_ids): seq_len = position_ids.shape[-1] From 4fc96dde2d6c577af9887147a28d05b114bcc495 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 11 Apr 2026 21:34:26 +0800 Subject: [PATCH 8/8] fix --- src/twinkle/model/transformers/transformers.py | 2 +- src/twinkle/template/base.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 7d8022d6..210ea3a0 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -1182,7 +1182,7 @@ def _normalize(name: str, keep_base_layer: bool) -> str: name = name.replace('.base_layer', '') else: if 'conv1d.weight' in name: - if any('conv1d.base_layer.weight' in name for name in model_keys): + if model_keys and any('conv1d.base_layer.weight' in name for name in model_keys): name = name.replace('conv1d.weight', 'conv1d.base_layer.weight') return name diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 5945328e..f76afd5e 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -162,7 +162,7 @@ def concat_input_feature(self, prompt_input_feature: InputFeature, new_tokens: L assert self.truncation_strategy != 'split', 'concat_input_feature does not support `truncation_strategy=split`' result = copy.deepcopy(prompt_input_feature) prompt_ids = result['input_ids'] - labels = result['labels'] + labels = list(result['labels']) input_ids = list(prompt_ids) + new_tokens labels = labels[-1:] + labels[:-1] # roll to input order labels = labels + new_tokens @@ -237,7 +237,7 @@ def _truncate_feature(self, feature: InputFeature, strategy: str) -> InputFeatur def set_mm_position_ids(self, input_feature: InputFeature): return np.arange(len(input_feature['input_ids'])) - def get_vllm_input_ids(input_ids): + def get_vllm_input_ids(self, input_ids): return input_ids def _check_max_length(self, input_feature: InputFeature) -> List[InputFeature]: