Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cookbook/client/tinker/self_host/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
max_length = 2048
lora_rank = 8
system_prompt = 'You are a helpful assistant.'
use_swanlab = True
use_swanlab = False


# ---------------------------------------------------------------------------
Expand Down
6 changes: 4 additions & 2 deletions cookbook/client/twinkle/self_host/short_math_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
GRADIENT_ACCUMULATION_STEPS = 1
DATA_NUM = 2000 # Number of Math samples to use

USE_SWANLAB = True
USE_SWANLAB = False
SWANLAB_PROJECT = 'twinkle-grpo'
SWANLAB_EXPERIMENT_NAME = 'short-math-grpo'

Expand Down Expand Up @@ -210,11 +210,13 @@ def train():
# ========== 1. Save weights and update adapter_uri ==========
# Instead of sync_weights, save the model checkpoint and pass
# the resulting path to the sampler as adapter_uri
# Use is_sampler=True to delete old sampler weights and keep only the latest
if step % SYNC_INTERVAL == 0:
logger.info(f'Step {step}: Saving weights for sampler...')
result = model.save(
name=f'grpo-sampler-step-{step}',
name='grpo-sampler-weights',
save_optimizer=False,
is_sampler=True,
)
current_adapter_uri = result.twinkle_path
logger.info(f'Step {step}: Saved weights to {current_adapter_uri}')
Expand Down
6 changes: 6 additions & 0 deletions src/twinkle/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,12 @@ def _collect_func(method: Union[Literal['none', 'flatten', 'mean', 'sum', 'first
elif method == 'last_pp':
assert device_mesh is not None
return [r for i, r in enumerate(result) if i in device_mesh.get_pp_last_ranks()]
elif method == 'last_pp_first':
# Return the first result from the last PP stage workers.
# Falls back to result[0] when PP = 1 (all workers are the last stage).
assert device_mesh is not None
last_pp = [r for i, r in enumerate(result) if i in device_mesh.get_pp_last_ranks()]
return last_pp[0] if last_pp else result[0]
Comment thread
Yunnglin marked this conversation as resolved.
elif isinstance(method, Callable):
# Callable
return method(result, device_mesh=device_mesh)
Expand Down
2 changes: 1 addition & 1 deletion src/twinkle/model/megatron/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def set_optimizer(self, optimizer_cls: Union[Optimizer, Type[Optimizer], str], *
def _accumulate_metric(optimizer_config: MegatronOptimizerGroup, is_training):
optimizer_config.accumulate_metrics(is_training)

@remote_function(collect='first', lazy_collect=False)
@remote_function(collect='last_pp_first', lazy_collect=False)
def calculate_metric(self, is_training, **kwargs):
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
optimizer_config = self.optimizer_group[adapter_name]
Expand Down
6 changes: 6 additions & 0 deletions src/twinkle/server/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from __future__ import annotations

import argparse
import os
import sys
from pathlib import Path

Expand Down Expand Up @@ -77,6 +78,11 @@ def main(args: list[str] | None = None) -> int:
try:
from twinkle.server.launcher import launch_server

# Apply log level so that all loggers (including those created later)
# pick up the user-specified level via the LOG_LEVEL env var that
# get_logger() already reads.
os.environ['LOG_LEVEL'] = parsed_args.log_level

config_path = Path(parsed_args.config)
if not config_path.exists():
logger.error(f'Config file not found: {config_path}')
Expand Down
4 changes: 2 additions & 2 deletions src/twinkle/server/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from .checkpoint_factory import create_checkpoint_manager, create_training_run_manager
from .datum import datum_to_input_feature, extract_rl_feature, input_feature_to_datum
from .datum import datum_to_input_feature, extract_rl_features_for_loss, input_feature_to_datum
from .router import StickyLoraRequestRouter

__all__ = [
'datum_to_input_feature',
'extract_rl_feature',
'extract_rl_features_for_loss',
'input_feature_to_datum',
'create_checkpoint_manager',
'create_training_run_manager',
Expand Down
35 changes: 21 additions & 14 deletions src/twinkle/server/common/datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import annotations

import numpy as np
from collections import defaultdict
from tinker import types

from twinkle.data_format.input_feature import InputFeature
Expand Down Expand Up @@ -56,26 +55,34 @@ def datum_to_input_feature(datum: types.Datum | list[types.Datum],
return input_feature


def extract_rl_feature(datum: types.Datum | list[types.Datum]) -> dict:
def extract_rl_features_for_loss(datum: types.Datum | list[types.Datum]) -> dict:
"""Extract RL features from datums for use as loss kwargs.

Converts per-datum feature lists into the format expected by loss functions:
- 'logprobs' -> 'old_logps' : list of per-datum log-probability lists (for GRPO)
- 'advantages'-> 'advantages' : list of per-datum advantage lists (for GRPO)
- 'ref_logps' -> 'ref_outputs' : {'logps': torch.Tensor [B, T]} (for DPO)
"""
import torch
if not isinstance(datum, list):
datum = [datum]

result = defaultdict(list)
old_logps, advantages, ref_logps_lists = [], [], []
for d in datum:
# 'logprobs' -> 'old_logps' (for GRPO loss)
if 'logprobs' in d.loss_fn_inputs:
old_logps = d.loss_fn_inputs['logprobs'].to_numpy().tolist()
result['old_logps'].append(old_logps)

# 'advantages' -> 'advantages' (for GRPO loss)
old_logps.append(d.loss_fn_inputs['logprobs'].to_numpy().tolist())
if 'advantages' in d.loss_fn_inputs:
advantages = d.loss_fn_inputs['advantages'].to_numpy().tolist()
result['advantages'].append(advantages)

# 'ref_logps' -> 'ref_logps' (for DPO loss)
advantages.append(d.loss_fn_inputs['advantages'].to_numpy().tolist())
if 'ref_logps' in d.loss_fn_inputs:
ref_logps = d.loss_fn_inputs['ref_logps'].to_numpy().tolist()
result['ref_logps'].append(ref_logps)
ref_logps_lists.append(d.loss_fn_inputs['ref_logps'].to_numpy().tolist())

result = {}
if old_logps:
result['old_logps'] = old_logps
if advantages:
result['advantages'] = advantages
if ref_logps_lists:
result['ref_outputs'] = {'logps': torch.stack([torch.tensor(r, dtype=torch.float32) for r in ref_logps_lists])}
return result


Expand Down
4 changes: 4 additions & 0 deletions src/twinkle/server/gateway/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def __init__(
# Disable proxy env vars to avoid external routing
self.client = httpx.AsyncClient(timeout=None, trust_env=False)

async def close(self) -> None:
"""Close the underlying httpx.AsyncClient to release connections."""
await self.client.aclose()

def _build_target_url(self, service_type: str, base_model: str, endpoint: str) -> str:
"""Build the target URL for internal service routing.

Expand Down
23 changes: 17 additions & 6 deletions src/twinkle/server/gateway/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request
from ray import serve
from typing import Any
Expand All @@ -29,9 +30,10 @@ class GatewayServer:

def __init__(self,
supported_models: list | None = None,
server_config: dict[str, Any] = {},
server_config: dict[str, Any] | None = None,
http_options: dict[str, Any] | None = None,
**kwargs) -> None:
server_config = server_config or {}
self.state = get_server_state(**server_config)
self.route_prefix = kwargs.get('route_prefix', '/api/v1')
self.http_options = http_options or {}
Expand Down Expand Up @@ -71,7 +73,7 @@ async def _get_base_model(self, model_id: str) -> str:

def build_server_app(deploy_options: dict[str, Any],
supported_models: list | None = None,
server_config: dict[str, Any] = {},
server_config: dict[str, Any] | None = None,
http_options: dict[str, Any] | None = None,
**kwargs):
"""Build and configure the unified gateway server application.
Expand All @@ -88,17 +90,26 @@ def build_server_app(deploy_options: dict[str, Any],
Returns:
Configured Ray Serve deployment bound with options
"""
app = FastAPI()

def get_self() -> GatewayServer:
return serve.get_replica_context().servable_object

@asynccontextmanager
async def lifespan(app: FastAPI):
yield
try:
await get_self().proxy.close()
except Exception:
pass
Comment thread
Yunnglin marked this conversation as resolved.

app = FastAPI(lifespan=lifespan)

@app.middleware('http')
async def verify_token(request: Request, call_next):
return await verify_request_token(request=request, call_next=call_next)

app.middleware('http')(create_metrics_middleware('Gateway'))

def get_self() -> GatewayServer:
return serve.get_replica_context().servable_object

_register_tinker_routes(app, get_self)
_register_twinkle_routes(app, get_self)

Expand Down
10 changes: 4 additions & 6 deletions src/twinkle/server/gateway/tinker_gateway_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def retrieve_future(request: Request,
request_id = body.request_id
max_wait = float(os.environ.get('TWINKLE_LONG_POLL_TIMEOUT', '30'))
poll_interval = float(os.environ.get('TWINKLE_POLL_INTERVAL', '0.5'))
start = asyncio.get_event_loop().time()
start = asyncio.get_running_loop().time()

while True:
record = await self.state.get_future(request_id)
Expand All @@ -95,7 +95,7 @@ async def retrieve_future(request: Request,
if status not in ('pending', 'queued', 'running', 'rate_limited'):
break

if asyncio.get_event_loop().time() - start >= max_wait:
if asyncio.get_running_loop().time() - start >= max_wait:
response_data = {'type': 'try_again'}
if queue_state := record.get('queue_state'):
response_data['queue_state'] = queue_state
Expand All @@ -105,10 +105,6 @@ async def retrieve_future(request: Request,

await asyncio.sleep(poll_interval)

record = await self.state.get_future(request_id)
if not record:
return {'type': 'try_again'}

status = record.get('status')

if status == 'rate_limited':
Expand Down Expand Up @@ -263,6 +259,8 @@ async def asample(request: Request, body: types.SampleRequest, self: GatewayServ
session = await self.state.get_sampling_session(body.sampling_session_id)
if session:
base_model = session.get('base_model')
if not base_model:
raise HTTPException(status_code=400, detail='base_model is required but could not be resolved')
return await self.proxy.proxy_to_sampler(request, 'asample', base_model)

@app.post('/save_weights_for_sampler')
Expand Down
49 changes: 38 additions & 11 deletions src/twinkle/server/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
"""
from __future__ import annotations

import time
import signal
import threading
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, NoReturn, Optional, Union

from twinkle import get_logger
from twinkle.server.utils.ray_serve_patch import apply_ray_serve_patches, get_runtime_env_for_patches
Expand Down Expand Up @@ -146,9 +147,12 @@ def _start_serve(self) -> None:
from ray import serve

try:
from ray.serve.context import _get_global_client
_get_global_client()
# Serve is running, shut it down before re-starting
serve.shutdown()
time.sleep(2)
except Exception:
# Serve not running — nothing to shut down
pass

http_options = self.config.get('http_options', {})
Expand Down Expand Up @@ -182,6 +186,9 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None:

deploy_options = {}
if deployments:
if len(deployments) > 1:
logger.warning(f'Application "{name}" has {len(deployments)} deployments configured, '
f'but only the first deployment will be used.')
deploy_config = deployments[0]
if isinstance(deploy_config, dict):
deploy_options = {k: v for k, v in deploy_config.items() if k != 'name'}
Expand All @@ -197,7 +204,12 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None:
logger.info(f'Deployed {name} at {route_prefix}')

def launch(self) -> None:
"""Launch the server with all configured applications."""
"""Launch the server with all configured applications.

Blocks the calling thread to keep the server running. Installs signal
handlers for SIGINT/SIGTERM so that ``serve.shutdown()`` is called on
termination instead of leaving orphaned deployments.
"""
# Apply Ray Serve patches before initializing Ray
apply_ray_serve_patches()

Expand Down Expand Up @@ -226,8 +238,26 @@ def launch(self) -> None:
dict) else app_config.route_prefix
print(f' - http://{host}:{port}{route_prefix}')

while True:
time.sleep(3600)
# Graceful shutdown via signal handling
shutdown_event = threading.Event()

def _handle_signal(signum, frame):
sig_name = signal.Signals(signum).name
logger.info(f'Received {sig_name}, shutting down gracefully...')
shutdown_event.set()

signal.signal(signal.SIGINT, _handle_signal)
signal.signal(signal.SIGTERM, _handle_signal)

# Block until a termination signal is received
shutdown_event.wait()

from ray import serve
try:
serve.shutdown()
logger.info('Ray Serve shut down successfully')
except Exception:
logger.warning('Error during Ray Serve shutdown', exc_info=True)

@classmethod
def from_yaml(
Expand Down Expand Up @@ -264,20 +294,18 @@ def launch_server(
config: dict[str, Any] | None = None,
config_path: str | Path | None = None,
ray_namespace: str | None = None,
) -> ServerLauncher:
) -> None:
"""
Launch a twinkle server with flexible configuration options.

This is the main entry point for launching servers programmatically.
The call blocks until a SIGINT/SIGTERM signal is received.

Args:
config: Configuration dictionary (takes precedence over config_path)
config_path: Path to YAML config file
ray_namespace: Ray namespace

Returns:
The ServerLauncher instance

Raises:
ValueError: If neither config nor config_path is provided

Expand Down Expand Up @@ -306,4 +334,3 @@ def launch_server(
)

launcher.launch()
return launcher
Loading
Loading