From 9e62ee9292fbe33c8958fc9fad34fc7fec871c81 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 17 Apr 2026 14:24:31 +0800 Subject: [PATCH 01/14] update serialize --- src/twinkle_client/common/serialize.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/twinkle_client/common/serialize.py b/src/twinkle_client/common/serialize.py index b2d1720c..3093ec58 100644 --- a/src/twinkle_client/common/serialize.py +++ b/src/twinkle_client/common/serialize.py @@ -50,11 +50,14 @@ def serialize_object(obj) -> str: data['_TWINKLE_TYPE_'] = 'DatasetMeta' return json.dumps(data, ensure_ascii=False) elif isinstance(obj, LoraConfig): - filtered_dict = { - _subkey: _subvalue - for _subkey, _subvalue in obj.__dict__.items() - if isinstance(_subvalue, basic_types) and not _subkey.startswith('_') - } + filtered_dict = {} + for _subkey, _subvalue in obj.__dict__.items(): + if isinstance(_subvalue, basic_types) and not _subkey.startswith('_'): + # Convert set/frozenset to list for JSON serialization + if isinstance(_subvalue, (set, frozenset)): + filtered_dict[_subkey] = list(_subvalue) + else: + filtered_dict[_subkey] = _subvalue filtered_dict['_TWINKLE_TYPE_'] = 'LoraConfig' return json.dumps(filtered_dict, ensure_ascii=False) elif isinstance(obj, BaseModel): From 1287e5f5cf1b4e3480feaa7d273e28cdbe19ed4d Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 17 Apr 2026 12:44:13 +0800 Subject: [PATCH 02/14] update server config --- cookbook/client/server/megatron/server_config.yaml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cookbook/client/server/megatron/server_config.yaml b/cookbook/client/server/megatron/server_config.yaml index 123cad7f..d95bfa4e 100644 --- a/cookbook/client/server/megatron/server_config.yaml +++ b/cookbook/client/server/megatron/server_config.yaml @@ -95,8 +95,10 @@ applications: device_type: cuda device_mesh: device_type: cuda - dp_size: 2 # 2-way data parallel - pp_size: 2 # 2-way pipeline parallel (~27GB/GPU) + tp_size: 2 + ep_size: 2 + pp_size: 2 + sequence_parallel: True queue_config: rps_limit: 20 # Max requests per second From 22fc350204fd728355399e50c9a7fc17c54df226 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 17 Apr 2026 15:02:52 +0800 Subject: [PATCH 03/14] update sampler save --- cookbook/client/twinkle/self_host/short_math_grpo.py | 4 +++- src/twinkle/server/model/twinkle_handlers.py | 5 +++-- src/twinkle_client/types/model.py | 1 + 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/cookbook/client/twinkle/self_host/short_math_grpo.py b/cookbook/client/twinkle/self_host/short_math_grpo.py index 03993c96..5a859fc8 100644 --- a/cookbook/client/twinkle/self_host/short_math_grpo.py +++ b/cookbook/client/twinkle/self_host/short_math_grpo.py @@ -210,11 +210,13 @@ def train(): # ========== 1. Save weights and update adapter_uri ========== # Instead of sync_weights, save the model checkpoint and pass # the resulting path to the sampler as adapter_uri + # Use is_sampler=True to delete old sampler weights and keep only the latest if step % SYNC_INTERVAL == 0: logger.info(f'Step {step}: Saving weights for sampler...') result = model.save( - name=f'grpo-sampler-step-{step}', + name='grpo-sampler-weights', save_optimizer=False, + is_sampler=True, ) current_adapter_uri = result.twinkle_path logger.info(f'Step {step}: Saved weights to {current_adapter_uri}') diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index ed2e62b9..00ce39f9 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -316,14 +316,15 @@ async def _task(): extra_kwargs = body.model_extra or {} checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') checkpoint_name = checkpoint_manager.get_ckpt_name(body.name) - save_dir = checkpoint_manager.get_save_dir(model_id=adapter_name, is_sampler=False) + save_dir = checkpoint_manager.get_save_dir(model_id=adapter_name, is_sampler=body.is_sampler) checkpoint_dir = self.model.save( name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=body.save_optimizer, **extra_kwargs) - twinkle_path = checkpoint_manager.save(model_id=adapter_name, name=checkpoint_name, is_sampler=False) + twinkle_path = checkpoint_manager.save( + model_id=adapter_name, name=checkpoint_name, is_sampler=body.is_sampler) return {'twinkle_path': twinkle_path, 'checkpoint_dir': checkpoint_dir} return await run_task(self.schedule_task_and_wait(_task, task_type='save')) diff --git a/src/twinkle_client/types/model.py b/src/twinkle_client/types/model.py index 2474c9e9..e2add1ce 100644 --- a/src/twinkle_client/types/model.py +++ b/src/twinkle_client/types/model.py @@ -65,6 +65,7 @@ class SaveRequest(BaseModel): adapter_name: str save_optimizer: bool = False name: Optional[str] = None + is_sampler: bool = False # If True, delete existing sampler weights before saving class Config: extra = 'allow' From 455334684597b898533c4606bd59ee915ba588ae Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 17 Apr 2026 18:03:30 +0800 Subject: [PATCH 04/14] update queue --- .../twinkle/self_host/short_math_grpo.py | 2 +- src/twinkle/server/model/tinker_handlers.py | 1 + src/twinkle/server/model/twinkle_handlers.py | 7 +- src/twinkle/server/utils/__init__.py | 3 +- src/twinkle/server/utils/task_queue.py | 672 ------------------ .../server/utils/task_queue/__init__.py | 26 + src/twinkle/server/utils/task_queue/config.py | 79 ++ src/twinkle/server/utils/task_queue/mixin.py | 357 ++++++++++ .../utils/{ => task_queue}/rate_limiter.py | 62 +- src/twinkle/server/utils/task_queue/types.py | 49 ++ src/twinkle/server/utils/task_queue/worker.py | 281 ++++++++ 11 files changed, 807 insertions(+), 732 deletions(-) delete mode 100644 src/twinkle/server/utils/task_queue.py create mode 100644 src/twinkle/server/utils/task_queue/__init__.py create mode 100644 src/twinkle/server/utils/task_queue/config.py create mode 100644 src/twinkle/server/utils/task_queue/mixin.py rename src/twinkle/server/utils/{ => task_queue}/rate_limiter.py (82%) create mode 100644 src/twinkle/server/utils/task_queue/types.py create mode 100644 src/twinkle/server/utils/task_queue/worker.py diff --git a/cookbook/client/twinkle/self_host/short_math_grpo.py b/cookbook/client/twinkle/self_host/short_math_grpo.py index 5a859fc8..871d4599 100644 --- a/cookbook/client/twinkle/self_host/short_math_grpo.py +++ b/cookbook/client/twinkle/self_host/short_math_grpo.py @@ -91,7 +91,7 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: GRADIENT_ACCUMULATION_STEPS = 1 DATA_NUM = 2000 # Number of Math samples to use -USE_SWANLAB = True +USE_SWANLAB = False SWANLAB_PROJECT = 'twinkle-grpo' SWANLAB_EXPERIMENT_NAME = 'short-math-grpo' diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py index 37d9df60..c9d378d7 100644 --- a/src/twinkle/server/model/tinker_handlers.py +++ b/src/twinkle/server/model/tinker_handlers.py @@ -258,6 +258,7 @@ async def _do_save_for_sampler(): checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') checkpoint_name = checkpoint_manager.get_ckpt_name(body.path) save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=True) + # Must save the checkpoint in the twinkle format before calling model.save() tinker_path = checkpoint_manager.save(body.model_id, name=checkpoint_name, is_sampler=True) logger.info(f'Saving weights to {save_dir}') self.model.save( diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index 00ce39f9..d0e9c124 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -317,14 +317,15 @@ async def _task(): checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') checkpoint_name = checkpoint_manager.get_ckpt_name(body.name) save_dir = checkpoint_manager.get_save_dir(model_id=adapter_name, is_sampler=body.is_sampler) + # Must save the checkpoint in the twinkle format before calling model.save() + twinkle_path = checkpoint_manager.save( + model_id=adapter_name, name=checkpoint_name, is_sampler=body.is_sampler) checkpoint_dir = self.model.save( name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=body.save_optimizer, **extra_kwargs) - twinkle_path = checkpoint_manager.save( - model_id=adapter_name, name=checkpoint_name, is_sampler=body.is_sampler) return {'twinkle_path': twinkle_path, 'checkpoint_dir': checkpoint_dir} return await run_task(self.schedule_task_and_wait(_task, task_type='save')) @@ -384,7 +385,7 @@ async def _task(): async_upload=False, ) - future_ref = await self.schedule_task(_task, task_type='upload_to_hub') + future_ref = await self.schedule_background_task(_task, task_type='upload_to_hub') request_id = future_ref.get('request_id') if request_id is None: raise HTTPException(status_code=500, detail=f'Upload task scheduling failed: {future_ref}') diff --git a/src/twinkle/server/utils/__init__.py b/src/twinkle/server/utils/__init__.py index 9b4abe66..9f68d104 100644 --- a/src/twinkle/server/utils/__init__.py +++ b/src/twinkle/server/utils/__init__.py @@ -3,6 +3,5 @@ BaseTrainingRunManager) from .device_utils import auto_fill_device_group_visible_devices, wrap_builder_with_device_group_env from .lifecycle import AdapterManagerMixin, ProcessorManagerMixin, SessionResourceMixin -from .rate_limiter import RateLimiter -from .task_queue import QueueState, TaskQueueConfig, TaskQueueMixin, TaskStatus +from .task_queue import QueueState, RateLimiter, TaskQueueConfig, TaskQueueMixin, TaskStatus from .template_utils import get_template_for_model diff --git a/src/twinkle/server/utils/task_queue.py b/src/twinkle/server/utils/task_queue.py deleted file mode 100644 index 9cd99253..00000000 --- a/src/twinkle/server/utils/task_queue.py +++ /dev/null @@ -1,672 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Task Queue Management for Tinker Server. - -This module provides: -1. TaskStatus - Enum for tracking task lifecycle states -2. TaskQueueConfig - Configuration for rate limits and queue behavior -3. TaskQueueMixin - Mixin class for serial task execution with rate limiting -""" -from __future__ import annotations - -import asyncio -import time -import traceback -import uuid -from collections import deque -from dataclasses import dataclass -from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Deque, Dict, Optional - -from twinkle.server.utils.metrics import get_task_metrics -from twinkle.utils.logger import get_logger -from .rate_limiter import RateLimiter - -if TYPE_CHECKING: - from twinkle.server.utils.state import ServerStateProxy - -logger = get_logger() - - -class TaskStatus(Enum): - """Task lifecycle status.""" - PENDING = 'pending' # Task created, waiting to be processed - QUEUED = 'queued' # Task in queue waiting for execution - RUNNING = 'running' # Task currently executing - COMPLETED = 'completed' # Task completed successfully - FAILED = 'failed' # Task failed with error - RATE_LIMITED = 'rate_limited' # Task rejected due to rate limiting - - -class QueueState(Enum): - """Queue state for tinker client compatibility. - - These states are returned to the tinker client to indicate the current - state of the task queue and help the client adjust its retry behavior. - """ - ACTIVE = 'active' # Queue is actively processing tasks - PAUSED_RATE_LIMIT = 'paused_rate_limit' # Queue paused due to rate limiting - PAUSED_CAPACITY = 'paused_capacity' # Queue paused due to capacity limits - UNKNOWN = 'unknown' # Unknown or unspecified state - - -@dataclass -class TaskQueueConfig: - """Configuration for task queue and rate limiting. - - Attributes: - rps_limit: Maximum requests per second per user token. - tps_limit: Maximum input tokens per second per user token. - window_seconds: Time window for rate limiting calculations. - queue_timeout: Maximum time a task can wait in queue (seconds). - enabled: Whether rate limiting is enabled. - token_cleanup_multiplier: Multiplier for token cleanup threshold. - token_cleanup_interval: How often to run cleanup task (seconds). - max_input_tokens: Maximum allowed input tokens per request (default 10000). - """ - rps_limit: float = 100.0 # 10 requests per second - tps_limit: float = 16000.0 # 10000 input tokens per second - window_seconds: float = 1.0 # 1 second sliding window - queue_timeout: float = 300.0 # 5 minutes queue timeout - enabled: bool = True # Rate limiting enabled by default - # Remove tokens after 10x window inactivity - token_cleanup_multiplier: float = 10.0 - token_cleanup_interval: float = 60.0 # Run cleanup every 60 seconds - max_input_tokens: int = 16000 # Maximum input tokens per request - - @classmethod - def from_dict(cls, config_dict: dict[str, Any] | None = None) -> TaskQueueConfig: - """Create TaskQueueConfig from a dictionary. - - Args: - config_dict: Dictionary with configuration values. Supports keys: - - rps_limit: requests per second limit - - tps_limit: input tokens per second limit - - window_seconds: sliding window duration - - queue_timeout: queue timeout in seconds - - enabled: whether rate limiting is enabled - - token_cleanup_multiplier: multiplier for token cleanup threshold - - token_cleanup_interval: cleanup task interval in seconds - - max_input_tokens: maximum input tokens per request - - Returns: - TaskQueueConfig instance with values from dict merged with defaults. - """ - config = cls() - if config_dict: - if 'rps_limit' in config_dict: - config.rps_limit = float(config_dict['rps_limit']) - if 'tps_limit' in config_dict: - config.tps_limit = float(config_dict['tps_limit']) - if 'window_seconds' in config_dict: - config.window_seconds = float(config_dict['window_seconds']) - if 'queue_timeout' in config_dict: - config.queue_timeout = float(config_dict['queue_timeout']) - if 'enabled' in config_dict: - config.enabled = bool(config_dict['enabled']) - if 'token_cleanup_multiplier' in config_dict: - config.token_cleanup_multiplier = float(config_dict['token_cleanup_multiplier']) - if 'token_cleanup_interval' in config_dict: - config.token_cleanup_interval = float(config_dict['token_cleanup_interval']) - if 'max_input_tokens' in config_dict: - config.max_input_tokens = int(config_dict['max_input_tokens']) - return config - - -@dataclass -class _QueuedTask: - request_id: str - coro_factory: Callable[[], Coroutine] - model_id: str | None - token: str | None - input_tokens: int - task_type: str | None - created_at: float - first_rate_limited_at: float | None = None - - -class TaskQueueMixin: - """Mixin providing task queue management, rate limiting, and status tracking. - - This mixin should be inherited by classes that need to: - 1. Execute async tasks serially through a queue - 2. Apply per-user rate limiting (rps and tps) - 3. Track task lifecycle status for proper client polling - - Requirements: - - Inheriting class must have `self.state: ServerStateProxy` attribute - - Call `_init_task_queue()` in `__init__` to initialize the queue - - Call `await _start_worker()` to start the background worker - - Example: - class MyService(TaskQueueMixin): - def __init__(self): - self.state = get_server_state() - self._init_task_queue(TaskQueueConfig.from_dict(config_dict)) - - async def my_endpoint(self, request, body): - async def _do_work(): - return await some_operation() - return await self.schedule_task( - _do_work, - model_id=body.model_id, - token=request.state.token, - input_tokens=len(body.tokens) - ) - """ - - # Type hint for state attribute that inheriting classes must provide - state: ServerStateProxy - - def _init_task_queue(self, config: TaskQueueConfig | None = None, deployment_name: str = '') -> None: - """Initialize the task queue system. - - Args: - config: Optional TaskQueueConfig. If None, uses default config. - deployment_name: Deployment name for metrics labels (e.g. 'Model', 'Sampler'). - """ - self._task_queue_config = config or TaskQueueConfig() - # Per-key queues, but executed by a single global worker. - self._task_queues: dict[str, asyncio.Queue] = {} - self._queue_order: Deque[str] = deque() - self._new_task_event: asyncio.Event = asyncio.Event() - - # Metrics initialization - self._deployment_name = deployment_name - self._task_metrics = get_task_metrics(deployment_name) if deployment_name else None - - # Initialize rate limiter for RPS/TPS control - self._rate_limiter = RateLimiter( - rps_limit=self._task_queue_config.rps_limit, - tps_limit=self._task_queue_config.tps_limit, - window_seconds=self._task_queue_config.window_seconds, - token_cleanup_multiplier=self._task_queue_config.token_cleanup_multiplier, - token_cleanup_interval=self._task_queue_config.token_cleanup_interval, - active_tokens_gauge=self._task_metrics.rate_limiter_active_tokens if self._task_metrics else None, - deployment_name=deployment_name, - ) - # Start the rate limiter cleanup task - self._rate_limiter.start_cleanup_task() - - # Single worker to ensure model operations remain serial. - self._worker_task: asyncio.Task | None = None - self._worker_started = False - self._worker_start_lock = asyncio.Lock() - - # Event loop reference for thread-safe callbacks (e.g., adapter expiration thread) - self._event_loop: asyncio.AbstractEventLoop | None = None - - @staticmethod - def _queue_key( - model_id: str | None, - token: str | None, - ) -> str: - if model_id: - return f'model:{model_id}' - if token: - return f'token:{token}' - return 'default' - - async def _ensure_worker_started(self) -> None: - """Ensure the single background worker is running.""" - if self._worker_started and self._worker_task is not None and not self._worker_task.done(): - return - - async with self._worker_start_lock: - if self._worker_started and self._worker_task is not None and not self._worker_task.done(): - return - self._worker_task = asyncio.create_task(self._queue_worker()) - self._worker_started = True - - def _ensure_queue_registered(self, queue_key: str) -> None: - if queue_key not in self._task_queues: - self._task_queues[queue_key] = asyncio.Queue() - if queue_key not in self._queue_order: - self._queue_order.append(queue_key) - - async def _queue_worker(self) -> None: - """Single background worker that processes tasks serially across all queues. - - Selection policy: round-robin across queue keys. If a task is rate-limited - at execution time, it is requeued and the worker tries other queues. - """ - logger.debug('[TaskQueue] Worker started') - while True: - try: - # Wait until there is at least one queue with a task - while True: - if any(q.qsize() > 0 for q in self._task_queues.values()): - break - self._new_task_event.clear() - await self._new_task_event.wait() - - executed_any = False - # Try each queue at most once per loop for fairness - for _ in range(len(self._queue_order)): - queue_key = self._queue_order[0] - self._queue_order.rotate(-1) - - q = self._task_queues.get(queue_key) - if q is None: - continue - - try: - task: _QueuedTask = q.get_nowait() - except asyncio.QueueEmpty: - continue - - # Record queue wait time and update depth gauge - if self._task_metrics: - queue_wait = time.monotonic() - task.created_at - task_type_label = task.task_type or 'unknown' - self._task_metrics.queue_wait_seconds.observe( - queue_wait, tags={ - 'deployment': self._deployment_name, - 'task_type': task_type_label - }) - total_depth = sum(qq.qsize() for qq in self._task_queues.values()) - self._task_metrics.queue_depth.set(total_depth, tags={'deployment': self._deployment_name}) - - now = time.monotonic() - - # Global queue timeout - if (now - task.created_at) > self._task_queue_config.queue_timeout: - error_payload = { - 'error': f'Queue timeout exceeded: waited {now - task.created_at:.2f}s', - 'category': 'Server' - } - await self.state.store_future_status( - task.request_id, - TaskStatus.FAILED.value, - task.model_id, - result=error_payload, - queue_state=QueueState.PAUSED_CAPACITY.value, - queue_state_reason=error_payload['error'], - ) - if self._task_metrics: - self._task_metrics.tasks_total.inc( - tags={ - 'deployment': self._deployment_name, - 'task_type': task.task_type or 'unknown', - 'status': 'timeout' - }) - q.task_done() - continue - - # Rate limiting check has been moved to schedule_task(), so tasks here should pass rate limits - - # Execute - executed_any = True - await self.state.store_future_status( - task.request_id, TaskStatus.RUNNING.value, task.model_id, queue_state=QueueState.ACTIVE.value) - - exec_start = time.monotonic() - task_status = 'completed' - try: - coro = task.coro_factory() - result = await coro - await self.state.store_future_status( - task.request_id, - TaskStatus.COMPLETED.value, - task.model_id, - result=result, - queue_state=QueueState.ACTIVE.value) - except Exception: - task_status = 'failed' - error_payload = {'error': traceback.format_exc(), 'category': 'Server'} - await self.state.store_future_status( - task.request_id, - TaskStatus.FAILED.value, - task.model_id, - result=error_payload, - queue_state=QueueState.ACTIVE.value) - finally: - q.task_done() - if self._task_metrics: - exec_time = time.monotonic() - exec_start - self._task_metrics.execution_seconds.observe( - exec_time, - tags={ - 'deployment': self._deployment_name, - 'task_type': task.task_type or 'unknown' - }) - self._task_metrics.tasks_total.inc( - tags={ - 'deployment': self._deployment_name, - 'task_type': task.task_type or 'unknown', - 'status': task_status - }) - - # Keep serial semantics: execute at most one runnable task per loop - break - - if not executed_any: - # All available tasks were rate-limited; avoid busy looping. - await asyncio.sleep(min(self._task_queue_config.window_seconds, 0.1)) - - except asyncio.CancelledError: - logger.warning('[TaskQueue] Worker cancelled') - break - except Exception: - logger.warning('Error in task queue worker') - continue - - async def _fail_queue_tasks_async(self, queue_key: str, reason: str) -> None: - q = self._task_queues.get(queue_key) - if q is None: - return - - drained: list[_QueuedTask] = [] - while True: - try: - drained.append(q.get_nowait()) - except asyncio.QueueEmpty: - break - - for task in drained: - error_payload = {'error': reason, 'category': 'Server'} - await self.state.store_future_status( - task.request_id, - TaskStatus.FAILED.value, - task.model_id, - result=error_payload, - queue_state=QueueState.UNKNOWN.value, - queue_state_reason=reason, - ) - q.task_done() - - # Remove queue structures - self._task_queues.pop(queue_key, None) - try: - while queue_key in self._queue_order: - self._queue_order.remove(queue_key) - except ValueError: - pass - - def fail_pending_tasks_for_model(self, model_id: str, reason: str) -> None: - """Fail and drop queued tasks for a model. Safe to call from non-async threads.""" - queue_key = self._queue_key(model_id=model_id, token=None) - if self._event_loop is None: - # Best-effort: nothing we can do safely without a loop. - logger.warning(f'[TaskQueue] fail_pending_tasks_for_model called without event loop: {queue_key}') - return - - def _schedule() -> None: - asyncio.create_task(self._fail_queue_tasks_async(queue_key, reason)) - - self._event_loop.call_soon_threadsafe(_schedule) - - async def _perform_preflight_checks( - self, - request_id: str, - model_id: str | None, - token: str | None, - input_tokens: int, - batch_size: int | None = None, - data_world_size: int | None = None, - ) -> dict[str, Any] | None: - """Perform pre-flight checks including rate limiting and token validation. - - Args: - request_id: The request ID for status tracking. - model_id: Optional model_id for error reporting. - token: Optional user token for rate limiting. - input_tokens: Number of input tokens for validation. - batch_size: Optional batch size for validation. - data_world_size: Optional data world size for batch size validation. - - Returns: - None if checks pass, or error response dict if checks fail. - """ - if not token or not self._task_queue_config.enabled: - return None - - # Check max input tokens - if input_tokens > self._task_queue_config.max_input_tokens: - error_msg = f'Input tokens ({input_tokens}) exceed maximum allowed ({self._task_queue_config.max_input_tokens})' # noqa: E501 - error_payload = {'error': error_msg, 'category': 'User'} - await self.state.store_future_status( - request_id, - TaskStatus.FAILED.value, - model_id, - result=error_payload, - queue_state=QueueState.UNKNOWN.value, - queue_state_reason=error_msg, - ) - return {'request_id': request_id, 'model_id': model_id} - - # Check batch size if provided - if batch_size is not None and data_world_size is not None: - if batch_size < data_world_size: - error_msg = f'Batch size {batch_size} must be greater than or equal to data world size {data_world_size}' # noqa: E501 - error_payload = {'error': error_msg, 'category': 'User'} - await self.state.store_future_status( - request_id, - TaskStatus.FAILED.value, - model_id, - result=error_payload, - queue_state=QueueState.UNKNOWN.value, - queue_state_reason=error_msg, - ) - return {'request_id': request_id, 'model_id': model_id} - - # Check rate limits - allowed, reason = await self._rate_limiter.check_and_record(token, input_tokens) - if not allowed: - if self._task_metrics: - self._task_metrics.rate_limit_rejections.inc(tags={'deployment': self._deployment_name}) - error_msg = f'Rate limit exceeded: {reason}' - error_payload = {'error': error_msg, 'category': 'User'} - await self.state.store_future_status( - request_id, - TaskStatus.FAILED.value, - model_id, - result=error_payload, - queue_state=QueueState.PAUSED_RATE_LIMIT.value, - queue_state_reason=error_msg, - ) - return {'request_id': request_id, 'model_id': model_id} - - return None - - async def schedule_task( - self, - coro_factory: Callable[[], Coroutine], - model_id: str | None = None, - token: str | None = None, - input_tokens: int = 0, - batch_size: int | None = None, - data_world_size: int | None = None, - task_type: str | None = None, - ) -> dict[str, Any]: - """Schedule an async task with rate limiting and status tracking. - - This method replaces the old `schedule_task` function with proper - status tracking to fix the race condition where clients would receive - 404 instead of 408 when polling before task execution started. - - Key improvements: - 1. Register PENDING status BEFORE creating the task - 2. Apply rate limiting per user token - 3. Execute tasks serially through a queue - - Args: - coro_factory: Factory that creates the coroutine to execute. The coroutine - will be created only after passing rate limiting and when it's time - to execute the queued task. - model_id: Optional model_id to associate with the result. - token: Optional user token for rate limiting. - input_tokens: Number of input tokens for tps rate limiting. - batch_size: Optional batch size for validation. - data_world_size: Optional data world size for batch size validation. - task_type: Optional task type for logging/observability. - - Returns: - Dict containing request_id and model_id for future retrieval. - """ - # Generate request_id first so it can be included in error responses - request_id = f'req_{uuid.uuid4().hex}' - - # 1. Pre-flight checks: rate limiting, max token validation, and batch size validation - preflight_result = await self._perform_preflight_checks(request_id, model_id, token, input_tokens, batch_size, - data_world_size) - if preflight_result is not None: - return preflight_result - - if self._event_loop is None: - self._event_loop = asyncio.get_running_loop() - - logger.debug( - f'[TaskQueue] Scheduling task {request_id}, rps_limit={self._task_queue_config.rps_limit}, enabled={self._task_queue_config.enabled}' # noqa: E501 - ) - - # 2. Register PENDING status FIRST - await self.state.store_future_status( - request_id, TaskStatus.PENDING.value, model_id, queue_state=QueueState.ACTIVE.value) - - # 3. Route to per-model/per-token queue - queue_key = self._queue_key(model_id=model_id, token=token) - self._ensure_queue_registered(queue_key) - - # 4. Ensure worker is started - await self._ensure_worker_started() - - # 5. Put task in queue and update status - q = self._task_queues[queue_key] - logger.debug( - f'[TaskQueue] Adding task {request_id} to queue key={queue_key} (current size: {q.qsize()}) type={task_type}' # noqa: E501 - ) - await q.put( - _QueuedTask( - request_id=request_id, - coro_factory=coro_factory, - model_id=model_id, - token=token, - input_tokens=input_tokens, - task_type=task_type, - created_at=time.monotonic(), - )) - await self.state.store_future_status( - request_id, TaskStatus.QUEUED.value, model_id, queue_state=QueueState.ACTIVE.value) - logger.debug(f'[TaskQueue] Task {request_id} queued, new queue size: {q.qsize()} key={queue_key}') - - self._new_task_event.set() - - if self._task_metrics: - total_depth = sum(q.qsize() for q in self._task_queues.values()) - self._task_metrics.queue_depth.set(total_depth, tags={'deployment': self._deployment_name}) - - return {'request_id': request_id, 'model_id': model_id} - - def get_queue_stats(self) -> dict[str, Any]: - """Get current queue statistics. - - Returns: - Dict with queue size and worker status. - """ - return { - 'queue_size': sum(q.qsize() for q in self._task_queues.values()), - 'queue_count': len(self._task_queues), - 'worker_running': self._worker_task is not None and not self._worker_task.done(), - 'rate_limit_config': { - 'rps_limit': self._task_queue_config.rps_limit, - 'tps_limit': self._task_queue_config.tps_limit, - 'enabled': self._task_queue_config.enabled, - } - } - - def get_rate_limit_stats(self, token: str) -> dict[str, Any]: - """Get rate limiting stats for a specific user token. - - Args: - token: User token to get stats for. - - Returns: - Dict with current and available rate limits. - """ - return self._rate_limiter.get_stats(token) - - def get_rate_limiter_memory_stats(self) -> dict[str, Any]: - """Get memory usage statistics from the rate limiter. - - Returns: - Dict with active token count and cleanup configuration. - """ - return self._rate_limiter.get_memory_stats() - - async def schedule_task_and_wait( - self, - coro_factory: Callable[[], Coroutine], - model_id: str | None = None, - token: str | None = None, - input_tokens: int = 0, - task_type: str | None = None, - ) -> Any: - """Schedule an async task and wait for its result synchronously. - - This is the twinkle-side counterpart to :meth:`schedule_task`. - It enqueues the task through the same serial worker, then blocks - (via async sleep) until the task completes, and returns the result - directly instead of a future reference dict. - - Args: - coro_factory: Factory that creates the coroutine to execute. - model_id: Optional model_id to associate with the result. - token: Optional user token for rate limiting. - input_tokens: Number of input tokens for tps rate limiting. - task_type: Optional task type for logging/observability. - - Returns: - The direct return value of the coroutine. - - Raises: - RuntimeError: If the task fails. - """ - future_ref = await self.schedule_task( - coro_factory, - model_id=model_id, - token=token, - input_tokens=input_tokens, - task_type=task_type, - ) - request_id = future_ref.get('request_id') - if request_id is None: - # Pre-flight check failed; surface the error from the stored future - raise RuntimeError(f'Task scheduling failed: {future_ref}') - - while True: - record = await self.state.get_future(request_id) - if record and record.get('status') not in ('pending', 'queued', 'running'): - break - await asyncio.sleep(0.05) - - if record['status'] == 'failed': - error = record.get('result', {}).get('error', 'Unknown error') - raise RuntimeError(error) - - return record['result'] - - async def shutdown_task_queue(self) -> None: - """Gracefully shutdown the task queue and cleanup tasks. - - This should be called when shutting down the server to ensure - proper cleanup of background tasks. - """ - # Stop the rate limiter cleanup task - await self._rate_limiter.stop_cleanup_task() - - # Cancel the worker task if running - if self._worker_task and not self._worker_task.done(): - self._worker_task.cancel() - try: - await self._worker_task - except asyncio.CancelledError: - pass - - self._worker_task = None - self._worker_started = False - - self._task_queues.clear() - self._queue_order.clear() - - logger.debug('[TaskQueue] Task queue shutdown complete') diff --git a/src/twinkle/server/utils/task_queue/__init__.py b/src/twinkle/server/utils/task_queue/__init__.py new file mode 100644 index 00000000..5c90d318 --- /dev/null +++ b/src/twinkle/server/utils/task_queue/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Task Queue package. + +Public exports (backward-compatible with the former task_queue.py module): +- TaskStatus - task lifecycle enum +- QueueState - queue state enum for tinker client compatibility +- TaskQueueConfig - queue and rate-limit configuration dataclass +- TaskQueueMixin - mixin with schedule_task / schedule_background_task +- RateLimiter - sliding-window rate limiter +""" +from .config import TaskQueueConfig +from .mixin import TaskQueueMixin +from .rate_limiter import RateLimiter +from .types import QueuedTask, QueueState, TaskStatus +from .worker import ComputeWorker + +__all__ = [ + 'TaskStatus', + 'QueueState', + 'QueuedTask', + 'TaskQueueConfig', + 'TaskQueueMixin', + 'ComputeWorker', + 'RateLimiter', +] diff --git a/src/twinkle/server/utils/task_queue/config.py b/src/twinkle/server/utils/task_queue/config.py new file mode 100644 index 00000000..a8b6437b --- /dev/null +++ b/src/twinkle/server/utils/task_queue/config.py @@ -0,0 +1,79 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Task queue configuration. + +Provides TaskQueueConfig for controlling rate limits, timeouts, +and queue behavior. +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass +class TaskQueueConfig: + """Configuration for task queue and rate limiting. + + Attributes: + rps_limit: Maximum requests per second per user token. + tps_limit: Maximum input tokens per second per user token. + window_seconds: Time window for rate limiting calculations. + queue_timeout: Maximum time a task can wait in queue (seconds). + execution_timeout: Maximum time a task can execute (seconds). 0 means no limit. + enabled: Whether rate limiting is enabled. + token_cleanup_multiplier: Multiplier for token cleanup threshold. + token_cleanup_interval: How often to run cleanup task (seconds). + max_input_tokens: Maximum allowed input tokens per request (default 16000). + """ + rps_limit: float = 100.0 # 100 requests per second + tps_limit: float = 16000.0 # 16000 input tokens per second + window_seconds: float = 1.0 # 1 second sliding window + queue_timeout: float = 300.0 # 5 minutes queue timeout + execution_timeout: float = 120.0 # 120 seconds execution timeout (0 to disable) + enabled: bool = True # Rate limiting enabled by default + # Remove tokens after 10x window inactivity + token_cleanup_multiplier: float = 10.0 + token_cleanup_interval: float = 60.0 # Run cleanup every 60 seconds + max_input_tokens: int = 16000 # Maximum input tokens per request + + @classmethod + def from_dict(cls, config_dict: dict[str, Any] | None = None) -> TaskQueueConfig: + """Create TaskQueueConfig from a dictionary. + + Args: + config_dict: Dictionary with configuration values. Supports keys: + - rps_limit: requests per second limit + - tps_limit: input tokens per second limit + - window_seconds: sliding window duration + - queue_timeout: queue timeout in seconds + - execution_timeout: task execution timeout in seconds (0 to disable) + - enabled: whether rate limiting is enabled + - token_cleanup_multiplier: multiplier for token cleanup threshold + - token_cleanup_interval: cleanup task interval in seconds + - max_input_tokens: maximum input tokens per request + + Returns: + TaskQueueConfig instance with values from dict merged with defaults. + """ + config = cls() + if config_dict: + if 'rps_limit' in config_dict: + config.rps_limit = float(config_dict['rps_limit']) + if 'tps_limit' in config_dict: + config.tps_limit = float(config_dict['tps_limit']) + if 'window_seconds' in config_dict: + config.window_seconds = float(config_dict['window_seconds']) + if 'queue_timeout' in config_dict: + config.queue_timeout = float(config_dict['queue_timeout']) + if 'execution_timeout' in config_dict: + config.execution_timeout = float(config_dict['execution_timeout']) + if 'enabled' in config_dict: + config.enabled = bool(config_dict['enabled']) + if 'token_cleanup_multiplier' in config_dict: + config.token_cleanup_multiplier = float(config_dict['token_cleanup_multiplier']) + if 'token_cleanup_interval' in config_dict: + config.token_cleanup_interval = float(config_dict['token_cleanup_interval']) + if 'max_input_tokens' in config_dict: + config.max_input_tokens = int(config_dict['max_input_tokens']) + return config diff --git a/src/twinkle/server/utils/task_queue/mixin.py b/src/twinkle/server/utils/task_queue/mixin.py new file mode 100644 index 00000000..0491fa99 --- /dev/null +++ b/src/twinkle/server/utils/task_queue/mixin.py @@ -0,0 +1,357 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +TaskQueueMixin: serial compute queue + background-task execution. + +Two execution paths: + schedule_task() / schedule_task_and_wait() -> serial compute queue (GPU ops) + schedule_background_task() -> fire-and-forget asyncio Task (I/O ops) +""" +from __future__ import annotations + +import asyncio +import time +import traceback +import uuid +from typing import TYPE_CHECKING, Any, Callable, Coroutine + +from twinkle.server.utils.metrics import get_task_metrics +from twinkle.utils.logger import get_logger +from .config import TaskQueueConfig +from .rate_limiter import RateLimiter +from .types import QueuedTask, QueueState, TaskStatus +from .worker import ComputeWorker + +if TYPE_CHECKING: + from twinkle.server.utils.state import ServerStateProxy + +logger = get_logger() + + +class TaskQueueMixin: + """Mixin providing two task execution paths. + + Execution paths + --------------- + 1. Compute queue (schedule_task / schedule_task_and_wait): + Single background worker, serial execution, round-robin across queues. + Use for GPU operations: forward, backward, step, save, load, etc. + + 2. Background task (schedule_background_task): + asyncio.create_task, runs concurrently with compute queue. + Use for pure I/O: upload_to_hub, etc. + Status is still tracked; clients can poll the same status endpoints. + + Requirements + ------------ + Inheriting class must expose self.state: ServerStateProxy and call + _init_task_queue() during __init__. + """ + + state: ServerStateProxy + + def _init_task_queue(self, config: TaskQueueConfig | None = None, deployment_name: str = '') -> None: + """Initialise the task queue, rate limiter, and compute worker.""" + self._task_queue_config = config or TaskQueueConfig() + self._deployment_name = deployment_name + self._task_metrics = get_task_metrics(deployment_name) if deployment_name else None + + self._rate_limiter = RateLimiter( + rps_limit=self._task_queue_config.rps_limit, + tps_limit=self._task_queue_config.tps_limit, + window_seconds=self._task_queue_config.window_seconds, + token_cleanup_multiplier=self._task_queue_config.token_cleanup_multiplier, + token_cleanup_interval=self._task_queue_config.token_cleanup_interval, + active_tokens_gauge=self._task_metrics.rate_limiter_active_tokens if self._task_metrics else None, + deployment_name=deployment_name, + ) + self._rate_limiter.start_cleanup_task() + + self._compute_worker = ComputeWorker( + state=self.state, + config=self._task_queue_config, + task_metrics=self._task_metrics, + deployment_name=deployment_name, + ) + + self._event_loop: asyncio.AbstractEventLoop | None = None + + @staticmethod + def _queue_key(model_id: str | None, token: str | None) -> str: + if model_id: + return f'model:{model_id}' + if token: + return f'token:{token}' + return 'default' + + async def _perform_preflight_checks( + self, + request_id: str, + model_id: str | None, + token: str | None, + input_tokens: int, + batch_size: int | None = None, + data_world_size: int | None = None, + ) -> dict[str, Any] | None: + """Run rate-limit and validation checks before queuing a task. + + Returns None if all checks pass, or an error-response dict on failure. + """ + if not token or not self._task_queue_config.enabled: + return None + + if input_tokens > self._task_queue_config.max_input_tokens: + error_msg = (f'Input tokens ({input_tokens}) exceed maximum allowed ' + f'({self._task_queue_config.max_input_tokens})') + error_payload = {'error': error_msg, 'category': 'User'} + await self.state.store_future_status( + request_id, + TaskStatus.FAILED.value, + model_id, + result=error_payload, + queue_state=QueueState.UNKNOWN.value, + queue_state_reason=error_msg, + ) + return {'request_id': request_id, 'model_id': model_id} + + if batch_size is not None and data_world_size is not None: + if batch_size < data_world_size: + error_msg = (f'Batch size {batch_size} must be >= data world size {data_world_size}') + error_payload = {'error': error_msg, 'category': 'User'} + await self.state.store_future_status( + request_id, + TaskStatus.FAILED.value, + model_id, + result=error_payload, + queue_state=QueueState.UNKNOWN.value, + queue_state_reason=error_msg, + ) + return {'request_id': request_id, 'model_id': model_id} + + allowed, reason = await self._rate_limiter.check_and_record(token, input_tokens) + if not allowed: + if self._task_metrics: + self._task_metrics.rate_limit_rejections.inc(tags={'deployment': self._deployment_name}) + error_msg = f'Rate limit exceeded: {reason}' + error_payload = {'error': error_msg, 'category': 'User'} + await self.state.store_future_status( + request_id, + TaskStatus.FAILED.value, + model_id, + result=error_payload, + queue_state=QueueState.PAUSED_RATE_LIMIT.value, + queue_state_reason=error_msg, + ) + return {'request_id': request_id, 'model_id': model_id} + + return None + + async def schedule_task( + self, + coro_factory: Callable[[], Coroutine], + model_id: str | None = None, + token: str | None = None, + input_tokens: int = 0, + batch_size: int | None = None, + data_world_size: int | None = None, + task_type: str | None = None, + ) -> dict[str, Any]: + """Schedule a GPU compute task through the serial compute queue. + + Tasks are processed one at a time in round-robin order across all + per-adapter/per-token queues. Use for any operation that touches GPU + state: forward, backward, step, save, load, add_adapter, etc. + + Args: + coro_factory: Zero-argument callable that creates the coroutine. + model_id: Adapter/model id for queue routing and result association. + token: User token for rate limiting. + input_tokens: Token count for TPS rate limiting. + batch_size: Optional batch size, validated against data_world_size. + data_world_size: Optional data world size for batch validation. + task_type: Label for logging and metrics. + + Returns: + {'request_id': str, 'model_id': str | None} + """ + request_id = f'req_{uuid.uuid4().hex}' + + preflight_result = await self._perform_preflight_checks(request_id, model_id, token, input_tokens, batch_size, + data_world_size) + if preflight_result is not None: + return preflight_result + + if self._event_loop is None: + self._event_loop = asyncio.get_running_loop() + + logger.info(f'[TaskQueue] Scheduling task {request_id}, type={task_type or "unknown"}, ' + f'model_id={model_id}, input_tokens={input_tokens}') + + await self.state.store_future_status( + request_id, TaskStatus.PENDING.value, model_id, queue_state=QueueState.ACTIVE.value) + + queue_key = self._queue_key(model_id=model_id, token=token) + self._compute_worker.ensure_queue_registered(queue_key) + await self._compute_worker.ensure_started() + + q = self._compute_worker.task_queues[queue_key] + await q.put( + QueuedTask( + request_id=request_id, + coro_factory=coro_factory, + model_id=model_id, + token=token, + input_tokens=input_tokens, + task_type=task_type, + created_at=time.monotonic(), + )) + await self.state.store_future_status( + request_id, TaskStatus.QUEUED.value, model_id, queue_state=QueueState.ACTIVE.value) + logger.info(f'[TaskQueue] Task {request_id} queued, queue_key={queue_key}, ' + f'queue_size={q.qsize()}, total_queues={len(self._compute_worker.task_queues)}') + + self._compute_worker.new_task_event.set() + + if self._task_metrics: + total_depth = sum(q.qsize() for q in self._compute_worker.task_queues.values()) + self._task_metrics.queue_depth.set(total_depth, tags={'deployment': self._deployment_name}) + + return {'request_id': request_id, 'model_id': model_id} + + async def schedule_task_and_wait( + self, + coro_factory: Callable[[], Coroutine], + model_id: str | None = None, + token: str | None = None, + input_tokens: int = 0, + task_type: str | None = None, + ) -> Any: + """Schedule a compute task and block until it completes. + + Twinkle-side counterpart to schedule_task(). Enqueues the task through + the serial worker, polls until a terminal state, and returns the result. + + Raises: + RuntimeError: If the task fails or scheduling is rejected. + """ + future_ref = await self.schedule_task( + coro_factory, + model_id=model_id, + token=token, + input_tokens=input_tokens, + task_type=task_type, + ) + request_id = future_ref.get('request_id') + if request_id is None: + raise RuntimeError(f'Task scheduling failed: {future_ref}') + + while True: + record = await self.state.get_future(request_id) + if record and record.get('status') not in ('pending', 'queued', 'running'): + break + await asyncio.sleep(0.05) + + if record['status'] == 'failed': + error = record.get('result', {}).get('error', 'Unknown error') + raise RuntimeError(error) + + return record['result'] + + async def schedule_background_task( + self, + coro_factory: Callable[[], Coroutine], + model_id: str | None = None, + task_type: str | None = None, + ) -> dict[str, Any]: + """Schedule a fire-and-forget background task (bypasses compute queue). + + Designed for pure I/O operations such as upload_to_hub that do not + require GPU serialization. The task is launched immediately as an + asyncio.create_task so it runs concurrently with the compute queue + without blocking any other user's training operations. + + Status is tracked via state.store_future_status so clients can poll + progress through the same status endpoints as schedule_task(). + + Args: + coro_factory: Zero-argument callable that creates the coroutine. + model_id: Optional model id for result association. + task_type: Label for logging. + + Returns: + {'request_id': str, 'model_id': str | None} + """ + request_id = f'req_{uuid.uuid4().hex}' + logger.info(f'[TaskQueue] Scheduling background task {request_id}, ' + f'type={task_type or "unknown"}, model_id={model_id}') + + await self.state.store_future_status( + request_id, TaskStatus.RUNNING.value, model_id, queue_state=QueueState.ACTIVE.value) + + async def _run() -> None: + try: + result = await coro_factory() + await self.state.store_future_status( + request_id, + TaskStatus.COMPLETED.value, + model_id, + result=result, + queue_state=QueueState.ACTIVE.value) + logger.info(f'[TaskQueue] Background task {request_id} completed, type={task_type or "unknown"}') + except Exception: + error_payload = {'error': traceback.format_exc(), 'category': 'Server'} + await self.state.store_future_status( + request_id, + TaskStatus.FAILED.value, + model_id, + result=error_payload, + queue_state=QueueState.ACTIVE.value) + logger.error(f'[TaskQueue] Background task {request_id} FAILED, type={task_type or "unknown"}:\n' + f'{traceback.format_exc(limit=3)}') + + asyncio.create_task(_run()) + return {'request_id': request_id, 'model_id': model_id} + + async def _fail_queue_tasks_async(self, queue_key: str, reason: str) -> None: + await self._compute_worker.fail_queue_tasks(queue_key, reason) + + def fail_pending_tasks_for_model(self, model_id: str, reason: str) -> None: + """Fail and drop all queued tasks for a model. Thread-safe.""" + queue_key = self._queue_key(model_id=model_id, token=None) + if self._event_loop is None: + logger.warning(f'[TaskQueue] fail_pending_tasks_for_model called without event loop: {queue_key}') + return + + def _schedule() -> None: + asyncio.create_task(self._fail_queue_tasks_async(queue_key, reason)) + + self._event_loop.call_soon_threadsafe(_schedule) + + def get_queue_stats(self) -> dict[str, Any]: + """Return current compute queue statistics.""" + return { + 'queue_size': + sum(q.qsize() for q in self._compute_worker.task_queues.values()), + 'queue_count': + len(self._compute_worker.task_queues), + 'worker_running': (self._compute_worker._worker_task is not None + and not self._compute_worker._worker_task.done()), + 'rate_limit_config': { + 'rps_limit': self._task_queue_config.rps_limit, + 'tps_limit': self._task_queue_config.tps_limit, + 'enabled': self._task_queue_config.enabled, + }, + } + + def get_rate_limit_stats(self, token: str) -> dict[str, Any]: + """Return rate-limiting stats for a user token.""" + return self._rate_limiter.get_stats(token) + + def get_rate_limiter_memory_stats(self) -> dict[str, Any]: + """Return memory usage statistics from the rate limiter.""" + return self._rate_limiter.get_memory_stats() + + async def shutdown_task_queue(self) -> None: + """Gracefully shut down the compute queue and release resources.""" + await self._rate_limiter.stop_cleanup_task() + await self._compute_worker.stop() + logger.debug('[TaskQueue] Task queue shutdown complete') diff --git a/src/twinkle/server/utils/rate_limiter.py b/src/twinkle/server/utils/task_queue/rate_limiter.py similarity index 82% rename from src/twinkle/server/utils/rate_limiter.py rename to src/twinkle/server/utils/task_queue/rate_limiter.py index 845cf246..22942880 100644 --- a/src/twinkle/server/utils/rate_limiter.py +++ b/src/twinkle/server/utils/task_queue/rate_limiter.py @@ -10,7 +10,7 @@ import asyncio import time -from typing import Any, Dict, List, Optional, Tuple +from typing import Any from twinkle.utils.logger import get_logger @@ -81,12 +81,7 @@ def __init__( self._deployment_name = deployment_name def _cleanup_old_requests(self, token: str, current_time: float) -> None: - """Remove requests outside the sliding window. - - Args: - token: User token to clean up. - current_time: Current timestamp. - """ + """Remove requests outside the sliding window.""" if token not in self._token_requests: return cutoff_time = current_time - self.window_seconds @@ -99,11 +94,7 @@ def _cleanup_old_requests(self, token: str, current_time: float) -> None: del self._last_activity[token] async def _cleanup_inactive_tokens(self) -> None: - """Background task that periodically removes inactive tokens. - - This prevents unbounded memory growth by removing tokens that haven't - been active for token_cleanup_multiplier * window_seconds. - """ + """Background task that periodically removes inactive tokens.""" logger.debug(f'[RateLimiter] Cleanup task started (interval={self.token_cleanup_interval}s)') while True: try: @@ -111,15 +102,12 @@ async def _cleanup_inactive_tokens(self) -> None: async with self._lock: current_time = time.time() - inactive_threshold = current_time - \ - (self.window_seconds * self.token_cleanup_multiplier) + inactive_threshold = current_time - (self.window_seconds * self.token_cleanup_multiplier) - # Find tokens that haven't been active recently tokens_to_remove = [ token for token, last_time in self._last_activity.items() if last_time < inactive_threshold ] - # Remove inactive tokens for token in tokens_to_remove: if token in self._token_requests: del self._token_requests[token] @@ -142,21 +130,14 @@ async def _cleanup_inactive_tokens(self) -> None: continue def start_cleanup_task(self) -> None: - """Start the background cleanup task. - - This should be called once when the rate limiter is initialized. - It's safe to call multiple times - subsequent calls are ignored. - """ + """Start the background cleanup task. Safe to call multiple times.""" if not self._cleanup_started: self._cleanup_task = asyncio.create_task(self._cleanup_inactive_tokens()) self._cleanup_started = True logger.debug('[RateLimiter] Background cleanup task started') async def stop_cleanup_task(self) -> None: - """Stop the background cleanup task. - - This should be called when shutting down the server. - """ + """Stop the background cleanup task.""" if self._cleanup_task and not self._cleanup_task.done(): self._cleanup_task.cancel() try: @@ -168,42 +149,27 @@ async def stop_cleanup_task(self) -> None: async def check_and_record(self, token: str, input_tokens: int) -> tuple[bool, str | None]: """Check if request is allowed and record it if so. - Args: - token: User token for rate limiting. - input_tokens: Number of input tokens in this request. - Returns: Tuple of (allowed: bool, reason: Optional[str]). - If allowed is False, reason contains the rate limit explanation. """ async with self._lock: current_time = time.time() - - # Clean up old requests self._cleanup_old_requests(token, current_time) - # Initialize if needed if token not in self._token_requests: self._token_requests[token] = [] - - # Update last activity time self._last_activity[token] = current_time requests = self._token_requests[token] - - # Count current window stats request_count = len(requests) token_count = sum(count for _, count in requests) - # Check rps limit if request_count >= self.rps_limit: return False, f'RPS limit exceeded: {request_count}/{self.rps_limit} requests/s' - # Check tps limit if token_count + input_tokens > self.tps_limit: return False, f'TPS limit exceeded: {token_count + input_tokens}/{self.tps_limit} tokens/s' - # Record this request self._token_requests[token].append((current_time, input_tokens)) if self._active_tokens_gauge is not None: tags = {'deployment': self._deployment_name} if self._deployment_name else {} @@ -211,18 +177,10 @@ async def check_and_record(self, token: str, input_tokens: int) -> tuple[bool, s return True, None def get_stats(self, token: str) -> dict[str, Any]: - """Get current rate limiting stats for a token. - - Args: - token: User token to get stats for. - - Returns: - Dict with current rps, tps, and limits. - """ + """Get current rate limiting stats for a token.""" current_time = time.time() self._cleanup_old_requests(token, current_time) - # Update last activity time even for stats queries if token in self._token_requests: self._last_activity[token] = current_time @@ -240,11 +198,7 @@ def get_stats(self, token: str) -> dict[str, Any]: } def get_memory_stats(self) -> dict[str, Any]: - """Get memory usage statistics for monitoring. - - Returns: - Dict with active token count and cleanup configuration. - """ + """Get memory usage statistics for monitoring.""" return { 'active_tokens': len(self._token_requests), 'tracked_tokens': len(self._last_activity), diff --git a/src/twinkle/server/utils/task_queue/types.py b/src/twinkle/server/utils/task_queue/types.py new file mode 100644 index 00000000..8599bbd9 --- /dev/null +++ b/src/twinkle/server/utils/task_queue/types.py @@ -0,0 +1,49 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Core type definitions for the task queue system. + +Provides: +- TaskStatus: Enum for tracking task lifecycle states +- QueueState: Enum for tinker client compatibility (retry behavior hints) +- QueuedTask: Dataclass representing a queued work item +""" +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Coroutine + + +class TaskStatus(Enum): + """Task lifecycle status.""" + PENDING = 'pending' # Task created, waiting to be processed + QUEUED = 'queued' # Task in queue waiting for execution + RUNNING = 'running' # Task currently executing + COMPLETED = 'completed' # Task completed successfully + FAILED = 'failed' # Task failed with error + RATE_LIMITED = 'rate_limited' # Task rejected due to rate limiting + + +class QueueState(Enum): + """Queue state for tinker client compatibility. + + These states are returned to the tinker client to indicate the current + state of the task queue and help the client adjust its retry behavior. + """ + ACTIVE = 'active' # Queue is actively processing tasks + PAUSED_RATE_LIMIT = 'paused_rate_limit' # Queue paused due to rate limiting + PAUSED_CAPACITY = 'paused_capacity' # Queue paused due to capacity limits + UNKNOWN = 'unknown' # Unknown or unspecified state + + +@dataclass +class QueuedTask: + """Dataclass representing a task waiting in the compute queue.""" + request_id: str + coro_factory: Callable[[], Coroutine] + model_id: str | None + token: str | None + input_tokens: int + task_type: str | None + created_at: float + first_rate_limited_at: float | None = None diff --git a/src/twinkle/server/utils/task_queue/worker.py b/src/twinkle/server/utils/task_queue/worker.py new file mode 100644 index 00000000..5616d8f9 --- /dev/null +++ b/src/twinkle/server/utils/task_queue/worker.py @@ -0,0 +1,281 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Compute queue worker. + +Provides ComputeWorker: a single background asyncio Task that processes +GPU compute tasks serially across all per-adapter/per-token queues using +a round-robin policy. +""" +from __future__ import annotations + +import asyncio +import time +import traceback +from collections import deque +from typing import TYPE_CHECKING, Any, Deque + +from twinkle.utils.logger import get_logger +from .config import TaskQueueConfig +from .types import QueuedTask, QueueState, TaskStatus + +if TYPE_CHECKING: + from twinkle.server.utils.metrics import TaskMetrics + from twinkle.server.utils.state import ServerStateProxy + +logger = get_logger() + + +class ComputeWorker: + """Serial background worker that processes GPU compute tasks. + + Implements a round-robin scheduling policy across per-adapter/per-token + queues so that no single user's long-running task (e.g. save/load) can + starve other users waiting in different queues. + + Only one task is executed at a time to preserve serial GPU semantics. + Upload and other pure-I/O tasks should NOT be submitted here; use the + background-task path in TaskQueueMixin instead. + """ + + def __init__( + self, + state: ServerStateProxy, + config: TaskQueueConfig, + task_metrics: TaskMetrics | None, + deployment_name: str, + ) -> None: + self._state = state + self._config = config + self._task_metrics = task_metrics + self._deployment_name = deployment_name + + self.task_queues: dict[str, asyncio.Queue] = {} + self.queue_order: Deque[str] = deque() + self.new_task_event: asyncio.Event = asyncio.Event() + + self._worker_task: asyncio.Task | None = None + self._started = False + self._start_lock = asyncio.Lock() + + async def ensure_started(self) -> None: + """Ensure the background worker coroutine is running.""" + if self._started and self._worker_task is not None and not self._worker_task.done(): + return + async with self._start_lock: + if self._started and self._worker_task is not None and not self._worker_task.done(): + return + self._worker_task = asyncio.create_task(self._worker_loop()) + self._started = True + + async def stop(self) -> None: + """Cancel the worker and wait for it to exit cleanly.""" + if self._worker_task and not self._worker_task.done(): + self._worker_task.cancel() + try: + await self._worker_task + except asyncio.CancelledError: + pass + self._worker_task = None + self._started = False + self.task_queues.clear() + self.queue_order.clear() + + def ensure_queue_registered(self, queue_key: str) -> None: + """Register a new per-key queue if it does not yet exist.""" + if queue_key not in self.task_queues: + self.task_queues[queue_key] = asyncio.Queue() + if queue_key not in self.queue_order: + self.queue_order.append(queue_key) + + async def _store_task_failed( + self, + task: QueuedTask, + error: str, + queue_state: str, + queue_state_reason: str | None = None, + ) -> None: + """Store FAILED status with a standardised error payload.""" + await self._state.store_future_status( + task.request_id, + TaskStatus.FAILED.value, + task.model_id, + result={ + 'error': error, + 'category': 'Server' + }, + queue_state=queue_state, + queue_state_reason=queue_state_reason, + ) + + async def fail_queue_tasks(self, queue_key: str, reason: str) -> None: + """Drain a queue and mark all pending tasks as FAILED.""" + q = self.task_queues.get(queue_key) + if q is None: + return + + drained: list[QueuedTask] = [] + while True: + try: + drained.append(q.get_nowait()) + except asyncio.QueueEmpty: + break + + for task in drained: + await self._store_task_failed(task, reason, QueueState.UNKNOWN.value, queue_state_reason=reason) + q.task_done() + + self.task_queues.pop(queue_key, None) + try: + while queue_key in self.queue_order: + self.queue_order.remove(queue_key) + except ValueError: + pass + + # ------------------------------------------------------------------ + # Worker loop helpers + # ------------------------------------------------------------------ + + async def _wait_for_work(self) -> None: + """Block until at least one queue has a pending task.""" + while True: + if any(q.qsize() > 0 for q in self.task_queues.values()): + return + self.new_task_event.clear() + await self.new_task_event.wait() + + async def _fail_timed_out_task(self, task: QueuedTask, waited: float, q: asyncio.Queue) -> None: + """Mark a queue-timed-out task as FAILED and release it from the queue.""" + error = f'Queue timeout exceeded: waited {waited:.2f}s' + await self._store_task_failed(task, error, QueueState.PAUSED_CAPACITY.value, queue_state_reason=error) + if self._task_metrics: + self._task_metrics.tasks_total.inc(tags={ + 'deployment': self._deployment_name, + 'task_type': task.task_type or 'unknown', + 'status': 'timeout', + }) + q.task_done() + + async def _execute_task(self, task: QueuedTask, queue_key: str, q: asyncio.Queue) -> None: + """Execute a single task: update status, run coroutine, record metrics. + + Handles execution timeout, general exceptions, and always calls + q.task_done() in the finally block. + """ + await self._state.store_future_status( + task.request_id, TaskStatus.RUNNING.value, task.model_id, queue_state=QueueState.ACTIVE.value) + + task_type = task.task_type or 'unknown' + exec_start = time.monotonic() + task_status = 'completed' + exec_time = 0.0 + try: + coro = task.coro_factory() + logger.info(f'[ComputeWorker] Task {task.request_id} started, ' + f'type={task_type}, queue_key={queue_key}, model_id={task.model_id}') + if self._config.execution_timeout > 0: + result = await asyncio.wait_for(coro, timeout=self._config.execution_timeout) + else: + result = await coro + exec_time = time.monotonic() - exec_start + logger.info(f'[ComputeWorker] Task {task.request_id} completed, ' + f'type={task_type}, exec_time={exec_time:.2f}s') + await self._state.store_future_status( + task.request_id, + TaskStatus.COMPLETED.value, + task.model_id, + result=result, + queue_state=QueueState.ACTIVE.value) + except asyncio.TimeoutError: + task_status = 'timeout' + exec_time = time.monotonic() - exec_start + error = (f'Execution timeout exceeded: {self._config.execution_timeout}s, ' + f'actual execution time: {exec_time:.2f}s') + logger.error(f'[ComputeWorker] Task {task.request_id} TIMEOUT after {exec_time:.2f}s, ' + f'type={task_type}, queue_key={queue_key}') + await self._store_task_failed(task, error, QueueState.ACTIVE.value) + except Exception: + task_status = 'failed' + exec_time = time.monotonic() - exec_start + error = traceback.format_exc() + logger.warning(f'[ComputeWorker] Task {task.request_id} FAILED after {exec_time:.2f}s, ' + f'type={task_type}, error:\n{traceback.format_exc(limit=3)}') + await self._store_task_failed(task, error, QueueState.ACTIVE.value) + finally: + q.task_done() + if self._task_metrics: + self._task_metrics.execution_seconds.observe( + exec_time, tags={ + 'deployment': self._deployment_name, + 'task_type': task_type + }) + self._task_metrics.tasks_total.inc(tags={ + 'deployment': self._deployment_name, + 'task_type': task_type, + 'status': task_status, + }) + + async def _try_run_one(self) -> bool: + """Round-robin: pick one runnable task and execute it. + + Iterates queues in round-robin order (at most once each). + - Queue-timed-out tasks are failed and skipped; the next queue is tried. + - The first runnable task is executed and the method returns True. + - Returns False if all queues were empty or every dequeued task had timed out. + """ + for _ in range(len(self.queue_order)): + queue_key = self.queue_order[0] + self.queue_order.rotate(-1) + + q = self.task_queues.get(queue_key) + if q is None: + continue + + try: + task: QueuedTask = q.get_nowait() + except asyncio.QueueEmpty: + continue + + # Record queue-wait metrics + if self._task_metrics: + queue_wait = time.monotonic() - task.created_at + self._task_metrics.queue_wait_seconds.observe( + queue_wait, tags={ + 'deployment': self._deployment_name, + 'task_type': task.task_type or 'unknown' + }) + total_depth = sum(qq.qsize() for qq in self.task_queues.values()) + self._task_metrics.queue_depth.set(total_depth, tags={'deployment': self._deployment_name}) + + # Check queue-level timeout (task waited too long before execution) + waited = time.monotonic() - task.created_at + if waited > self._config.queue_timeout: + await self._fail_timed_out_task(task, waited, q) + continue # try the next queue + + # Execute the task (serial: stops after the first execution) + await self._execute_task(task, queue_key, q) + return True + + return False + + # ------------------------------------------------------------------ + # Main worker loop + # ------------------------------------------------------------------ + + async def _worker_loop(self) -> None: + """Main worker loop: wait for work, run one task, repeat.""" + logger.info(f'[ComputeWorker] Started, queue_timeout={self._config.queue_timeout}, ' + f'execution_timeout={self._config.execution_timeout}') + while True: + try: + await self._wait_for_work() + executed = await self._try_run_one() + if not executed: + # All dequeued tasks were timed-out; yield briefly to avoid busy spin. + await asyncio.sleep(min(self._config.window_seconds, 0.1)) + except asyncio.CancelledError: + logger.warning('[ComputeWorker] Worker cancelled') + break + except Exception: + logger.warning(f'[ComputeWorker] Unexpected error:\n{traceback.format_exc(limit=3)}') + continue From 272d07732d89ee8928449866062b853851308e57 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sat, 18 Apr 2026 12:46:22 +0800 Subject: [PATCH 05/14] update --- src/twinkle/server/utils/task_queue/worker.py | 66 +++++++++++-------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/src/twinkle/server/utils/task_queue/worker.py b/src/twinkle/server/utils/task_queue/worker.py index 5616d8f9..78c3e8dd 100644 --- a/src/twinkle/server/utils/task_queue/worker.py +++ b/src/twinkle/server/utils/task_queue/worker.py @@ -87,6 +87,41 @@ def ensure_queue_registered(self, queue_key: str) -> None: if queue_key not in self.queue_order: self.queue_order.append(queue_key) + # ------------------------------------------------------------------ + # Metrics helpers + # ------------------------------------------------------------------ + + def _record_tasks_total(self, task_type: str, status: str) -> None: + """Increment the tasks_total counter if metrics are enabled.""" + if self._task_metrics: + self._task_metrics.tasks_total.inc(tags={ + 'deployment': self._deployment_name, + 'task_type': task_type, + 'status': status, + }) + + def _record_execution_time(self, task_type: str, exec_time: float) -> None: + """Observe execution duration in the execution_seconds histogram if metrics are enabled.""" + if self._task_metrics: + self._task_metrics.execution_seconds.observe( + exec_time, tags={ + 'deployment': self._deployment_name, + 'task_type': task_type, + }) + + def _record_queue_metrics(self, task_type: str, queue_wait: float) -> None: + """Observe queue wait time and update current queue depth if metrics are enabled.""" + if self._task_metrics: + self._task_metrics.queue_wait_seconds.observe( + queue_wait, tags={ + 'deployment': self._deployment_name, + 'task_type': task_type, + }) + total_depth = sum(qq.qsize() for qq in self.task_queues.values()) + self._task_metrics.queue_depth.set(total_depth, tags={'deployment': self._deployment_name}) + + # ------------------------------------------------------------------ + async def _store_task_failed( self, task: QueuedTask, @@ -147,12 +182,7 @@ async def _fail_timed_out_task(self, task: QueuedTask, waited: float, q: asyncio """Mark a queue-timed-out task as FAILED and release it from the queue.""" error = f'Queue timeout exceeded: waited {waited:.2f}s' await self._store_task_failed(task, error, QueueState.PAUSED_CAPACITY.value, queue_state_reason=error) - if self._task_metrics: - self._task_metrics.tasks_total.inc(tags={ - 'deployment': self._deployment_name, - 'task_type': task.task_type or 'unknown', - 'status': 'timeout', - }) + self._record_tasks_total(task.task_type or 'unknown', 'timeout') q.task_done() async def _execute_task(self, task: QueuedTask, queue_key: str, q: asyncio.Queue) -> None: @@ -202,17 +232,8 @@ async def _execute_task(self, task: QueuedTask, queue_key: str, q: asyncio.Queue await self._store_task_failed(task, error, QueueState.ACTIVE.value) finally: q.task_done() - if self._task_metrics: - self._task_metrics.execution_seconds.observe( - exec_time, tags={ - 'deployment': self._deployment_name, - 'task_type': task_type - }) - self._task_metrics.tasks_total.inc(tags={ - 'deployment': self._deployment_name, - 'task_type': task_type, - 'status': task_status, - }) + self._record_execution_time(task_type, exec_time) + self._record_tasks_total(task_type, task_status) async def _try_run_one(self) -> bool: """Round-robin: pick one runnable task and execute it. @@ -236,15 +257,8 @@ async def _try_run_one(self) -> bool: continue # Record queue-wait metrics - if self._task_metrics: - queue_wait = time.monotonic() - task.created_at - self._task_metrics.queue_wait_seconds.observe( - queue_wait, tags={ - 'deployment': self._deployment_name, - 'task_type': task.task_type or 'unknown' - }) - total_depth = sum(qq.qsize() for qq in self.task_queues.values()) - self._task_metrics.queue_depth.set(total_depth, tags={'deployment': self._deployment_name}) + queue_wait = time.monotonic() - task.created_at + self._record_queue_metrics(task.task_type or 'unknown', queue_wait) # Check queue-level timeout (task waited too long before execution) waited = time.monotonic() - task.created_at From b0182b933920e5dbc962b30f40d8b610c006e7b7 Mon Sep 17 00:00:00 2001 From: tastelikefeet <58414341+tastelikefeet@users.noreply.github.com> Date: Sat, 18 Apr 2026 10:55:07 +0800 Subject: [PATCH 06/14] Fix save (#165) --- src/twinkle/model/multi_lora.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index 649c0c62..1cc96535 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -190,11 +190,17 @@ def match_target_modules( if target_modules == 'all-linear': return True + # Strip LoRA-specific suffixes (e.g. ".lora_A.default.weight") so that + # a full parameter name like "model.layers.0.attn.proj.lora_A.default.weight" + # can be matched against target_modules like ["attn.proj"]. + if '.lora_' in module_name: + cleaned = re.sub(r'\.lora_\w+(\.[\w-]+)*$', '', module_name) + if isinstance(target_modules, str): - return re.fullmatch(target_modules, module_name) is not None + return re.fullmatch(target_modules, cleaned) is not None - if isinstance(target_modules, list): - return any(module_name.endswith(t) for t in target_modules) + if isinstance(target_modules, (list, set)): + return any(cleaned.endswith(t) for t in target_modules) return False From 3b71b9b54ab8bce39472bfb223918e7d2c56195a Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sat, 18 Apr 2026 13:23:18 +0800 Subject: [PATCH 07/14] update --- cookbook/client/tinker/self_host/dpo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cookbook/client/tinker/self_host/dpo.py b/cookbook/client/tinker/self_host/dpo.py index d55e9ce3..51474ca0 100644 --- a/cookbook/client/tinker/self_host/dpo.py +++ b/cookbook/client/tinker/self_host/dpo.py @@ -51,7 +51,7 @@ max_length = 2048 lora_rank = 8 system_prompt = 'You are a helpful assistant.' -use_swanlab = True +use_swanlab = False # --------------------------------------------------------------------------- From f9f6a76727983c646a0e8392b84b6a138ac60a90 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 19 Apr 2026 10:31:50 +0800 Subject: [PATCH 08/14] update log --- src/twinkle/server/utils/task_queue/mixin.py | 8 +++---- src/twinkle/server/utils/task_queue/worker.py | 23 ++++++++----------- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/src/twinkle/server/utils/task_queue/mixin.py b/src/twinkle/server/utils/task_queue/mixin.py index 0491fa99..65689029 100644 --- a/src/twinkle/server/utils/task_queue/mixin.py +++ b/src/twinkle/server/utils/task_queue/mixin.py @@ -183,9 +183,6 @@ async def schedule_task( if self._event_loop is None: self._event_loop = asyncio.get_running_loop() - logger.info(f'[TaskQueue] Scheduling task {request_id}, type={task_type or "unknown"}, ' - f'model_id={model_id}, input_tokens={input_tokens}') - await self.state.store_future_status( request_id, TaskStatus.PENDING.value, model_id, queue_state=QueueState.ACTIVE.value) @@ -206,8 +203,9 @@ async def schedule_task( )) await self.state.store_future_status( request_id, TaskStatus.QUEUED.value, model_id, queue_state=QueueState.ACTIVE.value) - logger.info(f'[TaskQueue] Task {request_id} queued, queue_key={queue_key}, ' - f'queue_size={q.qsize()}, total_queues={len(self._compute_worker.task_queues)}') + logger.info(f'[TaskQueue] Task {request_id} queued, type={task_type or "unknown"}, ' + f'model_id={model_id}, queue_key={queue_key}, ' + f'queue_depth={q.qsize()}, input_tokens={input_tokens}') self._compute_worker.new_task_event.set() diff --git a/src/twinkle/server/utils/task_queue/worker.py b/src/twinkle/server/utils/task_queue/worker.py index 78c3e8dd..77740cb7 100644 --- a/src/twinkle/server/utils/task_queue/worker.py +++ b/src/twinkle/server/utils/task_queue/worker.py @@ -200,15 +200,14 @@ async def _execute_task(self, task: QueuedTask, queue_key: str, q: asyncio.Queue exec_time = 0.0 try: coro = task.coro_factory() - logger.info(f'[ComputeWorker] Task {task.request_id} started, ' - f'type={task_type}, queue_key={queue_key}, model_id={task.model_id}') + logger.debug(f'[ComputeWorker] Task {task.request_id} executing, ' + f'type={task_type}, queue_key={queue_key}') if self._config.execution_timeout > 0: result = await asyncio.wait_for(coro, timeout=self._config.execution_timeout) else: result = await coro exec_time = time.monotonic() - exec_start - logger.info(f'[ComputeWorker] Task {task.request_id} completed, ' - f'type={task_type}, exec_time={exec_time:.2f}s') + logger.info(f'[ComputeWorker] Task {task.request_id} completed in {exec_time:.2f}s, type={task_type}') await self._state.store_future_status( task.request_id, TaskStatus.COMPLETED.value, @@ -227,8 +226,8 @@ async def _execute_task(self, task: QueuedTask, queue_key: str, q: asyncio.Queue task_status = 'failed' exec_time = time.monotonic() - exec_start error = traceback.format_exc() - logger.warning(f'[ComputeWorker] Task {task.request_id} FAILED after {exec_time:.2f}s, ' - f'type={task_type}, error:\n{traceback.format_exc(limit=3)}') + logger.error(f'[ComputeWorker] Task {task.request_id} FAILED after {exec_time:.2f}s, ' + f'type={task_type}:\n{traceback.format_exc(limit=3)}') await self._store_task_failed(task, error, QueueState.ACTIVE.value) finally: q.task_done() @@ -256,14 +255,12 @@ async def _try_run_one(self) -> bool: except asyncio.QueueEmpty: continue - # Record queue-wait metrics + # Record queue-wait metrics and check queue-level timeout queue_wait = time.monotonic() - task.created_at self._record_queue_metrics(task.task_type or 'unknown', queue_wait) - # Check queue-level timeout (task waited too long before execution) - waited = time.monotonic() - task.created_at - if waited > self._config.queue_timeout: - await self._fail_timed_out_task(task, waited, q) + if queue_wait > self._config.queue_timeout: + await self._fail_timed_out_task(task, queue_wait, q) continue # try the next queue # Execute the task (serial: stops after the first execution) @@ -288,8 +285,8 @@ async def _worker_loop(self) -> None: # All dequeued tasks were timed-out; yield briefly to avoid busy spin. await asyncio.sleep(min(self._config.window_seconds, 0.1)) except asyncio.CancelledError: - logger.warning('[ComputeWorker] Worker cancelled') + logger.info('[ComputeWorker] Worker stopped') break except Exception: - logger.warning(f'[ComputeWorker] Unexpected error:\n{traceback.format_exc(limit=3)}') + logger.error(f'[ComputeWorker] Unexpected error:\n{traceback.format_exc(limit=3)}') continue From 803506ea7afdbef02e41374e686ccbc509824438 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 19 Apr 2026 12:24:36 +0800 Subject: [PATCH 09/14] update --- src/twinkle/server/model/twinkle_handlers.py | 129 ++++++++++++------- src/twinkle/server/utils/task_queue/mixin.py | 4 + 2 files changed, 90 insertions(+), 43 deletions(-) diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index d0e9c124..b0eb8d0b 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -82,7 +82,7 @@ async def create(request: Request, body: types.CreateRequest, @app.post('/twinkle/forward', response_model=types.ForwardResponse) async def forward(request: Request, body: types.ForwardRequest, self: ModelManagement = Depends(self_fn)) -> types.ForwardResponse: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -92,7 +92,19 @@ async def _task(): ret = self.model.forward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) return {'result': ret} - return await run_task(self.schedule_task_and_wait(_task, task_type='forward')) + inputs_list = body.inputs if isinstance(body.inputs, list) else [body.inputs] + input_tokens = sum(len(inp.get('input_ids', [])) if isinstance(inp, dict) else 0 for inp in inputs_list) + batch_size = len(inputs_list) + return await run_task( + self.schedule_task_and_wait( + _task, + model_id=adapter_name, + token=token, + input_tokens=input_tokens, + batch_size=batch_size, + data_world_size=self.device_mesh.data_world_size, + task_type='forward', + )) @app.post('/twinkle/forward_only', response_model=types.ForwardResponse) async def forward_only( @@ -100,7 +112,7 @@ async def forward_only( body: types.ForwardOnlyRequest, self: ModelManagement = Depends(self_fn), ) -> types.ForwardResponse: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -110,7 +122,16 @@ async def _task(): ret = self.model.forward_only(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) return {'result': ret} - return await run_task(self.schedule_task_and_wait(_task, task_type='forward_only')) + inputs_list = body.inputs if isinstance(body.inputs, list) else [body.inputs] + input_tokens = sum(len(inp.get('input_ids', [])) if isinstance(inp, dict) else 0 for inp in inputs_list) + return await run_task( + self.schedule_task_and_wait( + _task, + model_id=adapter_name, + token=token, + input_tokens=input_tokens, + task_type='forward_only', + )) @app.post('/twinkle/calculate_loss', response_model=types.CalculateLossResponse) async def calculate_loss( @@ -118,7 +139,7 @@ async def calculate_loss( body: types.AdapterRequest, self: ModelManagement = Depends(self_fn), ) -> types.CalculateLossResponse: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -127,11 +148,12 @@ async def _task(): ret = self.model.calculate_loss(adapter_name=adapter_name, **extra_kwargs) return {'result': ret} - return await run_task(self.schedule_task_and_wait(_task, task_type='calculate_loss')) + return await run_task( + self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='calculate_loss')) @app.post('/twinkle/backward') async def backward(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -139,7 +161,7 @@ async def _task(): extra_kwargs = body.model_extra or {} self.model.backward(adapter_name=adapter_name, **extra_kwargs) - await run_task(self.schedule_task_and_wait(_task, task_type='backward')) + await run_task(self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='backward')) @app.post('/twinkle/forward_backward', response_model=types.ForwardBackwardResponse) async def forward_backward( @@ -147,7 +169,7 @@ async def forward_backward( body: types.ForwardRequest, self: ModelManagement = Depends(self_fn), ) -> types.ForwardBackwardResponse: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) def first_element(data): @@ -168,7 +190,19 @@ async def _task(): ret = self.model.forward_backward(inputs=all_inputs, adapter_name=adapter_name, **extra_kwargs) return {'result': ret} - return await run_task(self.schedule_task_and_wait(_task, task_type='forward_backward')) + inputs_list = body.inputs if isinstance(body.inputs, list) else [body.inputs] + input_tokens = sum(len(inp.get('input_ids', [])) if isinstance(inp, dict) else 0 for inp in inputs_list) + batch_size = len(inputs_list) + return await run_task( + self.schedule_task_and_wait( + _task, + model_id=adapter_name, + token=token, + input_tokens=input_tokens, + batch_size=batch_size, + data_world_size=self.device_mesh.data_world_size, + task_type='forward_backward', + )) @app.post('/twinkle/clip_grad_norm', response_model=types.ClipGradNormResponse) async def clip_grad_norm( @@ -176,7 +210,7 @@ async def clip_grad_norm( body: types.AdapterRequest, self: ModelManagement = Depends(self_fn), ) -> types.ClipGradNormResponse: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -185,11 +219,12 @@ async def _task(): ret = self.model.clip_grad_norm(adapter_name=adapter_name, **extra_kwargs) return {'result': str(ret)} - return await run_task(self.schedule_task_and_wait(_task, task_type='clip_grad_norm')) + return await run_task( + self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='clip_grad_norm')) @app.post('/twinkle/step') async def step(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -197,11 +232,11 @@ async def _task(): extra_kwargs = body.model_extra or {} self.model.step(adapter_name=adapter_name, **extra_kwargs) - await run_task(self.schedule_task_and_wait(_task, task_type='step')) + await run_task(self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='step')) @app.post('/twinkle/zero_grad') async def zero_grad(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -209,11 +244,11 @@ async def _task(): extra_kwargs = body.model_extra or {} self.model.zero_grad(adapter_name=adapter_name, **extra_kwargs) - await run_task(self.schedule_task_and_wait(_task, task_type='zero_grad')) + await run_task(self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='zero_grad')) @app.post('/twinkle/lr_step') async def lr_step(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -221,7 +256,7 @@ async def _task(): extra_kwargs = body.model_extra or {} self.model.lr_step(adapter_name=adapter_name, **extra_kwargs) - await run_task(self.schedule_task_and_wait(_task, task_type='lr_step')) + await run_task(self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='lr_step')) @app.post('/twinkle/clip_grad_and_step') async def clip_grad_and_step( @@ -229,7 +264,7 @@ async def clip_grad_and_step( body: types.ClipGradAndStepRequest, self: ModelManagement = Depends(self_fn), ) -> None: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -242,7 +277,8 @@ async def _task(): **extra_kwargs, ) - await run_task(self.schedule_task_and_wait(_task, task_type='clip_grad_and_step')) + await run_task( + self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='clip_grad_and_step')) @app.post('/twinkle/get_train_configs', response_model=types.GetTrainConfigsResponse) async def get_train_configs( @@ -250,7 +286,7 @@ async def get_train_configs( body: types.AdapterRequest, self: ModelManagement = Depends(self_fn), ) -> types.GetTrainConfigsResponse: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -259,11 +295,12 @@ async def _task(): ret = self.model.get_train_configs(adapter_name=adapter_name, **extra_kwargs) return {'result': ret} - return await run_task(self.schedule_task_and_wait(_task, task_type='get_train_configs')) + return await run_task( + self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='get_train_configs')) @app.post('/twinkle/set_loss') async def set_loss(request: Request, body: types.SetLossRequest, self: ModelManagement = Depends(self_fn)) -> None: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -271,7 +308,7 @@ async def _task(): extra_kwargs = body.model_extra or {} self.model.set_loss(body.loss_cls, adapter_name=adapter_name, **extra_kwargs) - await run_task(self.schedule_task_and_wait(_task, task_type='set_loss')) + await run_task(self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='set_loss')) @app.post('/twinkle/set_optimizer') async def set_optimizer( @@ -279,7 +316,7 @@ async def set_optimizer( body: types.SetOptimizerRequest, self: ModelManagement = Depends(self_fn), ) -> None: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -287,7 +324,8 @@ async def _task(): extra_kwargs = body.model_extra or {} self.model.set_optimizer(body.optimizer_cls, adapter_name=adapter_name, **extra_kwargs) - await run_task(self.schedule_task_and_wait(_task, task_type='set_optimizer')) + await run_task( + self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='set_optimizer')) @app.post('/twinkle/set_lr_scheduler') async def set_lr_scheduler( @@ -295,7 +333,7 @@ async def set_lr_scheduler( body: types.SetLrSchedulerRequest, self: ModelManagement = Depends(self_fn), ) -> None: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -303,7 +341,8 @@ async def _task(): extra_kwargs = body.model_extra or {} self.model.set_lr_scheduler(body.scheduler_cls, adapter_name=adapter_name, **extra_kwargs) - await run_task(self.schedule_task_and_wait(_task, task_type='set_lr_scheduler')) + await run_task( + self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='set_lr_scheduler')) @app.post('/twinkle/save', response_model=types.SaveResponse) async def save(request: Request, body: types.SaveRequest, @@ -328,7 +367,7 @@ async def _task(): **extra_kwargs) return {'twinkle_path': twinkle_path, 'checkpoint_dir': checkpoint_dir} - return await run_task(self.schedule_task_and_wait(_task, task_type='save')) + return await run_task(self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='save')) @app.post('/twinkle/load') async def load(request: Request, body: types.LoadRequest, self: ModelManagement = Depends(self_fn)) -> None: @@ -348,7 +387,7 @@ async def _task(): token=token, **extra_kwargs) - await run_task(self.schedule_task_and_wait(_task, task_type='load')) + await run_task(self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='load')) @app.post('/twinkle/upload_to_hub', response_model=types.UploadToHubResponse) async def upload_to_hub( @@ -433,7 +472,8 @@ async def _task(): training_run_manager.save(adapter_name, run_config) return {'status': 'ok', 'adapter_name': adapter_name} - return await run_task(self.schedule_task_and_wait(_task, task_type='add_adapter_to_model')) + return await run_task( + self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='add_adapter_to_model')) @app.post('/twinkle/apply_patch') async def apply_patch( @@ -441,7 +481,7 @@ async def apply_patch( body: types.ApplyPatchRequest, self: ModelManagement = Depends(self_fn), ) -> None: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -450,7 +490,7 @@ async def _task(): patch_cls = deserialize_object(body.patch_cls) self.model.apply_patch(patch_cls, adapter_name=adapter_name, **extra_kwargs) - await run_task(self.schedule_task_and_wait(_task, task_type='apply_patch')) + await run_task(self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='apply_patch')) @app.post('/twinkle/add_metric') async def add_metric( @@ -458,7 +498,7 @@ async def add_metric( body: types.AddMetricRequest, self: ModelManagement = Depends(self_fn), ) -> None: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -467,7 +507,7 @@ async def _task(): metric_cls = deserialize_object(body.metric_cls) self.model.add_metric(metric_cls, is_training=body.is_training, adapter_name=adapter_name, **extra_kwargs) - await run_task(self.schedule_task_and_wait(_task, task_type='add_metric')) + await run_task(self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='add_metric')) @app.post('/twinkle/set_template') async def set_template( @@ -475,7 +515,7 @@ async def set_template( body: types.SetTemplateRequest, self: ModelManagement = Depends(self_fn), ) -> None: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -483,7 +523,7 @@ async def _task(): extra_kwargs = body.model_extra or {} self.model.set_template(body.template_cls, adapter_name=adapter_name, **extra_kwargs) - await run_task(self.schedule_task_and_wait(_task, task_type='set_template')) + await run_task(self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='set_template')) @app.post('/twinkle/set_processor') async def set_processor( @@ -491,7 +531,7 @@ async def set_processor( body: types.SetProcessorRequest, self: ModelManagement = Depends(self_fn), ) -> None: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -499,7 +539,8 @@ async def _task(): extra_kwargs = body.model_extra or {} self.model.set_processor(body.processor_cls, adapter_name=adapter_name, **extra_kwargs) - await run_task(self.schedule_task_and_wait(_task, task_type='set_processor')) + await run_task( + self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='set_processor')) @app.post('/twinkle/calculate_metric', response_model=types.CalculateMetricResponse) async def calculate_metric( @@ -507,7 +548,7 @@ async def calculate_metric( body: types.CalculateMetricRequest, self: ModelManagement = Depends(self_fn), ) -> types.CalculateMetricResponse: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -516,7 +557,8 @@ async def _task(): ret = self.model.calculate_metric(is_training=body.is_training, adapter_name=adapter_name, **extra_kwargs) return {'result': ret} - return await run_task(self.schedule_task_and_wait(_task, task_type='calculate_metric')) + return await run_task( + self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='calculate_metric')) @app.post('/twinkle/get_state_dict', response_model=types.GetStateDictResponse) async def get_state_dict( @@ -524,7 +566,7 @@ async def get_state_dict( body: types.GetStateDictRequest, self: ModelManagement = Depends(self_fn), ) -> types.GetStateDictResponse: - await self._on_request_start(request) + token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -533,4 +575,5 @@ async def _task(): ret = self.model.get_state_dict(adapter_name=adapter_name, **extra_kwargs) return {'result': ret} - return await run_task(self.schedule_task_and_wait(_task, task_type='get_state_dict')) + return await run_task( + self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='get_state_dict')) diff --git a/src/twinkle/server/utils/task_queue/mixin.py b/src/twinkle/server/utils/task_queue/mixin.py index 65689029..1962a9d8 100644 --- a/src/twinkle/server/utils/task_queue/mixin.py +++ b/src/twinkle/server/utils/task_queue/mixin.py @@ -221,6 +221,8 @@ async def schedule_task_and_wait( model_id: str | None = None, token: str | None = None, input_tokens: int = 0, + batch_size: int | None = None, + data_world_size: int | None = None, task_type: str | None = None, ) -> Any: """Schedule a compute task and block until it completes. @@ -236,6 +238,8 @@ async def schedule_task_and_wait( model_id=model_id, token=token, input_tokens=input_tokens, + batch_size=batch_size, + data_world_size=data_world_size, task_type=task_type, ) request_id = future_ref.get('request_id') From a694f05b5e16a351f394d246d0055a9010f69188 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 19 Apr 2026 13:15:00 +0800 Subject: [PATCH 10/14] update sampler --- src/twinkle/server/__main__.py | 6 +++ src/twinkle/server/gateway/proxy.py | 4 ++ src/twinkle/server/gateway/server.py | 23 ++++++--- .../server/gateway/tinker_gateway_handlers.py | 10 ++-- src/twinkle/server/launcher.py | 49 ++++++++++++++---- src/twinkle/server/model/app.py | 33 +++++++----- .../server/processor/twinkle_handlers.py | 4 +- .../server/sampler/twinkle_handlers.py | 50 +++++++++++++------ src/twinkle/server/utils/task_queue/mixin.py | 5 +- 9 files changed, 130 insertions(+), 54 deletions(-) diff --git a/src/twinkle/server/__main__.py b/src/twinkle/server/__main__.py index e18283c3..8f97ef09 100644 --- a/src/twinkle/server/__main__.py +++ b/src/twinkle/server/__main__.py @@ -9,6 +9,7 @@ from __future__ import annotations import argparse +import os import sys from pathlib import Path @@ -77,6 +78,11 @@ def main(args: list[str] | None = None) -> int: try: from twinkle.server.launcher import launch_server + # Apply log level so that all loggers (including those created later) + # pick up the user-specified level via the LOG_LEVEL env var that + # get_logger() already reads. + os.environ['LOG_LEVEL'] = parsed_args.log_level + config_path = Path(parsed_args.config) if not config_path.exists(): logger.error(f'Config file not found: {config_path}') diff --git a/src/twinkle/server/gateway/proxy.py b/src/twinkle/server/gateway/proxy.py index 5ed9b7bf..0978b8e0 100644 --- a/src/twinkle/server/gateway/proxy.py +++ b/src/twinkle/server/gateway/proxy.py @@ -39,6 +39,10 @@ def __init__( # Disable proxy env vars to avoid external routing self.client = httpx.AsyncClient(timeout=None, trust_env=False) + async def close(self) -> None: + """Close the underlying httpx.AsyncClient to release connections.""" + await self.client.aclose() + def _build_target_url(self, service_type: str, base_model: str, endpoint: str) -> str: """Build the target URL for internal service routing. diff --git a/src/twinkle/server/gateway/server.py b/src/twinkle/server/gateway/server.py index 98e2dba5..dd89b84c 100644 --- a/src/twinkle/server/gateway/server.py +++ b/src/twinkle/server/gateway/server.py @@ -8,6 +8,7 @@ from __future__ import annotations import asyncio +from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException, Request from ray import serve from typing import Any @@ -29,9 +30,10 @@ class GatewayServer: def __init__(self, supported_models: list | None = None, - server_config: dict[str, Any] = {}, + server_config: dict[str, Any] | None = None, http_options: dict[str, Any] | None = None, **kwargs) -> None: + server_config = server_config or {} self.state = get_server_state(**server_config) self.route_prefix = kwargs.get('route_prefix', '/api/v1') self.http_options = http_options or {} @@ -71,7 +73,7 @@ async def _get_base_model(self, model_id: str) -> str: def build_server_app(deploy_options: dict[str, Any], supported_models: list | None = None, - server_config: dict[str, Any] = {}, + server_config: dict[str, Any] | None = None, http_options: dict[str, Any] | None = None, **kwargs): """Build and configure the unified gateway server application. @@ -88,7 +90,19 @@ def build_server_app(deploy_options: dict[str, Any], Returns: Configured Ray Serve deployment bound with options """ - app = FastAPI() + + def get_self() -> GatewayServer: + return serve.get_replica_context().servable_object + + @asynccontextmanager + async def lifespan(app: FastAPI): + yield + try: + await get_self().proxy.close() + except Exception: + pass + + app = FastAPI(lifespan=lifespan) @app.middleware('http') async def verify_token(request: Request, call_next): @@ -96,9 +110,6 @@ async def verify_token(request: Request, call_next): app.middleware('http')(create_metrics_middleware('Gateway')) - def get_self() -> GatewayServer: - return serve.get_replica_context().servable_object - _register_tinker_routes(app, get_self) _register_twinkle_routes(app, get_self) diff --git a/src/twinkle/server/gateway/tinker_gateway_handlers.py b/src/twinkle/server/gateway/tinker_gateway_handlers.py index 516da528..575ef82a 100644 --- a/src/twinkle/server/gateway/tinker_gateway_handlers.py +++ b/src/twinkle/server/gateway/tinker_gateway_handlers.py @@ -83,7 +83,7 @@ async def retrieve_future(request: Request, request_id = body.request_id max_wait = float(os.environ.get('TWINKLE_LONG_POLL_TIMEOUT', '30')) poll_interval = float(os.environ.get('TWINKLE_POLL_INTERVAL', '0.5')) - start = asyncio.get_event_loop().time() + start = asyncio.get_running_loop().time() while True: record = await self.state.get_future(request_id) @@ -95,7 +95,7 @@ async def retrieve_future(request: Request, if status not in ('pending', 'queued', 'running', 'rate_limited'): break - if asyncio.get_event_loop().time() - start >= max_wait: + if asyncio.get_running_loop().time() - start >= max_wait: response_data = {'type': 'try_again'} if queue_state := record.get('queue_state'): response_data['queue_state'] = queue_state @@ -105,10 +105,6 @@ async def retrieve_future(request: Request, await asyncio.sleep(poll_interval) - record = await self.state.get_future(request_id) - if not record: - return {'type': 'try_again'} - status = record.get('status') if status == 'rate_limited': @@ -263,6 +259,8 @@ async def asample(request: Request, body: types.SampleRequest, self: GatewayServ session = await self.state.get_sampling_session(body.sampling_session_id) if session: base_model = session.get('base_model') + if not base_model: + raise HTTPException(status_code=400, detail='base_model is required but could not be resolved') return await self.proxy.proxy_to_sampler(request, 'asample', base_model) @app.post('/save_weights_for_sampler') diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py index 9d28e52b..e2a6179a 100644 --- a/src/twinkle/server/launcher.py +++ b/src/twinkle/server/launcher.py @@ -21,9 +21,10 @@ """ from __future__ import annotations -import time +import signal +import threading from pathlib import Path -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, NoReturn, Optional, Union from twinkle import get_logger from twinkle.server.utils.ray_serve_patch import apply_ray_serve_patches, get_runtime_env_for_patches @@ -146,9 +147,12 @@ def _start_serve(self) -> None: from ray import serve try: + from ray.serve.context import _get_global_client + _get_global_client() + # Serve is running, shut it down before re-starting serve.shutdown() - time.sleep(2) except Exception: + # Serve not running — nothing to shut down pass http_options = self.config.get('http_options', {}) @@ -182,6 +186,9 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None: deploy_options = {} if deployments: + if len(deployments) > 1: + logger.warning(f'Application "{name}" has {len(deployments)} deployments configured, ' + f'but only the first deployment will be used.') deploy_config = deployments[0] if isinstance(deploy_config, dict): deploy_options = {k: v for k, v in deploy_config.items() if k != 'name'} @@ -197,7 +204,12 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None: logger.info(f'Deployed {name} at {route_prefix}') def launch(self) -> None: - """Launch the server with all configured applications.""" + """Launch the server with all configured applications. + + Blocks the calling thread to keep the server running. Installs signal + handlers for SIGINT/SIGTERM so that ``serve.shutdown()`` is called on + termination instead of leaving orphaned deployments. + """ # Apply Ray Serve patches before initializing Ray apply_ray_serve_patches() @@ -226,8 +238,26 @@ def launch(self) -> None: dict) else app_config.route_prefix print(f' - http://{host}:{port}{route_prefix}') - while True: - time.sleep(3600) + # Graceful shutdown via signal handling + shutdown_event = threading.Event() + + def _handle_signal(signum, frame): + sig_name = signal.Signals(signum).name + logger.info(f'Received {sig_name}, shutting down gracefully...') + shutdown_event.set() + + signal.signal(signal.SIGINT, _handle_signal) + signal.signal(signal.SIGTERM, _handle_signal) + + # Block until a termination signal is received + shutdown_event.wait() + + from ray import serve + try: + serve.shutdown() + logger.info('Ray Serve shut down successfully') + except Exception: + logger.warning('Error during Ray Serve shutdown', exc_info=True) @classmethod def from_yaml( @@ -264,20 +294,18 @@ def launch_server( config: dict[str, Any] | None = None, config_path: str | Path | None = None, ray_namespace: str | None = None, -) -> ServerLauncher: +) -> None: """ Launch a twinkle server with flexible configuration options. This is the main entry point for launching servers programmatically. + The call blocks until a SIGINT/SIGTERM signal is received. Args: config: Configuration dictionary (takes precedence over config_path) config_path: Path to YAML config file ray_namespace: Ray namespace - Returns: - The ServerLauncher instance - Raises: ValueError: If neither config nor config_path is provided @@ -306,4 +334,3 @@ def launch_server( ) launcher.launch() - return launcher diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index 037a3f31..ecd841e3 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -7,6 +7,7 @@ """ from __future__ import annotations +from contextlib import asynccontextmanager from fastapi import FastAPI, Request from ray import serve from ray.serve.config import RequestRouterConfig @@ -45,7 +46,7 @@ def __init__(self, device_group: dict[str, Any], device_mesh: dict[str, Any], use_megatron: bool = False, - adapter_config: dict[str, Any] = {}, + adapter_config: dict[str, Any] | None = None, queue_config: dict[str, Any] | None = None, **kwargs): self.device_group = DeviceGroup(**device_group) @@ -83,7 +84,7 @@ def __init__(self, # Initialize mixins self._init_task_queue(TaskQueueConfig.from_dict(queue_config), deployment_name='Model') - self._init_adapter_manager(**adapter_config) + self._init_adapter_manager(**(adapter_config or {})) # Note: countdown task is started lazily in _ensure_sticky() async def _ensure_replica_registered(self): @@ -108,13 +109,10 @@ async def _on_request_start(self, request: Request) -> str: token = get_token_from_request(request) return token - def __del__(self): + async def shutdown(self) -> None: + """Explicit async cleanup — called via FastAPI shutdown event.""" try: - # Best-effort cleanup; event loop may already be closed - import asyncio - loop = asyncio.get_event_loop() - if loop.is_running(): - asyncio.create_task(self.state.unregister_replica(self.replica_id)) + await self.state.unregister_replica(self.replica_id) except Exception: pass @@ -136,7 +134,7 @@ def build_model_app(model_id: str, device_mesh: dict[str, Any], deploy_options: dict[str, Any], use_megatron: bool = False, - adapter_config: dict[str, Any] = {}, + adapter_config: dict[str, Any] | None = None, queue_config: dict[str, Any] | None = None, **kwargs): """Build a unified model management application for distributed training. @@ -157,9 +155,21 @@ def build_model_app(model_id: str, Returns: Configured Ray Serve deployment bound with parameters """ + # Build the FastAPI app and register all routes BEFORE serve.ingress so that # the frozen app contains the complete route table (visible to ProxyActor). - app = FastAPI() + def get_self() -> ModelManagement: + return serve.get_replica_context().servable_object + + @asynccontextmanager + async def lifespan(app: FastAPI): + yield + try: + await get_self().shutdown() + except Exception: + pass + + app = FastAPI(lifespan=lifespan) @app.middleware('http') async def verify_token(request: Request, call_next): @@ -167,9 +177,6 @@ async def verify_token(request: Request, call_next): app.middleware('http')(create_metrics_middleware('Model')) - def get_self() -> ModelManagement: - return serve.get_replica_context().servable_object - _register_tinker_routes(app, get_self) _register_twinkle_routes(app, get_self) diff --git a/src/twinkle/server/processor/twinkle_handlers.py b/src/twinkle/server/processor/twinkle_handlers.py index 14bde4b3..66799ee8 100644 --- a/src/twinkle/server/processor/twinkle_handlers.py +++ b/src/twinkle/server/processor/twinkle_handlers.py @@ -77,7 +77,7 @@ def _do_create(): return getattr(processor_module, class_type)( remote_group=_remote_group, device_mesh=_device_mesh, instance_id=processor_id, **resolved_kwargs) - processor = await asyncio.get_event_loop().run_in_executor(None, _do_create) + processor = await asyncio.get_running_loop().run_in_executor(None, _do_create) self.resource_dict[processor_id] = processor return types.ProcessorCreateResponse(processor_id='pid:' + processor_id) @@ -117,7 +117,7 @@ def _do_call(): except StopIteration: return True, None - is_exhausted, result = await asyncio.get_event_loop().run_in_executor(None, _do_call) + is_exhausted, result = await asyncio.get_running_loop().run_in_executor(None, _do_call) if function_name == '__next__': if is_exhausted: diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py index f20d1738..40139a53 100644 --- a/src/twinkle/server/sampler/twinkle_handlers.py +++ b/src/twinkle/server/sampler/twinkle_handlers.py @@ -56,19 +56,37 @@ def _register_twinkle_sampler_routes(app: FastAPI, self_fn: Callable[[], Sampler It is wired in via Depends so it is resolved lazily at request time. """ + async def run_task(coro): + """Await a schedule_task_and_wait coroutine and surface any exception as a + structured HTTP 500 response so the client receives the full traceback instead + of an opaque connection-level error. + + Note: HTTPException is re-raised directly to preserve its status code and detail. + """ + try: + return await coro + except HTTPException: + raise + except Exception: + logger.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=traceback.format_exc()) + @app.post('/twinkle/create', response_model=types.CreateResponse) - def create(request: Request, self: SamplerManagement = Depends(self_fn)) -> types.CreateResponse: + async def create(request: Request, self: SamplerManagement = Depends(self_fn)) -> types.CreateResponse: """Health check / session creation endpoint.""" return types.CreateResponse() @app.post('/twinkle/sample', response_model=types.SampleResponseModelList) - def sample(request: Request, body: types.SampleRequest, - self: SamplerManagement = Depends(self_fn)) -> types.SampleResponseModelList: + async def sample( + request: Request, body: types.SampleRequest, + self: SamplerManagement = Depends(self_fn)) -> types.SampleResponseModelList: """Sample completions from the model. Supports Trajectory or InputFeature inputs, with optional LoRA adapter. """ - try: + token = await self._on_request_start(request) + + async def _task(): # Resolve adapter adapter_path = None adapter_name = body.adapter_name or '' @@ -76,8 +94,6 @@ def sample(request: Request, body: types.SampleRequest, if body.adapter_uri: from twinkle.server.common.checkpoint_factory import create_checkpoint_manager - from twinkle.server.utils.validation import get_token_from_request - token = get_token_from_request(request) checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') _, adapter_path = checkpoint_manager.parse_adapter_uri(body.adapter_uri) @@ -100,15 +116,12 @@ def sample(request: Request, body: types.SampleRequest, if body.sampling_params: params = SamplingParams.from_dict(body.sampling_params) - # Call sampler responses = self.sampler.sample( inputs, params, adapter_name=full_adapter_name, adapter_path=adapter_path, ) - if callable(responses): - responses = responses() sample_models = [] for response in responses: @@ -122,7 +135,6 @@ def sample(request: Request, body: types.SampleRequest, if seq.new_input_feature is not None else None, ) for seq in response.sequences ] - sample_models.append( types.SampleResponseModel( sequences=sequences, @@ -130,12 +142,20 @@ def sample(request: Request, body: types.SampleRequest, topk_prompt_logprobs=response.topk_prompt_logprobs, )) return types.SampleResponseModelList(samples=sample_models) - except Exception: - logger.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=traceback.format_exc()) + + # Calculate metrics for queue scheduling + inputs_list = body.inputs if isinstance(body.inputs, list) else [body.inputs] + input_tokens = sum(len(inp.get('input_ids', [])) if isinstance(inp, dict) else 0 for inp in inputs_list) + return await run_task( + self.schedule_task_and_wait( + _task, + token=token, + input_tokens=input_tokens, + task_type='sample', + )) @app.post('/twinkle/set_template', response_model=types.SetTemplateResponse) - def set_template( + async def set_template( request: Request, body: types.SetTemplateRequest, self: SamplerManagement = Depends(self_fn), @@ -146,7 +166,7 @@ def set_template( return types.SetTemplateResponse() @app.post('/twinkle/add_adapter_to_sampler', response_model=types.AddAdapterResponse) - def add_adapter_to_sampler( + async def add_adapter_to_sampler( request: Request, body: types.AddAdapterRequest, self: SamplerManagement = Depends(self_fn), diff --git a/src/twinkle/server/utils/task_queue/mixin.py b/src/twinkle/server/utils/task_queue/mixin.py index 1962a9d8..a5ecbc7e 100644 --- a/src/twinkle/server/utils/task_queue/mixin.py +++ b/src/twinkle/server/utils/task_queue/mixin.py @@ -246,11 +246,14 @@ async def schedule_task_and_wait( if request_id is None: raise RuntimeError(f'Task scheduling failed: {future_ref}') + poll_interval = 0.05 + max_poll_interval = 1.0 while True: record = await self.state.get_future(request_id) if record and record.get('status') not in ('pending', 'queued', 'running'): break - await asyncio.sleep(0.05) + await asyncio.sleep(poll_interval) + poll_interval = min(poll_interval * 2, max_poll_interval) if record['status'] == 'failed': error = record.get('result', {}).get('error', 'Unknown error') From 1eef4e8a781f77dc51b8af8f01d0a6a3f97066f4 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 19 Apr 2026 14:40:57 +0800 Subject: [PATCH 11/14] fix metrics --- src/twinkle/infra/__init__.py | 6 ++++++ src/twinkle/server/model/backends/megatron_model.py | 7 ++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 8462d06b..aa559e76 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -298,6 +298,12 @@ def _collect_func(method: Union[Literal['none', 'flatten', 'mean', 'sum', 'first elif method == 'last_pp': assert device_mesh is not None return [r for i, r in enumerate(result) if i in device_mesh.get_pp_last_ranks()] + elif method == 'last_pp_first': + # Return the first result from the last PP stage workers. + # Falls back to result[0] when PP = 1 (all workers are the last stage). + assert device_mesh is not None + last_pp = [r for i, r in enumerate(result) if i in device_mesh.get_pp_last_ranks()] + return last_pp[0] if last_pp else result[0] elif isinstance(method, Callable): # Callable return method(result, device_mesh=device_mesh) diff --git a/src/twinkle/server/model/backends/megatron_model.py b/src/twinkle/server/model/backends/megatron_model.py index 55cc4e72..b356ef45 100644 --- a/src/twinkle/server/model/backends/megatron_model.py +++ b/src/twinkle/server/model/backends/megatron_model.py @@ -75,7 +75,7 @@ def tinker_step(self, *, adam_params: types.AdamParams, **kwargs): super().step(**kwargs) super().zero_grad(**kwargs) - @remote_function(collect='first', lazy_collect=False) + @remote_function(collect='last_pp_first', lazy_collect=False) def tinker_calculate_metric(self, is_training, **kwargs): metric = super().calculate_metric(is_training, **kwargs) return clean_metrics(metric) @@ -110,3 +110,8 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr """Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O).""" output = super().forward_backward(inputs=inputs, **kwargs) return to_cpu_safe_output(output) + + # Use last_pp_first collect method + @remote_function(collect='last_pp_first', lazy_collect=False) + def calculate_metric(self, is_training, **kwargs): + return super().calculate_metric(is_training, **kwargs) From b9fdd4b1b08ae5218f1a529f189cfc70e2f1dc65 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 19 Apr 2026 15:36:16 +0800 Subject: [PATCH 12/14] update dpo padding --- src/twinkle/server/common/datum.py | 35 +++++---- src/twinkle/server/model/backends/common.py | 73 ++++++++++++------- .../server/model/backends/megatron_model.py | 10 +-- .../model/backends/transformers_model.py | 14 +--- 4 files changed, 76 insertions(+), 56 deletions(-) diff --git a/src/twinkle/server/common/datum.py b/src/twinkle/server/common/datum.py index 1cb1510e..86d12c77 100644 --- a/src/twinkle/server/common/datum.py +++ b/src/twinkle/server/common/datum.py @@ -2,7 +2,6 @@ from __future__ import annotations import numpy as np -from collections import defaultdict from tinker import types from twinkle.data_format.input_feature import InputFeature @@ -56,26 +55,34 @@ def datum_to_input_feature(datum: types.Datum | list[types.Datum], return input_feature -def extract_rl_feature(datum: types.Datum | list[types.Datum]) -> dict: +def extract_rl_features_for_loss(datum: types.Datum | list[types.Datum]) -> dict: + """Extract RL features from datums for use as loss kwargs. + + Converts per-datum feature lists into the format expected by loss functions: + - 'logprobs' -> 'old_logps' : list of per-datum log-probability lists (for GRPO) + - 'advantages'-> 'advantages' : list of per-datum advantage lists (for GRPO) + - 'ref_logps' -> 'ref_outputs' : {'logps': torch.Tensor [B, T]} (for DPO) + """ + import torch if not isinstance(datum, list): datum = [datum] - result = defaultdict(list) + old_logps, advantages, ref_logps_lists = [], [], [] for d in datum: - # 'logprobs' -> 'old_logps' (for GRPO loss) if 'logprobs' in d.loss_fn_inputs: - old_logps = d.loss_fn_inputs['logprobs'].to_numpy().tolist() - result['old_logps'].append(old_logps) - - # 'advantages' -> 'advantages' (for GRPO loss) + old_logps.append(d.loss_fn_inputs['logprobs'].to_numpy().tolist()) if 'advantages' in d.loss_fn_inputs: - advantages = d.loss_fn_inputs['advantages'].to_numpy().tolist() - result['advantages'].append(advantages) - - # 'ref_logps' -> 'ref_logps' (for DPO loss) + advantages.append(d.loss_fn_inputs['advantages'].to_numpy().tolist()) if 'ref_logps' in d.loss_fn_inputs: - ref_logps = d.loss_fn_inputs['ref_logps'].to_numpy().tolist() - result['ref_logps'].append(ref_logps) + ref_logps_lists.append(d.loss_fn_inputs['ref_logps'].to_numpy().tolist()) + + result = {} + if old_logps: + result['old_logps'] = old_logps + if advantages: + result['advantages'] = advantages + if ref_logps_lists: + result['ref_outputs'] = {'logps': torch.stack([torch.tensor(r, dtype=torch.float32) for r in ref_logps_lists])} return result diff --git a/src/twinkle/server/model/backends/common.py b/src/twinkle/server/model/backends/common.py index 607794c3..839b1929 100644 --- a/src/twinkle/server/model/backends/common.py +++ b/src/twinkle/server/model/backends/common.py @@ -162,14 +162,26 @@ def _ensure_dpo_metric(self, adapter_name: str, beta: float): return self.add_metric('DPOMetric', adapter_name=adapter_name, beta=beta) - def _tinker_build_output(self, inputs, outputs): + def _apply_ref_outputs(self, loss_values: dict, loss_kwargs: dict, adapter_name: str) -> None: + """Pop ref_outputs from loss_values into loss_kwargs and propagate to train_status. + + DPOMetric reads ref_outputs from train_status.forward_kwargs during accumulate_metrics, + so it must be set here before the subsequent loss calculation. + """ + if 'ref_outputs' not in loss_values: + return + ref_outputs_dict = loss_values.pop('ref_outputs') + loss_kwargs['ref_outputs'] = ref_outputs_dict + self.optimizer_group[adapter_name].train_status.forward_kwargs['ref_outputs'] = ref_outputs_dict + + def _tinker_build_output(self, inputs, outputs, return_full_logprobs: bool = False): """Extract logits/logps from model outputs and build per-datum output list.""" logits = self._normalize_tensor_output(outputs.get('logits')) logps = self._normalize_tensor_output(outputs.get('logps')) if logits is None and logps is None: # non-last PP stage: no outputs produced, collector will discard this return [] - return self._get_forward_output(inputs, logits, logps) + return self._get_forward_output(inputs, logits, logps, return_full_logprobs=return_full_logprobs) @staticmethod def _normalize_tensor_output(value): @@ -200,25 +212,20 @@ def _normalize_tensor_output(value): raise ValueError(f'Unexpected type for tensor output: {type(value)}') @staticmethod - def _tinker_prepare_ref_outputs(loss_values: dict, loss_kwargs: dict): - """Convert ref_logps list-of-lists into a padded tensor and inject into loss_kwargs. - - Returns the ref_outputs dict (or None if ref_logps not present), so callers - can optionally propagate it to train_status.forward_kwargs. + def _get_forward_output(inputs: List[types.Datum], + logits: torch.Tensor, + logps: torch.Tensor, + return_full_logprobs: bool = False) -> List[dict]: + """Convert raw logits to the expected output format with logprobs and elementwise_loss. + + When return_full_logprobs is True (forward_only / reference pass), logprobs is returned + at the full TP/CP-padded sequence length so that when the client sends it back as + ref_logps in the DPO forward_backward step the shape already matches the padded labels. + When return_full_logprobs is False (default, forward_backward pass), logprobs is + truncated to the original unpadded sequence length. + elementwise_loss is always computed on the original (unpadded) length because the + per-datum weights tensor has that length. """ - if 'ref_logps' not in loss_values: - return None - import torch.nn.functional as F - ref_logps_lists = loss_values.pop('ref_logps') - max_len = max(len(r) for r in ref_logps_lists) - padded = [F.pad(torch.tensor(r, dtype=torch.float32), (0, max_len - len(r))) for r in ref_logps_lists] - ref_outputs_dict = {'logps': torch.stack(padded)} - loss_kwargs['ref_outputs'] = ref_outputs_dict - return ref_outputs_dict - - @staticmethod - def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps: torch.Tensor) -> List[dict]: - """Convert raw logits to the expected output format with logprobs and elementwise_loss.""" from twinkle.utils.torch_utils import selective_log_softmax if logps is not None: device = logps.device @@ -233,21 +240,35 @@ def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps: labels = feature.loss_fn_inputs['target_tokens'].to_torch().long().view(-1).to(device) weights = feature.loss_fn_inputs['weights'].to_torch().view(-1).to(device) - seq_len = labels.numel() + seq_len = labels.numel() # original unpadded length if logps is None: assert logit is not None, 'logit must not be None when logps is None' feature_logits = logit[:seq_len, :] - token_log_probs = selective_log_softmax(feature_logits, labels) + token_log_probs_orig = selective_log_softmax(feature_logits, labels) + if return_full_logprobs: + # Extend to the full logit length (TP/CP-padded) by padding with 0. + # Padded positions have label -100 so they are masked out by DPOLoss. + padded_len = logit.shape[0] + if padded_len > seq_len: + import torch.nn.functional as F + token_log_probs_full = F.pad(token_log_probs_orig, (0, padded_len - seq_len), value=0.0) + else: + token_log_probs_full = token_log_probs_orig + else: + token_log_probs_full = token_log_probs_orig else: - token_log_probs = logps[idx, :seq_len] + token_log_probs_orig = logps[idx, :seq_len] + # When return_full_logprobs is True, retain the full TP/CP-padded slice. + # Positions beyond seq_len have label -100 and are masked by _compute_sequence_logps. + token_log_probs_full = logps[idx] if return_full_logprobs else token_log_probs_orig # elementwise_loss: positive NLL loss (0.0 where masked) - token_log_probs = token_log_probs.to(weights.device) - elementwise_loss = -token_log_probs * weights + token_log_probs_orig = token_log_probs_orig.to(weights.device) + elementwise_loss = -token_log_probs_orig * weights results.append({ - 'logprobs': types.TensorData.from_torch(token_log_probs.cpu()), + 'logprobs': types.TensorData.from_torch(token_log_probs_full.cpu()), 'elementwise_loss': types.TensorData.from_torch(elementwise_loss.cpu()) }) return results diff --git a/src/twinkle/server/model/backends/megatron_model.py b/src/twinkle/server/model/backends/megatron_model.py index b356ef45..b8a8bfde 100644 --- a/src/twinkle/server/model/backends/megatron_model.py +++ b/src/twinkle/server/model/backends/megatron_model.py @@ -10,7 +10,7 @@ from twinkle.data_format import InputFeature, Trajectory from twinkle.infra import collect_tensor_dict from twinkle.model.megatron import MultiLoraMegatronModel -from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature +from twinkle.server.common.datum import datum_to_input_feature, extract_rl_features_for_loss from twinkle.server.model.backends.common import (TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results, to_cpu_safe_output) @@ -28,11 +28,9 @@ def tinker_forward_backward(self, *, inputs: List[types.Datum], adapter_name: st self._tinker_setup_loss(loss_fn, inputs, adapter_name, kwargs) template = self.get_template(adapter_name=adapter_name) input_features = datum_to_input_feature(inputs, template) - loss_values = extract_rl_feature(inputs) + loss_values = extract_rl_features_for_loss(inputs) loss_kwargs = kwargs.copy() - # ref_logps → padded tensor; megatron forward_backward auto-stores loss_kwargs in - # train_status.forward_kwargs (megatron.py:465), so DPOMetric reads it next step. - self._tinker_prepare_ref_outputs(loss_values, loss_kwargs) + self._apply_ref_outputs(loss_values, loss_kwargs, adapter_name) loss_kwargs.update(loss_values) outputs = super().forward_backward(inputs=input_features, adapter_name=adapter_name, **loss_kwargs) @@ -50,7 +48,7 @@ def tinker_forward_only(self, *, inputs: List[types.Datum], adapter_name: str = template = self.get_template(adapter_name) input_features = datum_to_input_feature(inputs, template) outputs = super().forward_only(inputs=input_features, adapter_name=adapter_name, **kwargs) - results = self._tinker_build_output(inputs, outputs) + results = self._tinker_build_output(inputs, outputs, return_full_logprobs=True) return [results, 0.0] @remote_function(dispatch='all') diff --git a/src/twinkle/server/model/backends/transformers_model.py b/src/twinkle/server/model/backends/transformers_model.py index 121b4df7..e4cddbb7 100644 --- a/src/twinkle/server/model/backends/transformers_model.py +++ b/src/twinkle/server/model/backends/transformers_model.py @@ -13,7 +13,7 @@ from twinkle.data_format import InputFeature, Trajectory from twinkle.infra import collect_tensor_dict from twinkle.model import MultiLoraTransformersModel -from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature +from twinkle.server.common.datum import datum_to_input_feature, extract_rl_features_for_loss from twinkle.server.model.backends.common import (TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results, to_cpu_safe_output) @@ -36,7 +36,7 @@ def tinker_forward_only(self, *, inputs: List[types.Datum], adapter_name: str = template = self.get_template(adapter_name) input_features = datum_to_input_feature(inputs, template) outputs = super().forward_only(inputs=input_features, adapter_name=adapter_name, **kwargs) - results = self._tinker_build_output(inputs, outputs) + results = self._tinker_build_output(inputs, outputs, return_full_logprobs=True) return [results, 0.0] @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results) @@ -45,15 +45,9 @@ def tinker_forward_backward(self, *, inputs: List[types.Datum], adapter_name: st template = self.get_template(adapter_name) input_features = datum_to_input_feature(inputs, template) outputs = super().forward(inputs=input_features, adapter_name=adapter_name, **kwargs) - loss_values = extract_rl_feature(inputs) + loss_values = extract_rl_features_for_loss(inputs) loss_kwargs = kwargs.copy() - # Convert ref_logps list-of-lists into a padded tensor wrapped in ref_outputs - # so that DPOLoss and DPOMetric can consume it via ref_outputs.get('logps'). - ref_outputs_dict = self._tinker_prepare_ref_outputs(loss_values, loss_kwargs) - if ref_outputs_dict is not None: - # Propagate to train_status.forward_kwargs so DPOMetric.accumulate - # gets ref_outputs on the next forward() call (where accumulate_metrics runs). - self.optimizer_group[adapter_name].train_status.forward_kwargs['ref_outputs'] = ref_outputs_dict + self._apply_ref_outputs(loss_values, loss_kwargs, adapter_name) loss_kwargs.update(loss_values) loss = super().calculate_loss(adapter_name=adapter_name, **loss_kwargs) super().backward(adapter_name=adapter_name, **kwargs) From 3975438aa39015c14a89751a41c85c89157c2e36 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 19 Apr 2026 15:40:43 +0800 Subject: [PATCH 13/14] update --- src/twinkle/server/common/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/twinkle/server/common/__init__.py b/src/twinkle/server/common/__init__.py index 4c290eb2..2bb9f46d 100644 --- a/src/twinkle/server/common/__init__.py +++ b/src/twinkle/server/common/__init__.py @@ -1,11 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .checkpoint_factory import create_checkpoint_manager, create_training_run_manager -from .datum import datum_to_input_feature, extract_rl_feature, input_feature_to_datum +from .datum import datum_to_input_feature, extract_rl_features_for_loss, input_feature_to_datum from .router import StickyLoraRequestRouter __all__ = [ 'datum_to_input_feature', - 'extract_rl_feature', + 'extract_rl_features_for_loss', 'input_feature_to_datum', 'create_checkpoint_manager', 'create_training_run_manager', From 98981957f1eb5edc029c137f7d5f320f0a9f98bc Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sun, 19 Apr 2026 17:47:27 +0800 Subject: [PATCH 14/14] update --- src/twinkle/model/megatron/megatron.py | 2 +- src/twinkle/server/model/backends/megatron_model.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index a7a2bb66..dfc7e98d 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -657,7 +657,7 @@ def set_optimizer(self, optimizer_cls: Union[Optimizer, Type[Optimizer], str], * def _accumulate_metric(optimizer_config: MegatronOptimizerGroup, is_training): optimizer_config.accumulate_metrics(is_training) - @remote_function(collect='first', lazy_collect=False) + @remote_function(collect='last_pp_first', lazy_collect=False) def calculate_metric(self, is_training, **kwargs): adapter_name = kwargs.pop('adapter_name', self._get_default_group()) optimizer_config = self.optimizer_group[adapter_name] diff --git a/src/twinkle/server/model/backends/megatron_model.py b/src/twinkle/server/model/backends/megatron_model.py index b8a8bfde..0f2c3bf8 100644 --- a/src/twinkle/server/model/backends/megatron_model.py +++ b/src/twinkle/server/model/backends/megatron_model.py @@ -108,8 +108,3 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr """Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O).""" output = super().forward_backward(inputs=inputs, **kwargs) return to_cpu_safe_output(output) - - # Use last_pp_first collect method - @remote_function(collect='last_pp_first', lazy_collect=False) - def calculate_metric(self, is_training, **kwargs): - return super().calculate_metric(is_training, **kwargs)