Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions src/twinkle/data_format/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,50 @@ class SamplingParams:
prompt_logprobs: int = None
num_samples: int = 1

def __post_init__(self):
if not isinstance(self.temperature, (int, float)):
raise ValueError(f'temperature must be a number, got {type(self.temperature)}')
if self.temperature < 0:
raise ValueError(f'temperature must be >= 0, got {self.temperature}')
Comment thread
tastelikefeet marked this conversation as resolved.

if not isinstance(self.top_p, (int, float)):
raise ValueError(f'top_p must be a number, got {type(self.top_p)}')
if not 0 < self.top_p <= 1:
raise ValueError(f'top_p must be in range (0, 1], got {self.top_p}')

if not isinstance(self.top_k, int):
raise ValueError(f'top_k must be an int, got {type(self.top_k)}')
if self.top_k != -1 and self.top_k < 1:
raise ValueError(f'top_k must be -1 or >= 1, got {self.top_k}')

if self.logprobs is not None:
if not isinstance(self.logprobs, int):
raise ValueError(f'logprobs must be an int or None, got {type(self.logprobs)}')
if self.logprobs < 0:
raise ValueError(f'logprobs must be >= 0, got {self.logprobs}')

if self.prompt_logprobs is not None:
if not isinstance(self.prompt_logprobs, int):
raise ValueError(f'prompt_logprobs must be an int or None, got {type(self.prompt_logprobs)}')
if self.prompt_logprobs < 0:
raise ValueError(f'prompt_logprobs must be >= 0, got {self.prompt_logprobs}')

if not isinstance(self.num_samples, int):
raise ValueError(f'num_samples must be an int, got {type(self.num_samples)}')
if self.num_samples < 1:
raise ValueError(f'num_samples must be >= 1, got {self.num_samples}')

if self.max_tokens is not None:
if not isinstance(self.max_tokens, int):
raise ValueError(f'max_tokens must be an int or None, got {type(self.max_tokens)}')
if self.max_tokens < 1:
raise ValueError(f'max_tokens must be >= 1, got {self.max_tokens}')

if not isinstance(self.repetition_penalty, (int, float)):
raise ValueError(f'repetition_penalty must be a number, got {type(self.repetition_penalty)}')
if self.repetition_penalty <= 0:
raise ValueError(f'repetition_penalty must be > 0, got {self.repetition_penalty}')
Comment thread
tastelikefeet marked this conversation as resolved.

def to_vllm(self, **kwargs):
"""Convert to vLLM SamplingParams.
"""
Expand Down
1 change: 1 addition & 0 deletions src/twinkle/model/megatron/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def forward_backward(self,
if num_microbatches <= 1:
loss_extra_kwargs_per_mb = [kwargs]
else:
# Only support extra kwargs length==total_batch_size
for mb_idx in range(num_microbatches):
mb_start = mb_idx * micro_batch_size
mb_end = mb_start + micro_batch_size
Expand Down
2 changes: 1 addition & 1 deletion src/twinkle/processor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def is_mm_position_ids(position_ids):
if key == 'position_ids' and is_mm_position_ids(values[0]):
# mrope needs to cat the sequence and unsequeeze the middle dim
value = torch.cat(values, dim=2).unsqueeze(1)
if isinstance(values[0], torch.Tensor):
elif isinstance(values[0], torch.Tensor):
value = torch.cat(values, dim=0).unsqueeze(0)
else:
value = values
Expand Down
191 changes: 96 additions & 95 deletions src/twinkle/sampler/vllm_sampler/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,104 +549,105 @@ def _zmq_send_recv(payload, where: str):
except zmq.error.Again as e:
raise RuntimeError(f'IPC timeout ({zmq_timeout_s}s) during {where} on {zmq_handle}') from e

# Launch worker side concurrently
worker_task = asyncio.ensure_future(
self.engine.collective_rpc(
'update_weights_from_ipc',
kwargs={
'peft_config': peft_config,
'base_sync_done': base_sync_done,
'use_shm': use_shm,
'zmq_handle': zmq_handle,
},
))
try:
Comment thread
tastelikefeet marked this conversation as resolved.
# Launch worker side concurrently
worker_task = asyncio.ensure_future(
self.engine.collective_rpc(
'update_weights_from_ipc',
kwargs={
'peft_config': peft_config,
'base_sync_done': base_sync_done,
'use_shm': use_shm,
'zmq_handle': zmq_handle,
},
))

# Send IPC/SHM handle, wait for worker ready (non-blocking)
handle_payload = ipc_handle if use_gpu_ipc else {'name': shm_name, 'size': bucket_size}
await loop.run_in_executor(None, _zmq_send_recv, handle_payload, 'handle handshake')

# Stream weights into buckets and send to worker
async def _chain_first():
"""Re-inject the peeked first tensor, then yield the rest."""
yield first_name, first_tensor
async for item in weight_aiter:
yield item

# Send IPC/SHM handle, wait for worker ready (non-blocking)
handle_payload = ipc_handle if use_gpu_ipc else {'name': shm_name, 'size': bucket_size}
await loop.run_in_executor(None, _zmq_send_recv, handle_payload, 'handle handshake')

# Stream weights into buckets and send to worker
async def _chain_first():
"""Re-inject the peeked first tensor, then yield the rest."""
yield first_name, first_tensor
async for item in weight_aiter:
yield item

offset = 0
bucket_meta: list[dict] = []
n_weights = 0

async def _flush_bucket(is_last: bool) -> None:
nonlocal offset, bucket_meta
if not bucket_meta and not is_last:
return
if buffer.device.type != 'cpu':
Torch.synchronize()
await loop.run_in_executor(
None,
_zmq_send_recv,
{
'bucket_meta': bucket_meta,
'is_last': is_last,
},
'final bucket' if is_last else 'bucket flush',
)
offset = 0
bucket_meta = []

async for name, weight in _chain_first():
if use_shm and weight.device.type != 'cpu':
weight = weight.cpu()
if not weight.is_contiguous():
weight = weight.contiguous()

weight_u8 = weight.view(-1).view(torch.uint8)
total_nbytes = int(weight_u8.numel())
chunk_offset = 0
while chunk_offset < total_nbytes:
if offset >= bucket_size:
await _flush_bucket(is_last=False)

chunk_nbytes = min(bucket_size - offset, total_nbytes - chunk_offset)
buffer[offset:offset + chunk_nbytes].copy_(
weight_u8[chunk_offset:chunk_offset + chunk_nbytes],
non_blocking=True,
bucket_meta: list[dict] = []
n_weights = 0

async def _flush_bucket(is_last: bool) -> None:
nonlocal offset, bucket_meta
if not bucket_meta and not is_last:
return
if buffer.device.type != 'cpu':
Torch.synchronize()
await loop.run_in_executor(
None,
_zmq_send_recv,
{
'bucket_meta': bucket_meta,
'is_last': is_last,
},
'final bucket' if is_last else 'bucket flush',
)
bucket_meta.append({
'name': name,
'shape': weight.shape,
'dtype': weight.dtype,
'offset': offset,
'nbytes': chunk_nbytes,
'chunk_offset': chunk_offset,
'total_nbytes': total_nbytes,
})
offset += chunk_nbytes
chunk_offset += chunk_nbytes
n_weights += 1

# Send last bucket
await _flush_bucket(is_last=True)

# Wait for worker to finish loading
await worker_task

# Clean up
socket.close()
zmq_ctx.term()
if zmq_handle.startswith('ipc://'):
ipc_path = zmq_handle[len('ipc://'):]
try:
if os.path.exists(ipc_path):
os.remove(ipc_path)
except OSError:
pass
del buffer
if shm is not None:
shm.close()
shm.unlink()
del shm
gc.collect()
offset = 0
bucket_meta = []

async for name, weight in _chain_first():
if use_shm and weight.device.type != 'cpu':
weight = weight.cpu()
if not weight.is_contiguous():
weight = weight.contiguous()

weight_u8 = weight.view(-1).view(torch.uint8)
total_nbytes = int(weight_u8.numel())
chunk_offset = 0
while chunk_offset < total_nbytes:
if offset >= bucket_size:
await _flush_bucket(is_last=False)

chunk_nbytes = min(bucket_size - offset, total_nbytes - chunk_offset)
buffer[offset:offset + chunk_nbytes].copy_(
weight_u8[chunk_offset:chunk_offset + chunk_nbytes],
non_blocking=True,
)
bucket_meta.append({
'name': name,
'shape': weight.shape,
'dtype': weight.dtype,
'offset': offset,
'nbytes': chunk_nbytes,
'chunk_offset': chunk_offset,
'total_nbytes': total_nbytes,
})
offset += chunk_nbytes
chunk_offset += chunk_nbytes
n_weights += 1

# Send last bucket
await _flush_bucket(is_last=True)

# Wait for worker to finish loading
await worker_task
finally:
# Clean up — always release resources regardless of exceptions
Comment thread
tastelikefeet marked this conversation as resolved.
socket.close()
zmq_ctx.term()
if zmq_handle.startswith('ipc://'):
ipc_path = zmq_handle[len('ipc://'):]
try:
if os.path.exists(ipc_path):
os.remove(ipc_path)
except OSError:
pass
del buffer
if shm is not None:
shm.close()
shm.unlink()
del shm
gc.collect()

elapsed = time.time() - start_time
mode = 'LoRA' if base_sync_done and peft_config else 'base'
Expand Down
Loading