From b16edf776d46ad3b9740d8b29053fba128d51cd7 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 18 Apr 2026 18:59:42 +0800 Subject: [PATCH 1/3] fix (cherry picked from commit 4a51f5a9f01187e153dff7f13afbcd9e9ba8255b) --- src/twinkle/data_format/sampling.py | 44 ++++++++++++++++++ src/twinkle/dataloader/retry_sampler.py | 6 +-- src/twinkle/model/megatron/megatron.py | 1 + src/twinkle/processor/base.py | 2 +- .../sampler/vllm_sampler/vllm_engine.py | 45 ++++++++++--------- 5 files changed, 72 insertions(+), 26 deletions(-) diff --git a/src/twinkle/data_format/sampling.py b/src/twinkle/data_format/sampling.py index bbb8e342..d8549219 100644 --- a/src/twinkle/data_format/sampling.py +++ b/src/twinkle/data_format/sampling.py @@ -21,6 +21,50 @@ class SamplingParams: prompt_logprobs: int = None num_samples: int = 1 + def __post_init__(self): + if not isinstance(self.temperature, (int, float)): + raise ValueError(f'temperature must be a number, got {type(self.temperature)}') + if self.temperature < 0: + raise ValueError(f'temperature must be >= 0, got {self.temperature}') + + if not isinstance(self.top_p, (int, float)): + raise ValueError(f'top_p must be a number, got {type(self.top_p)}') + if not 0 < self.top_p <= 1: + raise ValueError(f'top_p must be in range (0, 1], got {self.top_p}') + + if not isinstance(self.top_k, int): + raise ValueError(f'top_k must be an int, got {type(self.top_k)}') + if self.top_k != -1 and self.top_k < 1: + raise ValueError(f'top_k must be -1 or >= 1, got {self.top_k}') + + if self.logprobs is not None: + if not isinstance(self.logprobs, int): + raise ValueError(f'logprobs must be an int or None, got {type(self.logprobs)}') + if self.logprobs < 0: + raise ValueError(f'logprobs must be >= 0, got {self.logprobs}') + + if self.prompt_logprobs is not None: + if not isinstance(self.prompt_logprobs, int): + raise ValueError(f'prompt_logprobs must be an int or None, got {type(self.prompt_logprobs)}') + if self.prompt_logprobs < 0: + raise ValueError(f'prompt_logprobs must be >= 0, got {self.prompt_logprobs}') + + if not isinstance(self.num_samples, int): + raise ValueError(f'num_samples must be an int, got {type(self.num_samples)}') + if self.num_samples < 1: + raise ValueError(f'num_samples must be >= 1, got {self.num_samples}') + + if self.max_tokens is not None: + if not isinstance(self.max_tokens, int): + raise ValueError(f'max_tokens must be an int or None, got {type(self.max_tokens)}') + if self.max_tokens < 1: + raise ValueError(f'max_tokens must be >= 1, got {self.max_tokens}') + + if not isinstance(self.repetition_penalty, (int, float)): + raise ValueError(f'repetition_penalty must be a number, got {type(self.repetition_penalty)}') + if self.repetition_penalty <= 0: + raise ValueError(f'repetition_penalty must be > 0, got {self.repetition_penalty}') + def to_vllm(self, **kwargs): """Convert to vLLM SamplingParams. """ diff --git a/src/twinkle/dataloader/retry_sampler.py b/src/twinkle/dataloader/retry_sampler.py index 62f05660..f87dbf5f 100644 --- a/src/twinkle/dataloader/retry_sampler.py +++ b/src/twinkle/dataloader/retry_sampler.py @@ -37,7 +37,7 @@ def __iter__(self): traceback.print_exc() continue else: - raise StopIteration(f'Max retries exceeded: {self.max_retries}, no valid data found.') + return origin_dataset_len = len(self.dataset) if total >= origin_dataset_len: @@ -45,7 +45,7 @@ def __iter__(self): for idx in np.random.RandomState().permutation(len(self.dataset)).tolist(): if total >= origin_dataset_len: - raise StopIteration + return for _ in range(self.max_retries): try: # Skip None values and raises @@ -59,7 +59,7 @@ def __iter__(self): traceback.print_exc() continue else: - raise ValueError(f'Max retries exceeded: {self.max_retries}, no valid data found.') + return def __len__(self): return len(self.dataset) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index dcece0ff..d88004d2 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -342,6 +342,7 @@ def forward_backward(self, if num_microbatches <= 1: loss_extra_kwargs_per_mb = [kwargs] else: + # Only support extra kwargs length==total_batch_size for mb_idx in range(num_microbatches): mb_start = mb_idx * micro_batch_size mb_end = mb_start + micro_batch_size diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index eb182bf5..a61b4f15 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -369,7 +369,7 @@ def is_mm_position_ids(position_ids): if key == 'position_ids' and is_mm_position_ids(values[0]): # mrope needs to cat the sequence and unsequeeze the middle dim value = torch.cat(values, dim=2).unsqueeze(1) - if isinstance(values[0], torch.Tensor): + elif isinstance(values[0], torch.Tensor): value = torch.cat(values, dim=0).unsqueeze(0) else: value = values diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index a1b7123e..1b10fba7 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -625,28 +625,29 @@ async def _flush_bucket(is_last: bool) -> None: chunk_offset += chunk_nbytes n_weights += 1 - # Send last bucket - await _flush_bucket(is_last=True) - - # Wait for worker to finish loading - await worker_task - - # Clean up - socket.close() - zmq_ctx.term() - if zmq_handle.startswith('ipc://'): - ipc_path = zmq_handle[len('ipc://'):] - try: - if os.path.exists(ipc_path): - os.remove(ipc_path) - except OSError: - pass - del buffer - if shm is not None: - shm.close() - shm.unlink() - del shm - gc.collect() + try: + # Send last bucket + await _flush_bucket(is_last=True) + + # Wait for worker to finish loading + await worker_task + finally: + # Clean up + socket.close() + zmq_ctx.term() + if zmq_handle.startswith('ipc://'): + ipc_path = zmq_handle[len('ipc://'):] + try: + if os.path.exists(ipc_path): + os.remove(ipc_path) + except OSError: + pass + del buffer + if shm is not None: + shm.close() + shm.unlink() + del shm + gc.collect() elapsed = time.time() - start_time mode = 'LoRA' if base_sync_done and peft_config else 'base' From bb1c433c27a50652c7dcef20f8da8425caab6d5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Sat, 18 Apr 2026 19:13:12 +0800 Subject: [PATCH 2/3] fix --- .../sampler/vllm_sampler/vllm_engine.py | 148 +++++++++--------- 1 file changed, 74 insertions(+), 74 deletions(-) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index 1b10fba7..be63e270 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -549,90 +549,90 @@ def _zmq_send_recv(payload, where: str): except zmq.error.Again as e: raise RuntimeError(f'IPC timeout ({zmq_timeout_s}s) during {where} on {zmq_handle}') from e - # Launch worker side concurrently - worker_task = asyncio.ensure_future( - self.engine.collective_rpc( - 'update_weights_from_ipc', - kwargs={ - 'peft_config': peft_config, - 'base_sync_done': base_sync_done, - 'use_shm': use_shm, - 'zmq_handle': zmq_handle, - }, - )) + try: + # Launch worker side concurrently + worker_task = asyncio.ensure_future( + self.engine.collective_rpc( + 'update_weights_from_ipc', + kwargs={ + 'peft_config': peft_config, + 'base_sync_done': base_sync_done, + 'use_shm': use_shm, + 'zmq_handle': zmq_handle, + }, + )) + + # Send IPC/SHM handle, wait for worker ready (non-blocking) + handle_payload = ipc_handle if use_gpu_ipc else {'name': shm_name, 'size': bucket_size} + await loop.run_in_executor(None, _zmq_send_recv, handle_payload, 'handle handshake') + + # Stream weights into buckets and send to worker + async def _chain_first(): + """Re-inject the peeked first tensor, then yield the rest.""" + yield first_name, first_tensor + async for item in weight_aiter: + yield item - # Send IPC/SHM handle, wait for worker ready (non-blocking) - handle_payload = ipc_handle if use_gpu_ipc else {'name': shm_name, 'size': bucket_size} - await loop.run_in_executor(None, _zmq_send_recv, handle_payload, 'handle handshake') - - # Stream weights into buckets and send to worker - async def _chain_first(): - """Re-inject the peeked first tensor, then yield the rest.""" - yield first_name, first_tensor - async for item in weight_aiter: - yield item - - offset = 0 - bucket_meta: list[dict] = [] - n_weights = 0 - - async def _flush_bucket(is_last: bool) -> None: - nonlocal offset, bucket_meta - if not bucket_meta and not is_last: - return - if buffer.device.type != 'cpu': - Torch.synchronize() - await loop.run_in_executor( - None, - _zmq_send_recv, - { - 'bucket_meta': bucket_meta, - 'is_last': is_last, - }, - 'final bucket' if is_last else 'bucket flush', - ) offset = 0 - bucket_meta = [] - - async for name, weight in _chain_first(): - if use_shm and weight.device.type != 'cpu': - weight = weight.cpu() - if not weight.is_contiguous(): - weight = weight.contiguous() - - weight_u8 = weight.view(-1).view(torch.uint8) - total_nbytes = int(weight_u8.numel()) - chunk_offset = 0 - while chunk_offset < total_nbytes: - if offset >= bucket_size: - await _flush_bucket(is_last=False) - - chunk_nbytes = min(bucket_size - offset, total_nbytes - chunk_offset) - buffer[offset:offset + chunk_nbytes].copy_( - weight_u8[chunk_offset:chunk_offset + chunk_nbytes], - non_blocking=True, + bucket_meta: list[dict] = [] + n_weights = 0 + + async def _flush_bucket(is_last: bool) -> None: + nonlocal offset, bucket_meta + if not bucket_meta and not is_last: + return + if buffer.device.type != 'cpu': + Torch.synchronize() + await loop.run_in_executor( + None, + _zmq_send_recv, + { + 'bucket_meta': bucket_meta, + 'is_last': is_last, + }, + 'final bucket' if is_last else 'bucket flush', ) - bucket_meta.append({ - 'name': name, - 'shape': weight.shape, - 'dtype': weight.dtype, - 'offset': offset, - 'nbytes': chunk_nbytes, - 'chunk_offset': chunk_offset, - 'total_nbytes': total_nbytes, - }) - offset += chunk_nbytes - chunk_offset += chunk_nbytes - n_weights += 1 + offset = 0 + bucket_meta = [] + + async for name, weight in _chain_first(): + if use_shm and weight.device.type != 'cpu': + weight = weight.cpu() + if not weight.is_contiguous(): + weight = weight.contiguous() + + weight_u8 = weight.view(-1).view(torch.uint8) + total_nbytes = int(weight_u8.numel()) + chunk_offset = 0 + while chunk_offset < total_nbytes: + if offset >= bucket_size: + await _flush_bucket(is_last=False) + + chunk_nbytes = min(bucket_size - offset, total_nbytes - chunk_offset) + buffer[offset:offset + chunk_nbytes].copy_( + weight_u8[chunk_offset:chunk_offset + chunk_nbytes], + non_blocking=True, + ) + bucket_meta.append({ + 'name': name, + 'shape': weight.shape, + 'dtype': weight.dtype, + 'offset': offset, + 'nbytes': chunk_nbytes, + 'chunk_offset': chunk_offset, + 'total_nbytes': total_nbytes, + }) + offset += chunk_nbytes + chunk_offset += chunk_nbytes + n_weights += 1 - try: # Send last bucket await _flush_bucket(is_last=True) # Wait for worker to finish loading await worker_task finally: - # Clean up + # Clean up — always release resources regardless of exceptions socket.close() zmq_ctx.term() if zmq_handle.startswith('ipc://'): From 09d00763e26309c67a38f9a5d17afc5f2632dc4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Sat, 18 Apr 2026 19:14:31 +0800 Subject: [PATCH 3/3] revert file --- src/twinkle/dataloader/retry_sampler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/twinkle/dataloader/retry_sampler.py b/src/twinkle/dataloader/retry_sampler.py index f87dbf5f..62f05660 100644 --- a/src/twinkle/dataloader/retry_sampler.py +++ b/src/twinkle/dataloader/retry_sampler.py @@ -37,7 +37,7 @@ def __iter__(self): traceback.print_exc() continue else: - return + raise StopIteration(f'Max retries exceeded: {self.max_retries}, no valid data found.') origin_dataset_len = len(self.dataset) if total >= origin_dataset_len: @@ -45,7 +45,7 @@ def __iter__(self): for idx in np.random.RandomState().permutation(len(self.dataset)).tolist(): if total >= origin_dataset_len: - return + raise StopIteration for _ in range(self.max_retries): try: # Skip None values and raises @@ -59,7 +59,7 @@ def __iter__(self): traceback.print_exc() continue else: - return + raise ValueError(f'Max retries exceeded: {self.max_retries}, no valid data found.') def __len__(self): return len(self.dataset)