Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cookbook/rl/dpo_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def main():
# Logging
if optim_step % GRADIENT_ACCUMULATION_STEPS == 0:
metrics = policy_model.calculate_metric(is_training=True)
logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS}] {metrics}')
logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS // GRADIENT_ACCUMULATION_STEPS}] {metrics}')

# Checkpointing
if optim_step % SAVE_STEPS == 0:
Expand Down
3 changes: 1 addition & 2 deletions cookbook/rl/dpo_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji)
MODEL_GPUS – GPUs for policy model (default: 8)
BATCH_SIZE – global batch size (preference pairs) (default: 8)
MAX_STEPS – total optimization steps (default: 1000)
LR – learning rate (default: 1e-4)
DPO_BETA – DPO temperature parameter (default: 0.1)
LOSS_TYPE – DPO variant (sigmoid/hinge/ipo) (default: sigmoid)
Expand Down Expand Up @@ -214,7 +213,7 @@ def main():
# Logging
if optim_step % GRADIENT_ACCUMULATION_STEPS == 0:
metrics = policy_model.calculate_metric(is_training=True)
logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS}] {metrics}')
logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS // GRADIENT_ACCUMULATION_STEPS}] {metrics}')

# Checkpointing
if optim_step % SAVE_STEPS == 0:
Expand Down
5 changes: 4 additions & 1 deletion cookbook/rl/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,11 @@ def main():
# lora will be merged into the base model and sync all weights to vLLM
ckpt_manager.sync_weights(merge_and_sync=False)
sampler.reset_prefix_cache()
expand_prompts = []
for prompt in global_prompts:
expand_prompts.extend([prompt] * NUM_GENERATIONS)
Comment thread
tastelikefeet marked this conversation as resolved.
sample_responses = sampler.sample(
global_prompts*NUM_GENERATIONS,
expand_prompts,
sampling_params,
)

Expand Down
8 changes: 7 additions & 1 deletion cookbook/rl/grpo_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,15 @@ def main():
ckpt_manager.sync_weights(merge_and_sync=False)
sampler.reset_prefix_cache()

# Expand prompts: each prompt repeated NUM_GENERATIONS times consecutively
# so that GRPOAdvantage groups rewards correctly per prompt
expand_prompts = []
for prompt in batch:
expand_prompts.extend([prompt] * NUM_GENERATIONS)
Comment thread
tastelikefeet marked this conversation as resolved.

# Sample multiple completions per prompt
sample_responses = sampler.sample(
batch * NUM_GENERATIONS,
expand_prompts,
sampling_params,
)

Expand Down
13 changes: 3 additions & 10 deletions cookbook/rl/short_math_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,7 @@ def main():

# Since we are training on text-only data, we avoid using 'all-linear' which would include the ViT layers.
lora_config = LoraConfig(
target_modules=[
'q_proj', 'k_proj', 'v_proj', 'o_proj',
'gate_proj', 'up_proj', 'down_proj',
'in_proj_qkv', 'in_proj_z', 'in_proj_a', 'in_proj_b', 'out_proj',
],
target_modules='all-linear',
r=LORA_RANK,
lora_alpha=LORA_RANK * 2,
lora_dropout=0.05,
Expand All @@ -143,7 +139,7 @@ def main():
remote_group='model',
)

model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1)
model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
if USE_MEGATRON:
model.set_optimizer('default', lr=LEARNING_RATE)
model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE)
Expand All @@ -161,11 +157,8 @@ def main():
'gpu_memory_utilization': 0.8,
'max_model_len': 8192,
'max_lora_rank': 32, # save as lora_config
# NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976
# enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False)
# meaning only sync lora weights, if merge_and_sync=True,
# lora will be merged into the base model and sync all weights to vLLM
'enable_lora': True,
'enable_tower_connector_lora': True,
},
device_mesh=sampler_mesh,
remote_group='sampler',
Expand Down
2 changes: 1 addition & 1 deletion src/twinkle/data_format/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from typing import Any, Dict, List, Literal, Optional, Union

if sys.version_info <= (3, 11):
if sys.version_info[:2] <= (3, 11):
# Pydantic requirements.
from typing_extensions import TypedDict
else:
Expand Down
2 changes: 2 additions & 0 deletions src/twinkle/dataloader/device_mesh_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def fetch(self, _):
continue
else:
break
if _data is None:
raise RuntimeError(f'No valid data after {self.max_retries} retries')
data.append(_data)
except StopIteration:
self.ended = True
Expand Down
1 change: 1 addition & 0 deletions src/twinkle/dataloader/retry_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __iter__(self):
continue
yield idx
total += 1
break
except Exception: # noqa
import traceback
traceback.print_exc()
Expand Down
2 changes: 1 addition & 1 deletion src/twinkle/infra/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], device_mesh: DeviceMesh)
result[key] = pad_and_stack_tensors(values)

elif isinstance(first_value, dict):
result[key] = collect_tensor_dict(values)
result[key] = collect_tensor_dict(values, device_mesh)

elif isinstance(first_value, np.ndarray) and first_value.size > 1:
raise NotImplementedError('Numpy array not supported for now.')
Expand Down
12 changes: 6 additions & 6 deletions src/twinkle/loss/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def __call__(

Args:
inputs: Dict containing 'input_ids' and 'labels' [batch, seq_len].
Batch should be organized as [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n]
Batch should be interleaved as [chosen_1, rejected_1, chosen_2, rejected_2, ...]
outputs: Dict containing either:
- 'logps': [batch, seq_len] pre-computed log probs, OR
- 'logits': [batch, seq_len, vocab] from which logps will be computed
Expand Down Expand Up @@ -535,11 +535,11 @@ def __call__(

# Odds ratio: log(odds_chosen / odds_rejected)
# log_odds = log(p/(1-p)) = log(p) - log(1-p)
# Use numerically stable computation
prob_chosen = torch.exp(chosen_avg_logps).clamp(min=1e-7, max=1 - 1e-7)
prob_rejected = torch.exp(rejected_avg_logps).clamp(min=1e-7, max=1 - 1e-7)
log_odds_chosen = torch.log(prob_chosen) - torch.log(1 - prob_chosen)
log_odds_rejected = torch.log(prob_rejected) - torch.log(1 - prob_rejected)
# Compute entirely in log-space to avoid exp() underflow:
# log(p) = avg_logps (already in log-space)
# log(1-p) = log1p(-exp(avg_logps)) (numerically stable via log1p)
log_odds_chosen = chosen_avg_logps - torch.log1p(-torch.exp(chosen_avg_logps))
log_odds_rejected = rejected_avg_logps - torch.log1p(-torch.exp(rejected_avg_logps))

# ORPO odds ratio loss
odds_ratio = log_odds_chosen - log_odds_rejected
Expand Down
5 changes: 4 additions & 1 deletion src/twinkle/loss/gkd.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,16 @@ def _generalized_jsd_loss(
temperature = 1.0
elif topk is not None and teacher_logits is not None:
# Local teacher: select top-k from teacher, gather corresponding student logits
teacher_logits = teacher_logits.to(student_logits.device())
teacher_logits = teacher_logits.to(student_logits.device)
teacher_logits, topk_idx = torch.topk(teacher_logits, k=topk, dim=-1)
teacher_logits.div_(temperature)
student_logits = torch.gather(student_logits, dim=-1, index=topk_idx)
student_logits.div_(temperature)
temperature = 1.0

if teacher_logits is not None and teacher_logits.device != student_logits.device:
teacher_logits = teacher_logits.to(student_logits.device)

# ── Mask valid (response) tokens ──────────────────────────────────────
if labels is not None:
mask = labels != -100 # ignore_index is always -100 per convention
Expand Down
23 changes: 5 additions & 18 deletions src/twinkle/loss/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,24 +206,11 @@ def _unpack_packed_logps(
mask_flat = loss_mask.squeeze(0) # [total_tokens]

# ── Find sequence boundaries ─────────────────────────────────────
if position_ids is not None:
pos_flat = position_ids.squeeze(0) # [total_tokens]
# position_ids resets to 0 at each new sequence
boundary_indices = (pos_flat == 0).nonzero(as_tuple=True)[0]
else:
# Fallback: use loss_mask transitions. Each sequence has a
# prompt region (mask=0) followed by a response region (mask=1).
# Detect 0→1 transitions preceded by a 0→0 gap (new prompt).
# Simpler: find where mask goes from 1→0→...→0→1 (prompt gap).
# We mark boundaries at the start of each prompt (first 0 after 1).
shifted = torch.cat([torch.tensor([False], device=mask_flat.device), mask_flat[:-1]])
# Start of a new sequence: transition from mask=1 (end of prev response)
# to mask=0 (start of next prompt), or position 0 for the first sequence.
prompt_starts = ((~mask_flat) & shifted).nonzero(as_tuple=True)[0]
boundary_indices = torch.cat([
torch.tensor([0], device=mask_flat.device),
prompt_starts,
])
assert position_ids is not None, ('position_ids is required for unpacking packed sequences. '
'Ensure the processor passes position_ids in packing mode.')
pos_flat = position_ids.squeeze(0) # [total_tokens]
# position_ids resets to 0 at each new sequence
boundary_indices = (pos_flat == 0).nonzero(as_tuple=True)[0]

# Deduplicate & sort
boundary_indices = boundary_indices.unique(sorted=True)
Expand Down
7 changes: 5 additions & 2 deletions src/twinkle/metric/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M
# Compute sequence-level logps
seq_logps = self._compute_sequence_logps(logps, labels)

# DPO requires interleaved [chosen, rejected, ...] pairs → batch must be even
assert seq_logps.shape[0] % 2 == 0, (
f'DPO metric requires an even batch size (interleaved chosen/rejected pairs), '
f'but got batch_size={seq_logps.shape[0]}.')

# Split into chosen and rejected (interleaved format)
chosen_logps, rejected_logps = self._split_chosen_rejected(seq_logps)
chosen_labels, rejected_labels = self._split_chosen_rejected(labels)
Expand Down Expand Up @@ -191,8 +196,6 @@ def calculate(self):
total_count = sum(r['count'] for r in all_results)
has_rewards = any(r['has_rewards'] for r in all_results)

self.reset()

if total_count == 0:
return {}

Expand Down
50 changes: 30 additions & 20 deletions src/twinkle/model/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,22 +444,22 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T
if self.sp_strategy is not None and labels is None:
outputs = self.sp_strategy.postprocess_outputs(outputs)
inputs['labels'] = labels
optimizer_config.eval_status.inputs = inputs
optimizer_config.eval_status.outputs = outputs
optimizer_config.eval_status.forward_kwargs = kwargs
optimizer_config.eval_status.loss_value = outputs.get('aux_loss', 0)
if labels is not None:
loss_mask = (labels != -100).bool()
masked_labels = labels.clone()
masked_labels[~loss_mask] = 0
logits = outputs['logits']
logits.div_(temperature)
outputs['logps'] = selective_log_softmax(logits, masked_labels)
outputs = copy(outputs)
outputs['past_key_values'] = None
if not return_logits:
outputs['logits'] = None
return outputs
optimizer_config.eval_status.inputs = inputs
optimizer_config.eval_status.outputs = outputs
optimizer_config.eval_status.forward_kwargs = kwargs
optimizer_config.eval_status.loss_value = outputs.get('aux_loss', 0)
if labels is not None:
loss_mask = (labels != -100).bool()
masked_labels = labels.clone()
masked_labels[~loss_mask] = 0
logits = outputs['logits']
logits.div_(temperature)
outputs['logps'] = selective_log_softmax(logits, masked_labels)
outputs = copy(outputs)
outputs['past_key_values'] = None
if not return_logits:
outputs['logits'] = None
return outputs

@remote_function(collect='mean')
def calculate_loss(self, **kwargs):
Expand Down Expand Up @@ -531,11 +531,21 @@ def backward(self, **kwargs):
# Auto set a grad scaler
self.set_grad_scaler(adapter_name=adapter_name)
scaler = optimizer_config.scaler
if scaler is not None:
scaler.scale(loss_value).backward()
else:
loss_value.backward()

optimizer_config.cur_step += 1
should_sync = optimizer_config.do_grad_sync()

import contextlib
no_sync_ctx = contextlib.nullcontext()
if not should_sync and hasattr(self.model, 'no_sync'):
no_sync_ctx = self.model.no_sync()

with no_sync_ctx:
if scaler is not None:
scaler.scale(loss_value).backward()
else:
loss_value.backward()

optimizer_config.train_status.loss_value = None

@remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
Expand Down
12 changes: 8 additions & 4 deletions src/twinkle/preprocessor/olympiad_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,19 @@ def _collect_images(self, row: Dict[str, Any]) -> List[Any]:
return images

def _format_final_answer(self, final_answer: Any, unit: str = '') -> str:
"""Format final answer(s) as string for comparison."""
"""Format final answer(s) as string for comparison.

Note: Unit is intentionally NOT appended to the answer string.
OlympiadBench stores unit as separate metadata, and models may or may
not include it in their output. Appending it would cause spurious
mismatches for unrecognised units (mol, Hz, N, V, W, Pa, …).
Unit information is preserved in user_data for optional downstream use.
"""
if isinstance(final_answer, list):
answers = [str(a).strip() for a in final_answer if a]
answer_str = ', '.join(answers)
else:
answer_str = str(final_answer).strip() if final_answer else ''

if unit and answer_str:
answer_str = f'{answer_str} {unit}'
return answer_str

def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
Expand Down
26 changes: 20 additions & 6 deletions src/twinkle/processor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,19 +154,25 @@ def pad_cp_inputs(input_tensor: torch.Tensor, padding_value: int) -> torch.Tenso
torch.tensor(position_ids_f.shape, device=position_ids_f.device, dtype=torch.int32),
])

for key in ['input_ids', 'position_ids', 'attention_mask', 'labels']:
value = _input[key]
for key in [
'input_ids', 'position_ids', 'attention_mask', 'labels', 'completion_mask', 'mm_token_type_ids'
]:
value = _input.get(key)
if value is None:
continue
result = []
for i in range(cu_seqlens.shape[0]):
if i == cu_seqlens.shape[0] - 1:
break
_value_slice = value[:, cu_seqlens[i]:cu_seqlens[i + 1]]
result.append(pad_cp_inputs(_value_slice, padding_value=self.padding_map[key]))
value = torch.cat(result, dim=1)
_value_slice = value[..., cu_seqlens[i]:cu_seqlens[i + 1]]
result.append(pad_cp_inputs(_value_slice, padding_value=self.padding_map.get(key, 0)))
value = torch.cat(result, dim=-1)
_input[key] = value
elif self.device_mesh.sequence_parallel and tp_size > 1:
# Sequence parallel without CP still requires seq_len % TP == 0
for key in ['input_ids', 'position_ids', 'attention_mask', 'labels']:
for key in [
'input_ids', 'position_ids', 'attention_mask', 'labels', 'completion_mask', 'mm_token_type_ids'
]:
value = _input.get(key)
if value is not None:
_input[key] = pad_cp_inputs(value, padding_value=self.padding_map.get(key, 0))
Expand Down Expand Up @@ -222,6 +228,14 @@ def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], di
# attention_mask = split_cp_inputs(attention_mask, cu_seqlens_q, dim=1)
batch_labels = split_cp_inputs(batch_labels, cu_seqlens_q, dim=1)

completion_mask = inputs.get('completion_mask')
if completion_mask is not None:
inputs['completion_mask'] = split_cp_inputs(completion_mask, cu_seqlens_q, dim=-1)

mm_token_type_ids = inputs.get('mm_token_type_ids')
if mm_token_type_ids is not None:
inputs['mm_token_type_ids'] = split_cp_inputs(mm_token_type_ids, cu_seqlens_q, dim=-1)

inputs['input_ids'] = input_ids
inputs['position_ids'] = position_ids
inputs['attention_mask'] = attention_mask
Expand Down
Loading
Loading