From d9bb19c61f94fa77fc3691030bf8fd1aa33a3d29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=A8=E6=B3=93?= Date: Thu, 16 Apr 2026 16:00:28 +0800 Subject: [PATCH 1/4] fix --- src/twinkle/model/megatron/megatron.py | 10 ++++++++++ src/twinkle/model/multi_lora.py | 2 ++ 2 files changed, 12 insertions(+) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index f087d3a6..351d9250 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -1170,7 +1170,14 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non if dp_rank == 0: self.hf_config.save_pretrained(output_dir) if isinstance(model[0], PeftModel): + config = model[0].peft_config[adapter_name] + target_modules = None + if getattr(config, 'origin_target_modules', None) == 'all-linear': + target_modules = config.target_modules + config.target_modules = 'all-linear' model[0].peft_config[adapter_name].save_pretrained(output_dir) + if getattr(config, 'origin_target_modules', None) == 'all-linear': + config.target_modules = target_modules def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_converter=None): """Save in Megatron checkpoint format.""" @@ -1274,6 +1281,9 @@ def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str config_or_dir = LoraConfig(**config_or_dir) config = config_or_dir + if config.target_modules == 'all-linear': + config.origin_target_modules = 'all-linear' + # Expand target_modules (e.g., 'all-linear' -> actual module names) if config.target_modules: if isinstance(config.target_modules, str): diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index 387e747c..eccea986 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -400,6 +400,8 @@ def _patch_peft(_module): def _patch_megatron(_module): # Expand target_modules (e.g., 'all-linear' -> actual module names) _config = deepcopy(config) + if _config.target_modules == 'all-linear': + _config.origin_target_modules = 'all-linear' if isinstance(_module, PeftModel): _module.add_adapter(lora_tenant.adapter_name, _config) else: From 9ae5859887e6aea531ec0c5a0104137fbd88e6e9 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 16 Apr 2026 20:00:11 +0800 Subject: [PATCH 2/4] fix --- src/twinkle/model/megatron/multi_lora_megatron.py | 3 ++- src/twinkle/model/multi_lora.py | 10 ++++------ .../model/transformers/multi_lora_transformers.py | 3 ++- src/twinkle/server/utils/template_utils.py | 1 + 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 5b4bee49..56872791 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -41,6 +41,7 @@ def __init__( max_loras: int = 5, max_r: int = 32, max_length: int = 8192, + target_modules: Union[List[str], str] = 'all-linear', **kwargs, ): requires('megatron_core') @@ -87,7 +88,7 @@ def __init__( self.model: List[nn.Module] = self.strategy.create_megatron_model(load_weights) MegatronPeft().__call__() self.multi_adapter = MultiLora(max_loras=max_loras, max_r=max_r, max_length=max_length) - self.model = self.multi_adapter.patch(self.model) + self.model = self.multi_adapter.patch(self.model, target_modules=target_modules) self.model = self.strategy.wrap_model(self.model) self.strategy.finish_param_config(self.model, None) self.multi_adapter.save_initial_weights() diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index eccea986..32381bc0 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -54,14 +54,12 @@ def activate_adapter(self, tenant_adapter_name: str, call_enable=False): if call_enable: # This will cost time _module.enable_adapter_layers() - if _module.active_adapter != adapter_name: - _module.set_adapter(adapter_name) + _module.set_adapter(adapter_name) else: if call_enable: # This will cost time self.module.enable_adapter_layers() - if self.module.active_adapter != adapter_name: - self.module.set_adapter(adapter_name) + self.module.set_adapter(adapter_name) def deactivate_adapter(self): if isinstance(self.module, list): @@ -374,11 +372,11 @@ def _get_weight_tensors(self): base_layer.forward = MethodType(_megatron_forward, base_layer) base_layer.layer_name = name - def patch(self, module: Union[torch.nn.Module, List[torch.nn.Module]], *args, **kwargs): + def patch(self, module: Union[torch.nn.Module, List[torch.nn.Module]], target_modules='all-linear', *args, **kwargs): for i in range(self.max_loras): config = LoraConfig( r=self.max_r, - target_modules='all-linear', + target_modules=target_modules, lora_alpha=32, ) lora_tenant = LoraTenant(index=i, adapter_name=f'lora_{i}', config=config) diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index fc2b53cd..e7cf9a3d 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -33,6 +33,7 @@ def __init__( max_loras: int = 5, max_r: int = 32, max_length: int = 8192, + target_modules: Union[List[str], str] = 'all-linear', **kwargs): assert device_mesh.fsdp_world_size <= 0, f'MultiLora does not support FSDP, current is: {str(device_mesh)}' os.environ['TOKENIZERS_PARALLELISM'] = 'true' @@ -66,7 +67,7 @@ def __init__( self.optimizer_group: Dict[str, OptimizerGroup] = {} self.multi_adapter = MultiLora(max_loras=max_loras, max_r=max_r, max_length=max_length) self.model.gradient_checkpointing_enable() - self.model = self.multi_adapter.patch(self.model) + self.model = self.multi_adapter.patch(self.model, target_modules=target_modules) self.strategy = AccelerateStrategy(mixed_precision=mixed_precision, device_mesh=None) self.model = self.strategy.wrap_model(self.model) self.multi_adapter.save_initial_weights() diff --git a/src/twinkle/server/utils/template_utils.py b/src/twinkle/server/utils/template_utils.py index ad015175..c593a8c0 100644 --- a/src/twinkle/server/utils/template_utils.py +++ b/src/twinkle/server/utils/template_utils.py @@ -10,6 +10,7 @@ # Key: model name pattern to match, Value: template name MODEL_TEMPLATE_MAPPING = { 'Qwen3.5': 'Qwen3_5Template', + 'Qwen3.6': 'Qwen3_5Template', # Add more model-template mappings here as needed # 'ModelName': 'TemplateName', } From 829f040c6f1c283e966f460bafaa65326b3cb9f1 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 16 Apr 2026 20:00:36 +0800 Subject: [PATCH 3/4] lint --- src/twinkle/model/multi_lora.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index 32381bc0..18012f54 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -372,7 +372,11 @@ def _get_weight_tensors(self): base_layer.forward = MethodType(_megatron_forward, base_layer) base_layer.layer_name = name - def patch(self, module: Union[torch.nn.Module, List[torch.nn.Module]], target_modules='all-linear', *args, **kwargs): + def patch(self, + module: Union[torch.nn.Module, List[torch.nn.Module]], + target_modules='all-linear', + *args, + **kwargs): for i in range(self.max_loras): config = LoraConfig( r=self.max_r, From 92e77fd3a6ea0ccae3ca201ab9d9199aff9fd602 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 16 Apr 2026 20:19:57 +0800 Subject: [PATCH 4/4] fix --- src/twinkle/model/megatron/multi_lora_megatron.py | 1 + src/twinkle/model/multi_lora.py | 14 ++++++++++++-- .../model/transformers/multi_lora_transformers.py | 1 + 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 56872791..5ec42eac 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -96,6 +96,7 @@ def __init__( self._finish_config = True # Active group for compatibility with single adapter self.active_group = None + self.multi_adapter.reset_adapter_status() def _check_adapter_valid(self, adapter_name: str): assert adapter_name and adapter_name in self.optimizer_group, (f'Use a valid adapter_name first, ' diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index 18012f54..649c0c62 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -45,6 +45,14 @@ def _get_available_lora(self) -> Optional[LoraTenant]: def _count_available_loras(self): return len([_lora for _lora in self.loras if _lora.tenant_adapter_name is None]) + def reset_adapter_status(self): + """Force lora_0 require_grad, disable others""" + if isinstance(self.module, list): + for _module in self.module: + _module.set_adapter('lora_0') + else: + self.module.set_adapter('lora_0') + def activate_adapter(self, tenant_adapter_name: str, call_enable=False): if not self.has_lora(tenant_adapter_name): raise ValueError(f'Adapter {tenant_adapter_name} does not exist') @@ -54,12 +62,14 @@ def activate_adapter(self, tenant_adapter_name: str, call_enable=False): if call_enable: # This will cost time _module.enable_adapter_layers() - _module.set_adapter(adapter_name) + if _module.active_adapter != adapter_name: + _module.set_adapter(adapter_name) else: if call_enable: # This will cost time self.module.enable_adapter_layers() - self.module.set_adapter(adapter_name) + if self.module.active_adapter != adapter_name: + self.module.set_adapter(adapter_name) def deactivate_adapter(self): if isinstance(self.module, list): diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index e7cf9a3d..6a8b0532 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -74,6 +74,7 @@ def __init__( # Active group for compatibility with single adapter self.active_group = None self.handler = self.register_global_mm_forward_hook() + self.multi_adapter.reset_adapter_status() def _check_adapter_valid(self, adapter_name: str): assert adapter_name and adapter_name in self.optimizer_group, (f'Use a valid adapter_name first, '