From c34883d9fc779286e603ca6a06649fcb7e6aa829 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Wed, 22 Apr 2026 18:46:00 +0800 Subject: [PATCH] update --- .../sampler/vllm_sampler/vllm_engine.py | 1 + .../sampler/vllm_sampler/vllm_sampler.py | 1 + src/twinkle/server/model/tinker_handlers.py | 3 +- src/twinkle/server/model/twinkle_handlers.py | 4 ++- src/twinkle/server/sampler/app.py | 34 +++++++----------- src/twinkle/server/sampler/tinker_handlers.py | 2 ++ .../server/sampler/twinkle_handlers.py | 3 ++ src/twinkle/server/utils/checkpoint_base.py | 35 +++++++++++++++---- src/twinkle_client/http/http_utils.py | 6 ++-- 9 files changed, 54 insertions(+), 35 deletions(-) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index 719acf1c..d037f1cd 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -388,6 +388,7 @@ async def _get_or_load_lora( # Fast path: return cached request for this path. if lora_path in self._lora_request_cache: + logger.info(f'Using cached LoRA request for {lora_path}') return self._lora_request_cache[lora_path] if not os.path.exists(lora_path): diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 5f1dae58..cca376f3 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -345,6 +345,7 @@ def sample( lora_request = None if adapter_path is not None: + logger.info(f'Loading LoRA from {adapter_path}') lora_request = self._run_in_loop(self.engine._get_or_load_lora(adapter_path)) if lora_request is None: logger.warning(f'Failed to pre-load LoRA from {adapter_path}, ' diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py index c9d378d7..e357d720 100644 --- a/src/twinkle/server/model/tinker_handlers.py +++ b/src/twinkle/server/model/tinker_handlers.py @@ -261,8 +261,7 @@ async def _do_save_for_sampler(): # Must save the checkpoint in the twinkle format before calling model.save() tinker_path = checkpoint_manager.save(body.model_id, name=checkpoint_name, is_sampler=True) logger.info(f'Saving weights to {save_dir}') - self.model.save( - name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=False) + self.model.save(name='latest', output_dir=save_dir, adapter_name=adapter_name, save_optimizer=False) payload = body.model_dump() payload['model_path'] = tinker_path metadata = await self.state.get_model_metadata(body.model_id) or {} diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index b0eb8d0b..7f566840 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -359,8 +359,10 @@ async def _task(): # Must save the checkpoint in the twinkle format before calling model.save() twinkle_path = checkpoint_manager.save( model_id=adapter_name, name=checkpoint_name, is_sampler=body.is_sampler) + # For sampler weights the actual data is always written to 'latest/'. + model_save_name = 'latest' if body.is_sampler else checkpoint_name checkpoint_dir = self.model.save( - name=checkpoint_name, + name=model_save_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=body.save_optimizer, diff --git a/src/twinkle/server/sampler/app.py b/src/twinkle/server/sampler/app.py index 65cb2f5d..17734472 100644 --- a/src/twinkle/server/sampler/app.py +++ b/src/twinkle/server/sampler/app.py @@ -55,28 +55,18 @@ def __init__(self, replica_context = serve.get_replica_context() replica_id = replica_context.replica_id.unique_id - # Initialize sampler based on type - if sampler_type == 'vllm': - from twinkle.sampler import vLLMSampler - sampler_kwargs = engine_args or {} - self.sampler = vLLMSampler( - model_id=model_id, - engine_args=sampler_kwargs, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=replica_id, - **{ - k: v - for k, v in kwargs.items() if k not in ['engine_args'] - }) - else: - from twinkle.sampler import TorchSampler - self.sampler = TorchSampler( - model_id=model_id, - device_mesh=self.device_mesh, - instance_id=replica_id, - remote_group=self.device_group.name, - **kwargs) + from twinkle.sampler import vLLMSampler + sampler_kwargs = engine_args or {} + self.sampler = vLLMSampler( + model_id=model_id, + engine_args=sampler_kwargs, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=replica_id, + **{ + k: v + for k, v in kwargs.items() if k not in ['engine_args'] + }) self.state: ServerStateProxy = get_server_state() diff --git a/src/twinkle/server/sampler/tinker_handlers.py b/src/twinkle/server/sampler/tinker_handlers.py index ae7496c0..19ddbfbe 100644 --- a/src/twinkle/server/sampler/tinker_handlers.py +++ b/src/twinkle/server/sampler/tinker_handlers.py @@ -52,6 +52,8 @@ async def _do_sample(): # Set template for sampler based on model type template = get_template_for_model(self.model_id) self.sampler.set_template(template, model_id=self.model_id) + # Reset prefix cache for new weights + self.sampler.reset_prefix_cache() # Get model_path from body or sampling session model_path = body.model_path diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py index 40139a53..b27ec23d 100644 --- a/src/twinkle/server/sampler/twinkle_handlers.py +++ b/src/twinkle/server/sampler/twinkle_handlers.py @@ -96,6 +96,8 @@ async def _task(): from twinkle.server.common.checkpoint_factory import create_checkpoint_manager checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') _, adapter_path = checkpoint_manager.parse_adapter_uri(body.adapter_uri) + # Reset prefix cache only when new weights are loaded + self.sampler.reset_prefix_cache() # Parse inputs inputs = body.inputs @@ -116,6 +118,7 @@ async def _task(): if body.sampling_params: params = SamplingParams.from_dict(body.sampling_params) + # Sample responses = self.sampler.sample( inputs, params, diff --git a/src/twinkle/server/utils/checkpoint_base.py b/src/twinkle/server/utils/checkpoint_base.py index cbe05602..fc13a322 100644 --- a/src/twinkle/server/utils/checkpoint_base.py +++ b/src/twinkle/server/utils/checkpoint_base.py @@ -608,12 +608,16 @@ def save(self, model_id: str, name: str, is_sampler: bool = False, public: bool Args: model_id: The model identifier - name: Checkpoint name + name: Checkpoint name. For sampler checkpoints this is ignored; weights are + always stored under the fixed name ``'latest'`` and a per-save timestamp + symlink is created in the same ``sampler_weights/`` directory. is_sampler: Whether this is a sampler checkpoint public: Whether the checkpoint is public Returns: - The path for the checkpoint + The ``twinkle://`` path for the checkpoint. For sampler checkpoints this + points to the timestamp symlink so callers always receive a unique path + and bypass any filesystem-path-based weight cache. """ # Validate path safety if not validate_user_path(self.token, name): @@ -621,7 +625,10 @@ def save(self, model_id: str, name: str, is_sampler: bool = False, public: bool weights_type = 'sampler_weights' if is_sampler else 'weights' checkpoint_type = 'sampler' if is_sampler else 'training' - checkpoint_id = f'{weights_type}/{name}' + # Sampler weights are always stored under the fixed name 'latest' so only one + # version exists on disk at a time; cleanup is handled by _delete_existing_sampler_weights. + effective_name = 'latest' if is_sampler else name + checkpoint_id = f'{weights_type}/{effective_name}' path = f'{self.path_prefix}{model_id}/{checkpoint_id}' checkpoint_path = self.get_ckpt_dir(model_id, checkpoint_id) @@ -649,6 +656,19 @@ def save(self, model_id: str, name: str, is_sampler: bool = False, public: bool # Update last_checkpoint in run info self.training_run_manager.update(model_id, {'last_checkpoint': ckpt_data}) + + if is_sampler: + # Create a per-save timestamp symlink in sampler_weights/ so callers always + # receive a unique twinkle:// path and bypass any filesystem-path-based cache. + save_dir = self.get_save_dir(model_id, is_sampler=True) + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + fixed_path = os.path.join(save_dir, 'latest') + symlink_path = os.path.join(save_dir, timestamp) + if os.path.islink(symlink_path): + os.unlink(symlink_path) + os.symlink(fixed_path, symlink_path) + return f'{self.path_prefix}{model_id}/sampler_weights/{timestamp}' + return path def _delete_existing_sampler_weights(self, model_id: str): @@ -662,14 +682,15 @@ def _delete_existing_sampler_weights(self, model_id: str): sampler_weights_dir = run_dir / 'sampler_weights' if sampler_weights_dir.exists() and sampler_weights_dir.is_dir(): - # Delete all subdirectories in sampler_weights for item in sampler_weights_dir.iterdir(): - if item.is_dir(): - # Delete checkpoint metadata file first + if item.is_symlink(): + # Unlink symlinks explicitly; shutil.rmtree on a symlink-to-directory + # can follow the link and unexpectedly delete the target contents. + item.unlink() + elif item.is_dir(): meta_path = item / CHECKPOINT_INFO_FILENAME if meta_path.exists(): meta_path.unlink() - # Delete the directory shutil.rmtree(item) logger.info(f'Deleted existing sampler weights for model_id: {model_id}') diff --git a/src/twinkle_client/http/http_utils.py b/src/twinkle_client/http/http_utils.py index ddfe6c59..e9bb4fc7 100644 --- a/src/twinkle_client/http/http_utils.py +++ b/src/twinkle_client/http/http_utils.py @@ -91,7 +91,7 @@ def http_get( url: Optional[str] = None, params: Optional[Dict[str, Any]] = {}, additional_headers: Optional[Dict[str, str]] = {}, - timeout: int = 300, + timeout: int = 600, ) -> requests.Response: """ Send HTTP GET request with required headers. @@ -124,7 +124,7 @@ def http_post( json_data: Optional[Dict[str, Any]] = {}, data: Optional[Any] = {}, additional_headers: Optional[Dict[str, str]] = {}, - timeout: int = 300, + timeout: int = 600, ) -> requests.Response: """ Send HTTP POST request with required headers. @@ -161,7 +161,7 @@ def http_delete( url: Optional[str] = None, params: Optional[Dict[str, Any]] = {}, additional_headers: Optional[Dict[str, str]] = {}, - timeout: int = 300, + timeout: int = 600, ) -> requests.Response: """ Send HTTP DELETE request with required headers.