From 397f2508627bb10c9ade271df634dbb5a9042cc0 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 14 Apr 2026 19:37:54 +0800 Subject: [PATCH 1/2] fix --- .../model/megatron/multi_lora_megatron.py | 47 +++++++++++++++---- src/twinkle/model/multi_lora.py | 14 ++++++ src/twinkle/template/base.py | 2 +- 3 files changed, 53 insertions(+), 10 deletions(-) diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 346b9f86..5b4bee49 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -1,7 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os +import re import torch.distributed as dist import torch.nn as nn +from contextlib import contextmanager from functools import partial from peft import LoraConfig from torch.optim import Optimizer @@ -55,7 +57,7 @@ def __init__( self._model_path = HubOperation.download_model(model_id) self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id) self._default_tokenizer = None - self.use_distributed_optimizer = kwargs.get('use_distributed_optimizer', True) + self.use_distributed_optimizer = False self.variable_seq_lengths = kwargs.get('variable_seq_lengths', False) self.optimizer_group = {} torch_util.set_device() @@ -154,21 +156,48 @@ def set_loss(self, loss_cls: Union[Loss, Type[Loss], str], **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) return super().set_loss(loss_cls, **kwargs) + @contextmanager + def optimizer_context(self, adapter_name: str): + """Temporarily replace named_parameters on each module in self.model + so that only parameters belonging to ``adapter_name`` are visible.""" + pattern = re.compile(rf'\.lora_\w+\.{re.escape(adapter_name)}\.') + originals = [] + for module in self.model: + orig = module.named_parameters + + def make_filtered(orig_fn): + + def filtered(prefix: str = '', recurse: bool = True, **kwargs): + for name, param in orig_fn(prefix=prefix, recurse=recurse, **kwargs): + if param.requires_grad and pattern.search(name): + yield name, param + + return filtered + + module.named_parameters = make_filtered(orig) + originals.append((module, orig)) + try: + yield + finally: + for module, orig in originals: + module.named_parameters = orig + @remote_function(dispatch='all') def set_optimizer(self, optimizer_cls: Union[Optimizer, Type[Optimizer], str], **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) - with self.multi_adapter.adapter(kwargs.get('adapter_name')): - # Multi lora cannot config use_distributed_optimizer/loss_scale/mix_precision - kwargs.pop('use_distributed_optimizer', None) - kwargs.pop('loss_scale', None) - kwargs['fp16'] = False - kwargs['bf16'] = True - return super().set_optimizer(optimizer_cls, **kwargs) + with self.multi_adapter.adapter(kwargs.get('adapter_name')) as adapter_name: + with self.optimizer_context(adapter_name): + # Multi lora cannot config use_distributed_optimizer/loss_scale/mix_precision + kwargs.pop('use_distributed_optimizer', None) + kwargs.pop('loss_scale', None) + kwargs['fp16'] = False + kwargs['bf16'] = True + super().set_optimizer(optimizer_cls, **kwargs) @remote_function(dispatch='all') def set_lr_scheduler(self, scheduler_cls: Union[LRScheduler, Type[LRScheduler], str], **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) - return super().set_lr_scheduler(scheduler_cls, **kwargs) + super().set_lr_scheduler(scheduler_cls, **kwargs) @remote_function(dispatch='all', collect='first', sync=True) def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs): diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index d8be4832..387e747c 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -424,6 +424,20 @@ def _patch_megatron(_module): else: module = _patch_peft(module) + # PEFT's add_adapter calls set_adapter(active_adapters) which only keeps the + # first adapter's requires_grad=True. We need ALL LoRA params to be trainable + # so that MegatronDDP registers them all in its gradient buffers (main_grad). + def _enable_all_lora_grad(_module): + for name, param in _module.named_parameters(): + if 'lora_' in name and not param.requires_grad: + param.requires_grad_(True) + + if isinstance(module, list): + for _m in module: + _enable_all_lora_grad(_m) + else: + _enable_all_lora_grad(module) + self.module = module return module diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index f76afd5e..9772f42c 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -162,7 +162,7 @@ def concat_input_feature(self, prompt_input_feature: InputFeature, new_tokens: L assert self.truncation_strategy != 'split', 'concat_input_feature does not support `truncation_strategy=split`' result = copy.deepcopy(prompt_input_feature) prompt_ids = result['input_ids'] - labels = list(result['labels']) + labels = list(result.get('labels', [])) input_ids = list(prompt_ids) + new_tokens labels = labels[-1:] + labels[:-1] # roll to input order labels = labels + new_tokens From d68475bb292d72ae4539ac5638d978de725220d8 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 14 Apr 2026 19:51:33 +0800 Subject: [PATCH 2/2] fix --- cookbook/client/server/megatron/server.py | 19 +++++++++++++++++++ .../client/server/megatron/server_config.yaml | 5 ++++- 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 cookbook/client/server/megatron/server.py diff --git a/cookbook/client/server/megatron/server.py b/cookbook/client/server/megatron/server.py new file mode 100644 index 00000000..938877eb --- /dev/null +++ b/cookbook/client/server/megatron/server.py @@ -0,0 +1,19 @@ +# Twinkle Server Launcher - Tinker-Compatible Transformers Backend +# +# This script starts the Twinkle server with Tinker-compatible API support. +# It reads the server_config.yaml in the same directory for all +# configuration (model, sampler, deployment settings, etc.). +# Run this script BEFORE running any client scripts (lora.py, sample.py, etc.). + +import os + +os.environ['TWINKLE_TRUST_REMOTE_CODE'] = '0' + +from twinkle.server import launch_server + +# Resolve the path to server_config.yaml relative to this script's location +file_dir = os.path.abspath(os.path.dirname(__file__)) +config_path = os.path.join(file_dir, 'server_config.yaml') + +# Launch the Twinkle server — this call blocks until the server is shut down +launch_server(config_path=config_path) diff --git a/cookbook/client/server/megatron/server_config.yaml b/cookbook/client/server/megatron/server_config.yaml index 90c8d9ac..7169ace7 100644 --- a/cookbook/client/server/megatron/server_config.yaml +++ b/cookbook/client/server/megatron/server_config.yaml @@ -33,6 +33,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" + TWINKLE_LONG_POLL_TIMEOUT: "120" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). @@ -74,6 +75,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" + TWINKLE_LONG_POLL_TIMEOUT: "120" # 2. Model Service - Hosts the base model for training. # Config: PP=2 x DP=2 on 4 GPUs, ~27GB weights/GPU, comfortable for LoRA training @@ -99,7 +101,7 @@ applications: rps_limit: 20 # Max requests per second tps_limit: 32000 # Max tokens per second adapter_config: - adapter_timeout: 30 # Seconds before idle adapter unload + adapter_timeout: 120 # Seconds before idle adapter unload adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours) deployments: - name: ModelManagement @@ -112,6 +114,7 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" + TWINKLE_LONG_POLL_TIMEOUT: "120" # 4. Processor Service - name: processor