Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/twinkle/sampler/vllm_sampler/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ async def _get_or_load_lora(

# Fast path: return cached request for this path.
if lora_path in self._lora_request_cache:
logger.info(f'Using cached LoRA request for {lora_path}')
Comment thread
Yunnglin marked this conversation as resolved.
return self._lora_request_cache[lora_path]

if not os.path.exists(lora_path):
Expand Down
1 change: 1 addition & 0 deletions src/twinkle/sampler/vllm_sampler/vllm_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def sample(

lora_request = None
if adapter_path is not None:
logger.info(f'Loading LoRA from {adapter_path}')
Comment thread
Yunnglin marked this conversation as resolved.
lora_request = self._run_in_loop(self.engine._get_or_load_lora(adapter_path))
if lora_request is None:
logger.warning(f'Failed to pre-load LoRA from {adapter_path}, '
Expand Down
3 changes: 1 addition & 2 deletions src/twinkle/server/model/tinker_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,7 @@ async def _do_save_for_sampler():
# Must save the checkpoint in the twinkle format before calling model.save()
tinker_path = checkpoint_manager.save(body.model_id, name=checkpoint_name, is_sampler=True)
logger.info(f'Saving weights to {save_dir}')
self.model.save(
name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=False)
self.model.save(name='latest', output_dir=save_dir, adapter_name=adapter_name, save_optimizer=False)
payload = body.model_dump()
payload['model_path'] = tinker_path
metadata = await self.state.get_model_metadata(body.model_id) or {}
Expand Down
4 changes: 3 additions & 1 deletion src/twinkle/server/model/twinkle_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,10 @@ async def _task():
# Must save the checkpoint in the twinkle format before calling model.save()
twinkle_path = checkpoint_manager.save(
model_id=adapter_name, name=checkpoint_name, is_sampler=body.is_sampler)
# For sampler weights the actual data is always written to 'latest/'.
model_save_name = 'latest' if body.is_sampler else checkpoint_name
checkpoint_dir = self.model.save(
name=checkpoint_name,
name=model_save_name,
output_dir=save_dir,
adapter_name=adapter_name,
save_optimizer=body.save_optimizer,
Expand Down
34 changes: 12 additions & 22 deletions src/twinkle/server/sampler/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,28 +55,18 @@ def __init__(self,
replica_context = serve.get_replica_context()
replica_id = replica_context.replica_id.unique_id

# Initialize sampler based on type
if sampler_type == 'vllm':
from twinkle.sampler import vLLMSampler
sampler_kwargs = engine_args or {}
self.sampler = vLLMSampler(
model_id=model_id,
engine_args=sampler_kwargs,
device_mesh=self.device_mesh,
remote_group=self.device_group.name,
instance_id=replica_id,
**{
k: v
for k, v in kwargs.items() if k not in ['engine_args']
})
else:
from twinkle.sampler import TorchSampler
self.sampler = TorchSampler(
model_id=model_id,
device_mesh=self.device_mesh,
instance_id=replica_id,
remote_group=self.device_group.name,
**kwargs)
from twinkle.sampler import vLLMSampler
sampler_kwargs = engine_args or {}
self.sampler = vLLMSampler(
model_id=model_id,
engine_args=sampler_kwargs,
device_mesh=self.device_mesh,
remote_group=self.device_group.name,
instance_id=replica_id,
**{
k: v
for k, v in kwargs.items() if k not in ['engine_args']
})
Comment thread
Yunnglin marked this conversation as resolved.

self.state: ServerStateProxy = get_server_state()

Expand Down
2 changes: 2 additions & 0 deletions src/twinkle/server/sampler/tinker_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ async def _do_sample():
# Set template for sampler based on model type
template = get_template_for_model(self.model_id)
self.sampler.set_template(template, model_id=self.model_id)
# Reset prefix cache for new weights
self.sampler.reset_prefix_cache()
Comment thread
Yunnglin marked this conversation as resolved.

# Get model_path from body or sampling session
model_path = body.model_path
Expand Down
3 changes: 3 additions & 0 deletions src/twinkle/server/sampler/twinkle_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ async def _task():
from twinkle.server.common.checkpoint_factory import create_checkpoint_manager
checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle')
_, adapter_path = checkpoint_manager.parse_adapter_uri(body.adapter_uri)
# Reset prefix cache only when new weights are loaded
self.sampler.reset_prefix_cache()
Comment thread
Yunnglin marked this conversation as resolved.

# Parse inputs
inputs = body.inputs
Expand All @@ -116,6 +118,7 @@ async def _task():
if body.sampling_params:
params = SamplingParams.from_dict(body.sampling_params)

# Sample
responses = self.sampler.sample(
inputs,
params,
Expand Down
35 changes: 28 additions & 7 deletions src/twinkle/server/utils/checkpoint_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,20 +608,27 @@ def save(self, model_id: str, name: str, is_sampler: bool = False, public: bool

Args:
model_id: The model identifier
name: Checkpoint name
name: Checkpoint name. For sampler checkpoints this is ignored; weights are
always stored under the fixed name ``'latest'`` and a per-save timestamp
symlink is created in the same ``sampler_weights/`` directory.
is_sampler: Whether this is a sampler checkpoint
public: Whether the checkpoint is public

Returns:
The path for the checkpoint
The ``twinkle://`` path for the checkpoint. For sampler checkpoints this
points to the timestamp symlink so callers always receive a unique path
and bypass any filesystem-path-based weight cache.
"""
# Validate path safety
if not validate_user_path(self.token, name):
raise ValueError(f'Invalid checkpoint name: {name}')

weights_type = 'sampler_weights' if is_sampler else 'weights'
checkpoint_type = 'sampler' if is_sampler else 'training'
checkpoint_id = f'{weights_type}/{name}'
# Sampler weights are always stored under the fixed name 'latest' so only one
# version exists on disk at a time; cleanup is handled by _delete_existing_sampler_weights.
effective_name = 'latest' if is_sampler else name
checkpoint_id = f'{weights_type}/{effective_name}'
path = f'{self.path_prefix}{model_id}/{checkpoint_id}'
checkpoint_path = self.get_ckpt_dir(model_id, checkpoint_id)

Expand Down Expand Up @@ -649,6 +656,19 @@ def save(self, model_id: str, name: str, is_sampler: bool = False, public: bool

# Update last_checkpoint in run info
self.training_run_manager.update(model_id, {'last_checkpoint': ckpt_data})

if is_sampler:
# Create a per-save timestamp symlink in sampler_weights/ so callers always
# receive a unique twinkle:// path and bypass any filesystem-path-based cache.
save_dir = self.get_save_dir(model_id, is_sampler=True)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
fixed_path = os.path.join(save_dir, 'latest')
symlink_path = os.path.join(save_dir, timestamp)
if os.path.islink(symlink_path):
Comment thread
Yunnglin marked this conversation as resolved.
os.unlink(symlink_path)
os.symlink(fixed_path, symlink_path)
Comment thread
Yunnglin marked this conversation as resolved.
return f'{self.path_prefix}{model_id}/sampler_weights/{timestamp}'

return path

def _delete_existing_sampler_weights(self, model_id: str):
Expand All @@ -662,14 +682,15 @@ def _delete_existing_sampler_weights(self, model_id: str):
sampler_weights_dir = run_dir / 'sampler_weights'

if sampler_weights_dir.exists() and sampler_weights_dir.is_dir():
# Delete all subdirectories in sampler_weights
for item in sampler_weights_dir.iterdir():
if item.is_dir():
# Delete checkpoint metadata file first
if item.is_symlink():
# Unlink symlinks explicitly; shutil.rmtree on a symlink-to-directory
# can follow the link and unexpectedly delete the target contents.
item.unlink()
elif item.is_dir():
meta_path = item / CHECKPOINT_INFO_FILENAME
if meta_path.exists():
meta_path.unlink()
# Delete the directory
shutil.rmtree(item)
logger.info(f'Deleted existing sampler weights for model_id: {model_id}')

Expand Down
6 changes: 3 additions & 3 deletions src/twinkle_client/http/http_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def http_get(
url: Optional[str] = None,
params: Optional[Dict[str, Any]] = {},
additional_headers: Optional[Dict[str, str]] = {},
timeout: int = 300,
timeout: int = 600,
) -> requests.Response:
"""
Send HTTP GET request with required headers.
Expand Down Expand Up @@ -124,7 +124,7 @@ def http_post(
json_data: Optional[Dict[str, Any]] = {},
data: Optional[Any] = {},
additional_headers: Optional[Dict[str, str]] = {},
timeout: int = 300,
timeout: int = 600,
) -> requests.Response:
"""
Send HTTP POST request with required headers.
Expand Down Expand Up @@ -161,7 +161,7 @@ def http_delete(
url: Optional[str] = None,
params: Optional[Dict[str, Any]] = {},
additional_headers: Optional[Dict[str, str]] = {},
timeout: int = 300,
timeout: int = 600,
) -> requests.Response:
"""
Send HTTP DELETE request with required headers.
Expand Down
Loading