Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions cookbook/client/server/megatron/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Twinkle Server Launcher - Tinker-Compatible Transformers Backend
#
# This script starts the Twinkle server with Tinker-compatible API support.
# It reads the server_config.yaml in the same directory for all
# configuration (model, sampler, deployment settings, etc.).
# Run this script BEFORE running any client scripts (lora.py, sample.py, etc.).

import os

os.environ['TWINKLE_TRUST_REMOTE_CODE'] = '0'

from twinkle.server import launch_server

# Resolve the path to server_config.yaml relative to this script's location
file_dir = os.path.abspath(os.path.dirname(__file__))
config_path = os.path.join(file_dir, 'server_config.yaml')

# Launch the Twinkle server — this call blocks until the server is shut down
launch_server(config_path=config_path)
5 changes: 4 additions & 1 deletion cookbook/client/server/megatron/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ applications:
runtime_env:
env_vars:
TWINKLE_TRUST_REMOTE_CODE: "0"
TWINKLE_LONG_POLL_TIMEOUT: "120"

# 3. Sampler Service - Runs inference / sampling using vLLM engine
# Used for generating text from the model (e.g., evaluating LoRA results).
Expand Down Expand Up @@ -74,6 +75,7 @@ applications:
runtime_env:
env_vars:
TWINKLE_TRUST_REMOTE_CODE: "0"
TWINKLE_LONG_POLL_TIMEOUT: "120"

# 2. Model Service - Hosts the base model for training.
# Config: PP=2 x DP=2 on 4 GPUs, ~27GB weights/GPU, comfortable for LoRA training
Expand All @@ -99,7 +101,7 @@ applications:
rps_limit: 20 # Max requests per second
tps_limit: 32000 # Max tokens per second
adapter_config:
adapter_timeout: 30 # Seconds before idle adapter unload
adapter_timeout: 120 # Seconds before idle adapter unload
adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours)
deployments:
- name: ModelManagement
Expand All @@ -112,6 +114,7 @@ applications:
runtime_env:
env_vars:
TWINKLE_TRUST_REMOTE_CODE: "0"
TWINKLE_LONG_POLL_TIMEOUT: "120"

# 4. Processor Service
- name: processor
Expand Down
47 changes: 38 additions & 9 deletions src/twinkle/model/megatron/multi_lora_megatron.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import os
import re
import torch.distributed as dist
import torch.nn as nn
from contextlib import contextmanager
from functools import partial
from peft import LoraConfig
from torch.optim import Optimizer
Expand Down Expand Up @@ -55,7 +57,7 @@ def __init__(
self._model_path = HubOperation.download_model(model_id)
self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id)
self._default_tokenizer = None
self.use_distributed_optimizer = kwargs.get('use_distributed_optimizer', True)
self.use_distributed_optimizer = False
Comment thread
tastelikefeet marked this conversation as resolved.
self.variable_seq_lengths = kwargs.get('variable_seq_lengths', False)
self.optimizer_group = {}
torch_util.set_device()
Expand Down Expand Up @@ -154,21 +156,48 @@ def set_loss(self, loss_cls: Union[Loss, Type[Loss], str], **kwargs):
self._check_adapter_valid(kwargs.get('adapter_name'))
return super().set_loss(loss_cls, **kwargs)

@contextmanager
def optimizer_context(self, adapter_name: str):
"""Temporarily replace named_parameters on each module in self.model
so that only parameters belonging to ``adapter_name`` are visible."""
pattern = re.compile(rf'\.lora_\w+\.{re.escape(adapter_name)}\.')
Comment thread
tastelikefeet marked this conversation as resolved.
originals = []
for module in self.model:
orig = module.named_parameters

def make_filtered(orig_fn):

def filtered(prefix: str = '', recurse: bool = True, **kwargs):
for name, param in orig_fn(prefix=prefix, recurse=recurse, **kwargs):
if param.requires_grad and pattern.search(name):
yield name, param

return filtered

module.named_parameters = make_filtered(orig)
originals.append((module, orig))
try:
yield
finally:
for module, orig in originals:
module.named_parameters = orig

@remote_function(dispatch='all')
def set_optimizer(self, optimizer_cls: Union[Optimizer, Type[Optimizer], str], **kwargs):
self._check_adapter_valid(kwargs.get('adapter_name'))
with self.multi_adapter.adapter(kwargs.get('adapter_name')):
# Multi lora cannot config use_distributed_optimizer/loss_scale/mix_precision
kwargs.pop('use_distributed_optimizer', None)
kwargs.pop('loss_scale', None)
kwargs['fp16'] = False
kwargs['bf16'] = True
return super().set_optimizer(optimizer_cls, **kwargs)
with self.multi_adapter.adapter(kwargs.get('adapter_name')) as adapter_name:
with self.optimizer_context(adapter_name):
# Multi lora cannot config use_distributed_optimizer/loss_scale/mix_precision
kwargs.pop('use_distributed_optimizer', None)
kwargs.pop('loss_scale', None)
kwargs['fp16'] = False
kwargs['bf16'] = True
super().set_optimizer(optimizer_cls, **kwargs)

@remote_function(dispatch='all')
def set_lr_scheduler(self, scheduler_cls: Union[LRScheduler, Type[LRScheduler], str], **kwargs):
self._check_adapter_valid(kwargs.get('adapter_name'))
return super().set_lr_scheduler(scheduler_cls, **kwargs)
super().set_lr_scheduler(scheduler_cls, **kwargs)

@remote_function(dispatch='all', collect='first', sync=True)
def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs):
Expand Down
14 changes: 14 additions & 0 deletions src/twinkle/model/multi_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,20 @@ def _patch_megatron(_module):
else:
module = _patch_peft(module)

# PEFT's add_adapter calls set_adapter(active_adapters) which only keeps the
# first adapter's requires_grad=True. We need ALL LoRA params to be trainable
# so that MegatronDDP registers them all in its gradient buffers (main_grad).
def _enable_all_lora_grad(_module):
for name, param in _module.named_parameters():
if 'lora_' in name and not param.requires_grad:
param.requires_grad_(True)

if isinstance(module, list):
for _m in module:
_enable_all_lora_grad(_m)
else:
_enable_all_lora_grad(module)

self.module = module
return module

Expand Down
2 changes: 1 addition & 1 deletion src/twinkle/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def concat_input_feature(self, prompt_input_feature: InputFeature, new_tokens: L
assert self.truncation_strategy != 'split', 'concat_input_feature does not support `truncation_strategy=split`'
result = copy.deepcopy(prompt_input_feature)
prompt_ids = result['input_ids']
labels = list(result['labels'])
labels = list(result.get('labels', []))
input_ids = list(prompt_ids) + new_tokens
labels = labels[-1:] + labels[:-1] # roll to input order
labels = labels + new_tokens
Expand Down
Loading