From d638074f3efca3f28f35eedddbd3b18f5e8c33cf Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 18 Apr 2026 21:35:05 +0800 Subject: [PATCH 1/6] wip --- src/twinkle/data_format/message.py | 2 +- src/twinkle/loss/dpo.py | 12 ++--- src/twinkle/metric/dpo.py | 2 - .../model/transformers/transformers.py | 50 +++++++++++-------- 4 files changed, 37 insertions(+), 29 deletions(-) diff --git a/src/twinkle/data_format/message.py b/src/twinkle/data_format/message.py index 236f01d9..42d7afc8 100644 --- a/src/twinkle/data_format/message.py +++ b/src/twinkle/data_format/message.py @@ -2,7 +2,7 @@ import sys from typing import Any, Dict, List, Literal, Optional, Union -if sys.version_info <= (3, 11): +if sys.version_info[:2] <= (3, 11): # Pydantic requirements. from typing_extensions import TypedDict else: diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index d50837a0..fe526ab4 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -261,7 +261,7 @@ def __call__( Args: inputs: Dict containing 'input_ids' and 'labels' [batch, seq_len]. - Batch should be organized as [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + Batch should be interleaved as [chosen_1, rejected_1, chosen_2, rejected_2, ...] outputs: Dict containing either: - 'logps': [batch, seq_len] pre-computed log probs, OR - 'logits': [batch, seq_len, vocab] from which logps will be computed @@ -535,11 +535,11 @@ def __call__( # Odds ratio: log(odds_chosen / odds_rejected) # log_odds = log(p/(1-p)) = log(p) - log(1-p) - # Use numerically stable computation - prob_chosen = torch.exp(chosen_avg_logps).clamp(min=1e-7, max=1 - 1e-7) - prob_rejected = torch.exp(rejected_avg_logps).clamp(min=1e-7, max=1 - 1e-7) - log_odds_chosen = torch.log(prob_chosen) - torch.log(1 - prob_chosen) - log_odds_rejected = torch.log(prob_rejected) - torch.log(1 - prob_rejected) + # Compute entirely in log-space to avoid exp() underflow: + # log(p) = avg_logps (already in log-space) + # log(1-p) = log1p(-exp(avg_logps)) (numerically stable via log1p) + log_odds_chosen = chosen_avg_logps - torch.log1p(-torch.exp(chosen_avg_logps)) + log_odds_rejected = rejected_avg_logps - torch.log1p(-torch.exp(rejected_avg_logps)) # ORPO odds ratio loss odds_ratio = log_odds_chosen - log_odds_rejected diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py index 8a1d4d6c..49668bab 100644 --- a/src/twinkle/metric/dpo.py +++ b/src/twinkle/metric/dpo.py @@ -191,8 +191,6 @@ def calculate(self): total_count = sum(r['count'] for r in all_results) has_rewards = any(r['has_rewards'] for r in all_results) - self.reset() - if total_count == 0: return {} diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 210ea3a0..8f545cc4 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -444,22 +444,22 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T if self.sp_strategy is not None and labels is None: outputs = self.sp_strategy.postprocess_outputs(outputs) inputs['labels'] = labels - optimizer_config.eval_status.inputs = inputs - optimizer_config.eval_status.outputs = outputs - optimizer_config.eval_status.forward_kwargs = kwargs - optimizer_config.eval_status.loss_value = outputs.get('aux_loss', 0) - if labels is not None: - loss_mask = (labels != -100).bool() - masked_labels = labels.clone() - masked_labels[~loss_mask] = 0 - logits = outputs['logits'] - logits.div_(temperature) - outputs['logps'] = selective_log_softmax(logits, masked_labels) - outputs = copy(outputs) - outputs['past_key_values'] = None - if not return_logits: - outputs['logits'] = None - return outputs + optimizer_config.eval_status.inputs = inputs + optimizer_config.eval_status.outputs = outputs + optimizer_config.eval_status.forward_kwargs = kwargs + optimizer_config.eval_status.loss_value = outputs.get('aux_loss', 0) + if labels is not None: + loss_mask = (labels != -100).bool() + masked_labels = labels.clone() + masked_labels[~loss_mask] = 0 + logits = outputs['logits'] + logits.div_(temperature) + outputs['logps'] = selective_log_softmax(logits, masked_labels) + outputs = copy(outputs) + outputs['past_key_values'] = None + if not return_logits: + outputs['logits'] = None + return outputs @remote_function(collect='mean') def calculate_loss(self, **kwargs): @@ -531,11 +531,21 @@ def backward(self, **kwargs): # Auto set a grad scaler self.set_grad_scaler(adapter_name=adapter_name) scaler = optimizer_config.scaler - if scaler is not None: - scaler.scale(loss_value).backward() - else: - loss_value.backward() + optimizer_config.cur_step += 1 + should_sync = optimizer_config.do_grad_sync() + + import contextlib + no_sync_ctx = contextlib.nullcontext() + if not should_sync and hasattr(self.model, 'no_sync'): + no_sync_ctx = self.model.no_sync() + + with no_sync_ctx: + if scaler is not None: + scaler.scale(loss_value).backward() + else: + loss_value.backward() + optimizer_config.train_status.loss_value = None @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) From 94b098b8505e7f0b0f9a590a931635ee72aeeed9 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 18 Apr 2026 21:43:51 +0800 Subject: [PATCH 2/6] fix --- cookbook/rl/dpo_full.py | 2 +- cookbook/rl/dpo_lora.py | 3 +- cookbook/rl/grpo.py | 5 ++- cookbook/rl/grpo_mm.py | 8 +++- src/twinkle/preprocessor/olympiad_bench.py | 12 ++++-- src/twinkle/reward/olympiad_bench.py | 41 ++++++++++++++----- .../sampler/vllm_sampler/vllm_sampler.py | 2 +- 7 files changed, 52 insertions(+), 21 deletions(-) diff --git a/cookbook/rl/dpo_full.py b/cookbook/rl/dpo_full.py index 8c3e5a6f..8610b986 100644 --- a/cookbook/rl/dpo_full.py +++ b/cookbook/rl/dpo_full.py @@ -243,7 +243,7 @@ def main(): # Logging if optim_step % GRADIENT_ACCUMULATION_STEPS == 0: metrics = policy_model.calculate_metric(is_training=True) - logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS}] {metrics}') + logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS // GRADIENT_ACCUMULATION_STEPS}] {metrics}') # Checkpointing if optim_step % SAVE_STEPS == 0: diff --git a/cookbook/rl/dpo_lora.py b/cookbook/rl/dpo_lora.py index 861e72a4..6733127e 100644 --- a/cookbook/rl/dpo_lora.py +++ b/cookbook/rl/dpo_lora.py @@ -34,7 +34,6 @@ DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji) MODEL_GPUS – GPUs for policy model (default: 8) BATCH_SIZE – global batch size (preference pairs) (default: 8) - MAX_STEPS – total optimization steps (default: 1000) LR – learning rate (default: 1e-4) DPO_BETA – DPO temperature parameter (default: 0.1) LOSS_TYPE – DPO variant (sigmoid/hinge/ipo) (default: sigmoid) @@ -214,7 +213,7 @@ def main(): # Logging if optim_step % GRADIENT_ACCUMULATION_STEPS == 0: metrics = policy_model.calculate_metric(is_training=True) - logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS}] {metrics}') + logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS // GRADIENT_ACCUMULATION_STEPS}] {metrics}') # Checkpointing if optim_step % SAVE_STEPS == 0: diff --git a/cookbook/rl/grpo.py b/cookbook/rl/grpo.py index 7fc3f2fd..4f587668 100644 --- a/cookbook/rl/grpo.py +++ b/cookbook/rl/grpo.py @@ -142,8 +142,11 @@ def main(): # lora will be merged into the base model and sync all weights to vLLM ckpt_manager.sync_weights(merge_and_sync=False) sampler.reset_prefix_cache() + expand_prompts = [] + for prompt in global_prompts: + expand_prompts.extend([prompt] * NUM_GENERATIONS) sample_responses = sampler.sample( - global_prompts*NUM_GENERATIONS, + expand_prompts, sampling_params, ) diff --git a/cookbook/rl/grpo_mm.py b/cookbook/rl/grpo_mm.py index 9398ce8f..ff3a7450 100644 --- a/cookbook/rl/grpo_mm.py +++ b/cookbook/rl/grpo_mm.py @@ -223,9 +223,15 @@ def main(): ckpt_manager.sync_weights(merge_and_sync=False) sampler.reset_prefix_cache() + # Expand prompts: each prompt repeated NUM_GENERATIONS times consecutively + # so that GRPOAdvantage groups rewards correctly per prompt + expand_prompts = [] + for prompt in batch: + expand_prompts.extend([prompt] * NUM_GENERATIONS) + # Sample multiple completions per prompt sample_responses = sampler.sample( - batch * NUM_GENERATIONS, + expand_prompts, sampling_params, ) diff --git a/src/twinkle/preprocessor/olympiad_bench.py b/src/twinkle/preprocessor/olympiad_bench.py index 25fe57b6..51904051 100644 --- a/src/twinkle/preprocessor/olympiad_bench.py +++ b/src/twinkle/preprocessor/olympiad_bench.py @@ -61,15 +61,19 @@ def _collect_images(self, row: Dict[str, Any]) -> List[Any]: return images def _format_final_answer(self, final_answer: Any, unit: str = '') -> str: - """Format final answer(s) as string for comparison.""" + """Format final answer(s) as string for comparison. + + Note: Unit is intentionally NOT appended to the answer string. + OlympiadBench stores unit as separate metadata, and models may or may + not include it in their output. Appending it would cause spurious + mismatches for unrecognised units (mol, Hz, N, V, W, Pa, …). + Unit information is preserved in user_data for optional downstream use. + """ if isinstance(final_answer, list): answers = [str(a).strip() for a in final_answer if a] answer_str = ', '.join(answers) else: answer_str = str(final_answer).strip() if final_answer else '' - - if unit and answer_str: - answer_str = f'{answer_str} {unit}' return answer_str def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: diff --git a/src/twinkle/reward/olympiad_bench.py b/src/twinkle/reward/olympiad_bench.py index 12736cd1..5b285a70 100644 --- a/src/twinkle/reward/olympiad_bench.py +++ b/src/twinkle/reward/olympiad_bench.py @@ -31,7 +31,11 @@ def _get_completion(trajectory: Dict[str, Any]) -> str: def _extract_boxed_answers(text: str) -> List[str]: - """Extract all answers from \\boxed{} in text, handling nested braces.""" + """Extract all answers from \\boxed{} in text, handling nested braces. + + Correctly skips LaTeX escaped braces (``\\{`` and ``\\}``) so that + expressions like ``\\boxed{\\{1, 2, 3\\}}`` are extracted intact. + """ answers = [] i = 0 while i < len(text): @@ -39,11 +43,14 @@ def _extract_boxed_answers(text: str) -> List[str]: idx = text.find('\\boxed{', i) if idx == -1: break - # Find matching closing brace + # Find matching closing brace, skipping escaped \{ and \} start = idx + 7 # len('\\boxed{') depth = 1 j = start while j < len(text) and depth > 0: + if text[j] == '\\' and j + 1 < len(text) and text[j + 1] in '{}': + j += 2 # skip escaped brace + continue if text[j] == '{': depth += 1 elif text[j] == '}': @@ -140,12 +147,16 @@ def _normalize_answer(answer: str) -> str: # === Phase 4: Unit removal with word boundaries === # Units: only match standalone units, not parts of words - answer = re.sub(r'\b(cm|mm|kg|J)\b', '', answer) # Common units with word boundary - # m/g/s after numbers, brackets, or at end of string - answer = re.sub(r'(?<=[0-9\])])([mgs])\b', '', answer) - # Also remove trailing m/g/s after comma+number pattern (e.g., "3,7m" → "3,7") - answer = re.sub(r'([0-9])([mgs])$', r'\1', answer) - answer = re.sub(r'(°|度|米|千克|克|秒)', '', answer) # Chinese units always remove + # Covers SI base/derived units and common physics/chemistry units + answer = re.sub(r'\b(cm|mm|km|nm|um|kg|mg|Hz|kHz|MHz|GHz|mol|Pa|kPa|MPa|' + r'eV|keV|MeV|GeV|cal|kcal|cd|lm|lx|Wb|Bq|Gy|Sv)\b', '', answer) + # Single-letter units (N, V, W, A, K, C, T, F, H, L) - only after numbers/brackets + answer = re.sub(r'(?<=[0-9\])])\s*([NVWAKCTFHLJmgs])\b', '', answer) + # Also remove trailing single-letter units after comma+number pattern + answer = re.sub(r'([0-9])([NVWAKCTFHLJmgs])$', r'\1', answer) + # Multi-char units that need post-number context to avoid variable-name collisions + answer = re.sub(r'(?<=[0-9\])])\s*(m/s|m/s2|km/h|kg/m3|N/m|J/mol|rad|sr)\b', '', answer) + answer = re.sub(r'(°|度|米|千克|克|秒|摩尔|帕|瓦|伏|安|赫兹|牛|焦)', '', answer) # Chinese units always remove # === Phase 5: Cleanup === # Ratio colon → slash: 3:2 → 3/2 @@ -193,11 +204,17 @@ def _split_answers(gt: str) -> List[str]: return answers -def _is_numeric_match(pred: str, gt: str, tolerance: float = 0.01) -> bool: - """Check if two values match numerically.""" +def _is_numeric_match(pred: str, gt: str, tolerance: float = 0.01, abs_tolerance: float = 1e-4) -> bool: + """Check if two values match numerically. + + Uses both relative and absolute tolerance to handle near-zero values + robustly. A match is declared when *either* criterion is satisfied. + """ try: pred_val = float(pred) gt_val = float(gt) + if abs(pred_val - gt_val) < abs_tolerance: + return True if gt_val == 0: return abs(pred_val) < tolerance return abs(pred_val - gt_val) / abs(gt_val) < tolerance @@ -205,11 +222,13 @@ def _is_numeric_match(pred: str, gt: str, tolerance: float = 0.01) -> bool: return False -def _numeric_similarity(pred: str, gt: str) -> float: +def _numeric_similarity(pred: str, gt: str, abs_tolerance: float = 1e-4) -> float: """Return similarity score [0, 1] for numeric values.""" try: pred_val = float(pred) gt_val = float(gt) + if abs(pred_val - gt_val) < abs_tolerance: + return 1.0 if gt_val == 0: return 1.0 if abs(pred_val) < 0.01 else 0.0 relative_error = abs(pred_val - gt_val) / abs(gt_val) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 64816cae..c707479e 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -242,7 +242,7 @@ async def _sample_single( if 'input_ids' not in feat or multi_modal_data: if 'input_ids' in feat: if len(feat['input_ids']) != len(response.prompt_token_ids): - raise RuntimeError(f'Input ids length {len(feat["input_ids"])} does not' + raise RuntimeError(f'Input ids length {len(feat["input_ids"])} does not ' f'match prompt_token_ids length {len(response.prompt_token_ids)}') else: feat['input_ids'] = response.prompt_token_ids From 4c75a842cbcb43ca9eec12c26067e7dc88bf24fb Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 18 Apr 2026 23:15:24 +0800 Subject: [PATCH 3/6] fix --- cookbook/rl/short_math_grpo.py | 2 +- src/twinkle/dataloader/device_mesh_fetcher.py | 2 ++ src/twinkle/dataloader/retry_sampler.py | 1 + src/twinkle/infra/collectors.py | 2 +- src/twinkle/loss/gkd.py | 5 ++- src/twinkle/loss/grpo.py | 25 ++++---------- src/twinkle/metric/dpo.py | 6 ++++ src/twinkle/processor/base.py | 24 ++++++++++---- src/twinkle/reward/gsm8k.py | 33 ++++++++++++++++--- src/twinkle/reward/math_reward.py | 20 ++++++++--- src/twinkle/reward/olympiad_bench.py | 2 +- .../sampler/vllm_sampler/vllm_engine.py | 1 + src/twinkle/template/base.py | 8 ++++- 13 files changed, 93 insertions(+), 38 deletions(-) diff --git a/cookbook/rl/short_math_grpo.py b/cookbook/rl/short_math_grpo.py index 7637dae6..6a02ac42 100644 --- a/cookbook/rl/short_math_grpo.py +++ b/cookbook/rl/short_math_grpo.py @@ -143,7 +143,7 @@ def main(): remote_group='model', ) - model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1) + model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) if USE_MEGATRON: model.set_optimizer('default', lr=LEARNING_RATE) model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE) diff --git a/src/twinkle/dataloader/device_mesh_fetcher.py b/src/twinkle/dataloader/device_mesh_fetcher.py index 9560fa0f..f545f701 100644 --- a/src/twinkle/dataloader/device_mesh_fetcher.py +++ b/src/twinkle/dataloader/device_mesh_fetcher.py @@ -66,6 +66,8 @@ def fetch(self, _): continue else: break + if _data is None: + raise RuntimeError(f'No valid data after {self.max_retries} retries') data.append(_data) except StopIteration: self.ended = True diff --git a/src/twinkle/dataloader/retry_sampler.py b/src/twinkle/dataloader/retry_sampler.py index 62f05660..6fe84f04 100644 --- a/src/twinkle/dataloader/retry_sampler.py +++ b/src/twinkle/dataloader/retry_sampler.py @@ -54,6 +54,7 @@ def __iter__(self): continue yield idx total += 1 + break except Exception: # noqa import traceback traceback.print_exc() diff --git a/src/twinkle/infra/collectors.py b/src/twinkle/infra/collectors.py index aaa60819..e8756e65 100644 --- a/src/twinkle/infra/collectors.py +++ b/src/twinkle/infra/collectors.py @@ -40,7 +40,7 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], device_mesh: DeviceMesh) result[key] = pad_and_stack_tensors(values) elif isinstance(first_value, dict): - result[key] = collect_tensor_dict(values) + result[key] = collect_tensor_dict(values, device_mesh) elif isinstance(first_value, np.ndarray) and first_value.size > 1: raise NotImplementedError('Numpy array not supported for now.') diff --git a/src/twinkle/loss/gkd.py b/src/twinkle/loss/gkd.py index 399b7d1f..ea09b823 100644 --- a/src/twinkle/loss/gkd.py +++ b/src/twinkle/loss/gkd.py @@ -150,13 +150,16 @@ def _generalized_jsd_loss( temperature = 1.0 elif topk is not None and teacher_logits is not None: # Local teacher: select top-k from teacher, gather corresponding student logits - teacher_logits = teacher_logits.to(student_logits.device()) + teacher_logits = teacher_logits.to(student_logits.device) teacher_logits, topk_idx = torch.topk(teacher_logits, k=topk, dim=-1) teacher_logits.div_(temperature) student_logits = torch.gather(student_logits, dim=-1, index=topk_idx) student_logits.div_(temperature) temperature = 1.0 + if teacher_logits is not None and teacher_logits.device != student_logits.device: + teacher_logits = teacher_logits.to(student_logits.device) + # ── Mask valid (response) tokens ────────────────────────────────────── if labels is not None: mask = labels != -100 # ignore_index is always -100 per convention diff --git a/src/twinkle/loss/grpo.py b/src/twinkle/loss/grpo.py index f07a26c8..baee40ad 100644 --- a/src/twinkle/loss/grpo.py +++ b/src/twinkle/loss/grpo.py @@ -206,24 +206,13 @@ def _unpack_packed_logps( mask_flat = loss_mask.squeeze(0) # [total_tokens] # ── Find sequence boundaries ───────────────────────────────────── - if position_ids is not None: - pos_flat = position_ids.squeeze(0) # [total_tokens] - # position_ids resets to 0 at each new sequence - boundary_indices = (pos_flat == 0).nonzero(as_tuple=True)[0] - else: - # Fallback: use loss_mask transitions. Each sequence has a - # prompt region (mask=0) followed by a response region (mask=1). - # Detect 0→1 transitions preceded by a 0→0 gap (new prompt). - # Simpler: find where mask goes from 1→0→...→0→1 (prompt gap). - # We mark boundaries at the start of each prompt (first 0 after 1). - shifted = torch.cat([torch.tensor([False], device=mask_flat.device), mask_flat[:-1]]) - # Start of a new sequence: transition from mask=1 (end of prev response) - # to mask=0 (start of next prompt), or position 0 for the first sequence. - prompt_starts = ((~mask_flat) & shifted).nonzero(as_tuple=True)[0] - boundary_indices = torch.cat([ - torch.tensor([0], device=mask_flat.device), - prompt_starts, - ]) + assert position_ids is not None, ( + 'position_ids is required for unpacking packed sequences. ' + 'Ensure the processor passes position_ids in packing mode.' + ) + pos_flat = position_ids.squeeze(0) # [total_tokens] + # position_ids resets to 0 at each new sequence + boundary_indices = (pos_flat == 0).nonzero(as_tuple=True)[0] # Deduplicate & sort boundary_indices = boundary_indices.unique(sorted=True) diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py index 49668bab..93872a8c 100644 --- a/src/twinkle/metric/dpo.py +++ b/src/twinkle/metric/dpo.py @@ -115,6 +115,12 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M # Compute sequence-level logps seq_logps = self._compute_sequence_logps(logps, labels) + # DPO requires interleaved [chosen, rejected, ...] pairs → batch must be even + assert seq_logps.shape[0] % 2 == 0, ( + f'DPO metric requires an even batch size (interleaved chosen/rejected pairs), ' + f'but got batch_size={seq_logps.shape[0]}.' + ) + # Split into chosen and rejected (interleaved format) chosen_logps, rejected_logps = self._split_chosen_rejected(seq_logps) chosen_labels, rejected_labels = self._split_chosen_rejected(labels) diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index 17ba2307..ce55f0d3 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -154,19 +154,23 @@ def pad_cp_inputs(input_tensor: torch.Tensor, padding_value: int) -> torch.Tenso torch.tensor(position_ids_f.shape, device=position_ids_f.device, dtype=torch.int32), ]) - for key in ['input_ids', 'position_ids', 'attention_mask', 'labels']: - value = _input[key] + for key in ['input_ids', 'position_ids', 'attention_mask', 'labels', + 'completion_mask', 'mm_token_type_ids']: + value = _input.get(key) + if value is None: + continue result = [] for i in range(cu_seqlens.shape[0]): if i == cu_seqlens.shape[0] - 1: break - _value_slice = value[:, cu_seqlens[i]:cu_seqlens[i + 1]] - result.append(pad_cp_inputs(_value_slice, padding_value=self.padding_map[key])) - value = torch.cat(result, dim=1) + _value_slice = value[..., cu_seqlens[i]:cu_seqlens[i + 1]] + result.append(pad_cp_inputs(_value_slice, padding_value=self.padding_map.get(key, 0))) + value = torch.cat(result, dim=-1) _input[key] = value elif self.device_mesh.sequence_parallel and tp_size > 1: # Sequence parallel without CP still requires seq_len % TP == 0 - for key in ['input_ids', 'position_ids', 'attention_mask', 'labels']: + for key in ['input_ids', 'position_ids', 'attention_mask', 'labels', + 'completion_mask', 'mm_token_type_ids']: value = _input.get(key) if value is not None: _input[key] = pad_cp_inputs(value, padding_value=self.padding_map.get(key, 0)) @@ -222,6 +226,14 @@ def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], di # attention_mask = split_cp_inputs(attention_mask, cu_seqlens_q, dim=1) batch_labels = split_cp_inputs(batch_labels, cu_seqlens_q, dim=1) + completion_mask = inputs.get('completion_mask') + if completion_mask is not None: + inputs['completion_mask'] = split_cp_inputs(completion_mask, cu_seqlens_q, dim=-1) + + mm_token_type_ids = inputs.get('mm_token_type_ids') + if mm_token_type_ids is not None: + inputs['mm_token_type_ids'] = split_cp_inputs(mm_token_type_ids, cu_seqlens_q, dim=-1) + inputs['input_ids'] = input_ids inputs['position_ids'] = position_ids inputs['attention_mask'] = attention_mask diff --git a/src/twinkle/reward/gsm8k.py b/src/twinkle/reward/gsm8k.py index eb439675..347d49e4 100644 --- a/src/twinkle/reward/gsm8k.py +++ b/src/twinkle/reward/gsm8k.py @@ -4,6 +4,30 @@ from twinkle.reward.base import Reward +def _extract_last_boxed(text: str) -> str: + """Extract content from the last \\boxed{...}, handling nested braces.""" + idx = text.rfind('\\boxed{') + if idx == -1: + return '' + start = idx + len('\\boxed{') + depth = 1 + j = start + while j < len(text) and depth > 0: + if text[j] == '{': + depth += 1 + elif text[j] == '}': + depth -= 1 + j += 1 + if depth == 0: + return text[start:j - 1].strip() + return '' + + +def _has_boxed(text: str) -> bool: + """Check whether *text* contains a valid \\boxed{...} (nested-brace aware).""" + return bool(_extract_last_boxed(text)) + + class GSM8KAccuracyReward(Reward): """Accuracy reward for GSM8K: checks if the model's answer matches ground truth. @@ -15,9 +39,9 @@ class GSM8KAccuracyReward(Reward): def extract_answer(completion: str) -> str: """Extract the answer from model completion, preferring \\boxed{} over ####.""" text = completion[-500:] if len(completion) > 500 else completion - boxed = re.findall(r'\\boxed\{([^}]+)\}', text) + boxed = _extract_last_boxed(text) if boxed: - return boxed[-1].replace(',', '').replace(' ', '').strip() + return boxed.replace(',', '').replace(' ', '').strip() matches = re.findall(r'####\s*([\-\d,\.\s]+)', text) if matches: return matches[-1].replace(',', '').replace(' ', '').strip() @@ -35,7 +59,8 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: break # Get ground truth from user_data - user_data = trajectory.get('user_data') + user_data = trajectory.get('user_data') or [] + gt = '' for item in user_data: if item[0] == 'ground_truth': gt = item[1] @@ -70,6 +95,6 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: if msg.get('role') == 'assistant': completion = msg.get('content', '') break - has_answer = bool(re.search(r'\\boxed\{[^}]+\}', completion) or re.search(r'####\s*[\-\d,\.]+', completion)) + has_answer = bool(_has_boxed(completion) or re.search(r'####\s*[\-\d,\.]+', completion)) rewards.append(1.0 if has_answer else 0.0) return rewards diff --git a/src/twinkle/reward/math_reward.py b/src/twinkle/reward/math_reward.py index f8a9e36f..6c19514b 100644 --- a/src/twinkle/reward/math_reward.py +++ b/src/twinkle/reward/math_reward.py @@ -22,12 +22,22 @@ def check_terminate(answers: Union[str, List[str]]) -> List[bool]: @staticmethod def extract_boxed_result(text): - pattern = r'\\boxed{([^}]*)}' - match = re.search(pattern, text) - if match: - return match.group(1).strip() - else: + """Extract content from \\boxed{...}, handling nested braces.""" + idx = text.rfind('\\boxed{') + if idx == -1: return text + start = idx + len('\\boxed{') + depth = 1 + j = start + while j < len(text) and depth > 0: + if text[j] == '{': + depth += 1 + elif text[j] == '}': + depth -= 1 + j += 1 + if depth == 0: + return text[start:j - 1].strip() + return text @staticmethod def clean_latex(latex_str): diff --git a/src/twinkle/reward/olympiad_bench.py b/src/twinkle/reward/olympiad_bench.py index 5b285a70..c375ce6a 100644 --- a/src/twinkle/reward/olympiad_bench.py +++ b/src/twinkle/reward/olympiad_bench.py @@ -274,7 +274,7 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: rewards.append(1.0) else: sim = _numeric_similarity(pred, gt_val) - rewards.append(sim * 0.5) + rewards.append(sim * 0.99) else: pred_normalized = [_normalize_answer(p) for p in predicted] correct_count = 0 diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index be63e270..294ef5e3 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -261,6 +261,7 @@ async def sample(self, seq_logprobs = None if output.logprobs is not None: seq_logprobs = [] + breakpoint() for i, lp in enumerate(output.logprobs): if i < len(token_ids): sorted_items = sorted(lp.items(), key=lambda x: -(x[1].logprob))[:logprobs] diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 10d0984f..5784ddae 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -166,7 +166,7 @@ def concat_input_feature(self, prompt_input_feature: InputFeature, new_tokens: L input_ids = list(prompt_ids) + new_tokens labels = labels[-1:] + labels[:-1] # roll to input order labels = labels + new_tokens - labels = labels[1:] + labels[:1] # roll to input-1 order + # We don't need to roll back, self._invoke_post_pipeline will do this. result['input_ids'] = input_ids result['labels'] = labels if 'mm_token_type_ids' in result: @@ -228,10 +228,14 @@ def _truncate_feature(self, feature: InputFeature, strategy: str) -> InputFeatur result['input_ids'] = result['input_ids'][-self.max_length:] if 'labels' in result: result['labels'] = result['labels'][-self.max_length:] + if 'mm_token_type_ids' in result: + result['mm_token_type_ids'] = result['mm_token_type_ids'][..., -self.max_length:] elif strategy == 'right': result['input_ids'] = result['input_ids'][:self.max_length] if 'labels' in result: result['labels'] = result['labels'][:self.max_length] + if 'mm_token_type_ids' in result: + result['mm_token_type_ids'] = result['mm_token_type_ids'][..., :self.max_length] return InputFeature(**result) def set_mm_position_ids(self, input_feature: InputFeature): @@ -255,6 +259,8 @@ def _check_max_length(self, input_feature: InputFeature) -> List[InputFeature]: feat['input_ids'] = feat['input_ids'][start:end] if 'labels' in feat: feat['labels'] = feat['labels'][start:end] + if 'mm_token_type_ids' in feat: + feat['mm_token_type_ids'] = feat['mm_token_type_ids'][..., start:end] results.append(InputFeature(**feat)) return results From 7eed3079fabf6a29fe08b048a3d06c5cadde44c7 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 18 Apr 2026 23:57:39 +0800 Subject: [PATCH 4/6] fix --- src/twinkle/loss/grpo.py | 6 ++---- src/twinkle/metric/dpo.py | 3 +-- src/twinkle/processor/base.py | 10 ++++++---- src/twinkle/reward/olympiad_bench.py | 5 +++-- src/twinkle/sampler/vllm_sampler/vllm_engine.py | 13 ++++++++++--- 5 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/twinkle/loss/grpo.py b/src/twinkle/loss/grpo.py index baee40ad..471db66d 100644 --- a/src/twinkle/loss/grpo.py +++ b/src/twinkle/loss/grpo.py @@ -206,10 +206,8 @@ def _unpack_packed_logps( mask_flat = loss_mask.squeeze(0) # [total_tokens] # ── Find sequence boundaries ───────────────────────────────────── - assert position_ids is not None, ( - 'position_ids is required for unpacking packed sequences. ' - 'Ensure the processor passes position_ids in packing mode.' - ) + assert position_ids is not None, ('position_ids is required for unpacking packed sequences. ' + 'Ensure the processor passes position_ids in packing mode.') pos_flat = position_ids.squeeze(0) # [total_tokens] # position_ids resets to 0 at each new sequence boundary_indices = (pos_flat == 0).nonzero(as_tuple=True)[0] diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py index 93872a8c..e54baa2e 100644 --- a/src/twinkle/metric/dpo.py +++ b/src/twinkle/metric/dpo.py @@ -118,8 +118,7 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M # DPO requires interleaved [chosen, rejected, ...] pairs → batch must be even assert seq_logps.shape[0] % 2 == 0, ( f'DPO metric requires an even batch size (interleaved chosen/rejected pairs), ' - f'but got batch_size={seq_logps.shape[0]}.' - ) + f'but got batch_size={seq_logps.shape[0]}.') # Split into chosen and rejected (interleaved format) chosen_logps, rejected_logps = self._split_chosen_rejected(seq_logps) diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index ce55f0d3..95758e8c 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -154,8 +154,9 @@ def pad_cp_inputs(input_tensor: torch.Tensor, padding_value: int) -> torch.Tenso torch.tensor(position_ids_f.shape, device=position_ids_f.device, dtype=torch.int32), ]) - for key in ['input_ids', 'position_ids', 'attention_mask', 'labels', - 'completion_mask', 'mm_token_type_ids']: + for key in [ + 'input_ids', 'position_ids', 'attention_mask', 'labels', 'completion_mask', 'mm_token_type_ids' + ]: value = _input.get(key) if value is None: continue @@ -169,8 +170,9 @@ def pad_cp_inputs(input_tensor: torch.Tensor, padding_value: int) -> torch.Tenso _input[key] = value elif self.device_mesh.sequence_parallel and tp_size > 1: # Sequence parallel without CP still requires seq_len % TP == 0 - for key in ['input_ids', 'position_ids', 'attention_mask', 'labels', - 'completion_mask', 'mm_token_type_ids']: + for key in [ + 'input_ids', 'position_ids', 'attention_mask', 'labels', 'completion_mask', 'mm_token_type_ids' + ]: value = _input.get(key) if value is not None: _input[key] = pad_cp_inputs(value, padding_value=self.padding_map.get(key, 0)) diff --git a/src/twinkle/reward/olympiad_bench.py b/src/twinkle/reward/olympiad_bench.py index c375ce6a..ded5889f 100644 --- a/src/twinkle/reward/olympiad_bench.py +++ b/src/twinkle/reward/olympiad_bench.py @@ -148,8 +148,9 @@ def _normalize_answer(answer: str) -> str: # === Phase 4: Unit removal with word boundaries === # Units: only match standalone units, not parts of words # Covers SI base/derived units and common physics/chemistry units - answer = re.sub(r'\b(cm|mm|km|nm|um|kg|mg|Hz|kHz|MHz|GHz|mol|Pa|kPa|MPa|' - r'eV|keV|MeV|GeV|cal|kcal|cd|lm|lx|Wb|Bq|Gy|Sv)\b', '', answer) + answer = re.sub( + r'\b(cm|mm|km|nm|um|kg|mg|Hz|kHz|MHz|GHz|mol|Pa|kPa|MPa|' + r'eV|keV|MeV|GeV|cal|kcal|cd|lm|lx|Wb|Bq|Gy|Sv)\b', '', answer) # Single-letter units (N, V, W, A, K, C, T, F, H, L) - only after numbers/brackets answer = re.sub(r'(?<=[0-9\])])\s*([NVWAKCTFHLJmgs])\b', '', answer) # Also remove trailing single-letter units after comma+number pattern diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index 294ef5e3..1ac4d76a 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -261,11 +261,18 @@ async def sample(self, seq_logprobs = None if output.logprobs is not None: seq_logprobs = [] - breakpoint() for i, lp in enumerate(output.logprobs): if i < len(token_ids): - sorted_items = sorted(lp.items(), key=lambda x: -(x[1].logprob))[:logprobs] - seq_logprobs.append([(tid, lp_obj.logprob) for tid, lp_obj in sorted_items]) + if logprobs == 1: + # Single logprob mode: return the sampled token's logprob directly + assert token_ids[i] in lp, ( + f'Sampled token {token_ids[i]} not found in logprobs at position {i}. ' + f'Available tokens: {list(lp.keys())}') + seq_logprobs.append(lp[token_ids[i]].logprob) + else: + # Multiple logprobs mode: return top-k logprobs + sorted_items = sorted(lp.items(), key=lambda x: -(x[1].logprob))[:logprobs] + seq_logprobs.append([(tid, lp_obj.logprob) for tid, lp_obj in sorted_items]) # Map finish_reason to StopReason stop_reason: StopReason = 'length' From 375d92cac733b92dd1447b857513bce5c115822c Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 19 Apr 2026 00:11:15 +0800 Subject: [PATCH 5/6] fi --- src/twinkle/sampler/vllm_sampler/vllm_engine.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index 1ac4d76a..2023e9e9 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -264,11 +264,12 @@ async def sample(self, for i, lp in enumerate(output.logprobs): if i < len(token_ids): if logprobs == 1: - # Single logprob mode: return the sampled token's logprob directly - assert token_ids[i] in lp, ( - f'Sampled token {token_ids[i]} not found in logprobs at position {i}. ' - f'Available tokens: {list(lp.keys())}') - seq_logprobs.append(lp[token_ids[i]].logprob) + # Single logprob mode: return the sampled token's logprob + # in the same [(tid, logprob)] format as multi-logprob mode + tid = token_ids[i] + assert tid in lp, (f'Sampled token {tid} not found in logprobs at position {i}. ' + f'Available tokens: {list(lp.keys())}') + seq_logprobs.append([(tid, lp[tid].logprob)]) else: # Multiple logprobs mode: return top-k logprobs sorted_items = sorted(lp.items(), key=lambda x: -(x[1].logprob))[:logprobs] From 0e7841fc502603a896c21032aef484ccbcc4c038 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 19 Apr 2026 00:31:40 +0800 Subject: [PATCH 6/6] fix --- cookbook/rl/short_math_grpo.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/cookbook/rl/short_math_grpo.py b/cookbook/rl/short_math_grpo.py index 6a02ac42..e882b54a 100644 --- a/cookbook/rl/short_math_grpo.py +++ b/cookbook/rl/short_math_grpo.py @@ -118,11 +118,7 @@ def main(): # Since we are training on text-only data, we avoid using 'all-linear' which would include the ViT layers. lora_config = LoraConfig( - target_modules=[ - 'q_proj', 'k_proj', 'v_proj', 'o_proj', - 'gate_proj', 'up_proj', 'down_proj', - 'in_proj_qkv', 'in_proj_z', 'in_proj_a', 'in_proj_b', 'out_proj', - ], + target_modules='all-linear', r=LORA_RANK, lora_alpha=LORA_RANK * 2, lora_dropout=0.05, @@ -161,11 +157,8 @@ def main(): 'gpu_memory_utilization': 0.8, 'max_model_len': 8192, 'max_lora_rank': 32, # save as lora_config - # NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976 - # enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False) - # meaning only sync lora weights, if merge_and_sync=True, - # lora will be merged into the base model and sync all weights to vLLM 'enable_lora': True, + 'enable_tower_connector_lora': True, }, device_mesh=sampler_mesh, remote_group='sampler',