Fix model pp > 1 and tp > 1 errors#171
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a significant refactoring of the task queue system, modularizing it into a dedicated package and adding execution timeouts. It also implements graceful shutdown procedures for the gateway and model servers using FastAPI lifespans and signal handlers to ensure resources like the httpx client and Ray Serve are properly cleaned up. Other notable changes include enhanced DPO reference output handling, the addition of a 'last_pp_first' collection method for metrics, and more robust task scheduling with metadata tracking. Review feedback primarily focused on improving error handling during shutdown by logging exceptions instead of silently catching them and preventing a potential IndexError in the new collection logic.
There was a problem hiding this comment.
Pull request overview
Fixes failures observed when running with pipeline parallelism (PP) > 1 and tensor parallelism (TP) > 1 by adjusting result collection semantics, tightening request/task execution ordering, and improving server lifecycle/shutdown behavior.
Changes:
- Introduce a dedicated task-queue package (config/types/worker/mixin) to serialize GPU compute work with round-robin fairness, rate limiting, and optional execution timeouts.
- Fix PP>1 metric collection by adding a new
last_pp_firstcollect mode and using it for Megatron metric endpoints. - Improve server reliability: add FastAPI lifespan shutdown hooks (model/gateway), safer event-loop APIs (
get_running_loop), and launcher signal-based graceful shutdown.
Reviewed changes
Copilot reviewed 28 out of 28 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| src/twinkle_client/types/model.py | Adds is_sampler flag to save requests to support sampler-weight save semantics. |
| src/twinkle/server/utils/task_queue/worker.py | New serial compute worker with RR scheduling, queue/exec timeouts, and metrics hooks. |
| src/twinkle/server/utils/task_queue/types.py | New task queue core enums + QueuedTask dataclass. |
| src/twinkle/server/utils/task_queue/rate_limiter.py | Moves/cleans up rate limiter used by task queue. |
| src/twinkle/server/utils/task_queue/mixin.py | New mixin API for compute-queue vs background-task scheduling + preflight checks. |
| src/twinkle/server/utils/task_queue/config.py | New task-queue configuration including exec timeout and token validation limits. |
| src/twinkle/server/utils/task_queue/init.py | Public exports to preserve the old twinkle.server.utils.task_queue import surface. |
| src/twinkle/server/utils/task_queue.py | Removes old monolithic implementation in favor of the new package. |
| src/twinkle/server/utils/init.py | Updates imports to re-export task-queue symbols from the new package. |
| src/twinkle/server/sampler/twinkle_handlers.py | Routes sampler requests through the task queue for serialized execution + better error surfacing. |
| src/twinkle/server/processor/twinkle_handlers.py | Uses get_running_loop() for executor calls (safer in async contexts). |
| src/twinkle/server/model/twinkle_handlers.py | Passes token/task metrics into queue; updates save/upload execution paths; adds sampler save behavior. |
| src/twinkle/server/model/tinker_handlers.py | Aligns sampler-save ordering/comments with twinkle handler behavior. |
| src/twinkle/server/model/backends/transformers_model.py | Updates DPO/GRPO feature extraction path and ensures ref outputs are applied consistently. |
| src/twinkle/server/model/backends/megatron_model.py | Uses new RL feature extraction and last_pp_first for metric collection under PP>1. |
| src/twinkle/server/model/backends/common.py | Adds _apply_ref_outputs; supports returning full padded logprobs for ref passes. |
| src/twinkle/server/model/app.py | Adds lifespan-based shutdown hook and fixes mutable default args. |
| src/twinkle/server/launcher.py | Adds signal-based graceful shutdown; changes programmatic launch behavior. |
| src/twinkle/server/gateway/tinker_gateway_handlers.py | Uses get_running_loop() timing; improves error handling (e.g., missing base_model). |
| src/twinkle/server/gateway/server.py | Adds lifespan hook to close proxy client; fixes mutable default arg. |
| src/twinkle/server/gateway/proxy.py | Adds close() to properly close httpx.AsyncClient. |
| src/twinkle/server/common/datum.py | Renames/changes RL feature extraction to produce loss kwargs (incl. ref_outputs). |
| src/twinkle/server/common/init.py | Updates public export name for RL feature extraction helper. |
| src/twinkle/server/main.py | Applies CLI log level via env var prior to server start. |
| src/twinkle/model/megatron/megatron.py | Uses last_pp_first for metric collection under PP>1. |
| src/twinkle/infra/init.py | Adds last_pp_first collect mode to support PP>1 “last stage” result selection. |
| cookbook/client/twinkle/self_host/short_math_grpo.py | Updates example to save sampler weights via is_sampler=True and a stable checkpoint name. |
| cookbook/client/tinker/self_host/dpo.py | Disables SwanLab by default in example config. |
| asyncio.create_task(self.state.unregister_replica(self.replica_id)) | ||
| await self.state.unregister_replica(self.replica_id) | ||
| except Exception: | ||
| pass |
There was a problem hiding this comment.
ModelManagement.shutdown() only unregisters the replica, but it doesn't stop the TaskQueueMixin background tasks (rate-limiter cleanup task and ComputeWorker task). This can leave pending asyncio tasks at shutdown and cause noisy warnings/leaks. Consider calling await self.shutdown_task_queue() (and any other mixin cleanup) as part of shutdown().
| pass | |
| pass | |
| try: | |
| await self.shutdown_task_queue() | |
| except Exception: | |
| pass |
| save_dir = checkpoint_manager.get_save_dir(model_id=adapter_name, is_sampler=body.is_sampler) | ||
| # Must save the checkpoint in the twinkle format before calling model.save() | ||
| twinkle_path = checkpoint_manager.save( | ||
| model_id=adapter_name, name=checkpoint_name, is_sampler=body.is_sampler) | ||
| checkpoint_dir = self.model.save( | ||
| name=checkpoint_name, | ||
| output_dir=save_dir, | ||
| adapter_name=adapter_name, | ||
| save_optimizer=body.save_optimizer, | ||
| **extra_kwargs) |
There was a problem hiding this comment.
checkpoint_manager.save() records checkpoint metadata (including size_bytes via get_dir_size(checkpoint_path)) before model.save() writes any weights, so the saved checkpoint metadata will report size_bytes=0 (and may be inconsistent if model.save later fails). If you need the sampler-weight deletion behavior up front, consider writing/updating checkpoint metadata after model.save completes (or re-computing size_bytes).
| tps_limit: float = 16000.0 # 16000 input tokens per second | ||
| window_seconds: float = 1.0 # 1 second sliding window | ||
| queue_timeout: float = 300.0 # 5 minutes queue timeout | ||
| execution_timeout: float = 120.0 # 120 seconds execution timeout (0 to disable) |
There was a problem hiding this comment.
TaskQueueConfig introduces a default execution_timeout=120s. For large models, GPU ops like save/load/forward_backward can legitimately exceed this and will now fail via asyncio.wait_for cancellation, potentially leaving partial state. Consider defaulting execution_timeout to 0 (disabled) or a much higher value, and rely on explicit config overrides for deployments that want timeouts.
| execution_timeout: float = 120.0 # 120 seconds execution timeout (0 to disable) | |
| execution_timeout: float = 0.0 # Disabled by default; set explicitly to enforce a limit |
| def launch_server( | ||
| config: dict[str, Any] | None = None, | ||
| config_path: str | Path | None = None, | ||
| ray_namespace: str | None = None, | ||
| ) -> ServerLauncher: | ||
| ) -> None: |
There was a problem hiding this comment.
launch_server() previously returned a ServerLauncher, but it now returns None and blocks until SIGINT/SIGTERM. Since launch_server is re-exported as a public API (twinkle.server.init), this is a backward-incompatible change for programmatic callers that expect to manage the launcher lifecycle themselves. Consider either keeping the return value (and offering a separate blocking helper) or clearly versioning/deprecating this API change.
PR type
PR information
Write the detail information belongs to this PR.
Experiment results
Paste your experiment result here(if needed).