diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index f087d3a6..351d9250 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -1170,7 +1170,14 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non if dp_rank == 0: self.hf_config.save_pretrained(output_dir) if isinstance(model[0], PeftModel): + config = model[0].peft_config[adapter_name] + target_modules = None + if getattr(config, 'origin_target_modules', None) == 'all-linear': + target_modules = config.target_modules + config.target_modules = 'all-linear' model[0].peft_config[adapter_name].save_pretrained(output_dir) + if getattr(config, 'origin_target_modules', None) == 'all-linear': + config.target_modules = target_modules def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_converter=None): """Save in Megatron checkpoint format.""" @@ -1274,6 +1281,9 @@ def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str config_or_dir = LoraConfig(**config_or_dir) config = config_or_dir + if config.target_modules == 'all-linear': + config.origin_target_modules = 'all-linear' + # Expand target_modules (e.g., 'all-linear' -> actual module names) if config.target_modules: if isinstance(config.target_modules, str): diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 5b4bee49..5ec42eac 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -41,6 +41,7 @@ def __init__( max_loras: int = 5, max_r: int = 32, max_length: int = 8192, + target_modules: Union[List[str], str] = 'all-linear', **kwargs, ): requires('megatron_core') @@ -87,7 +88,7 @@ def __init__( self.model: List[nn.Module] = self.strategy.create_megatron_model(load_weights) MegatronPeft().__call__() self.multi_adapter = MultiLora(max_loras=max_loras, max_r=max_r, max_length=max_length) - self.model = self.multi_adapter.patch(self.model) + self.model = self.multi_adapter.patch(self.model, target_modules=target_modules) self.model = self.strategy.wrap_model(self.model) self.strategy.finish_param_config(self.model, None) self.multi_adapter.save_initial_weights() @@ -95,6 +96,7 @@ def __init__( self._finish_config = True # Active group for compatibility with single adapter self.active_group = None + self.multi_adapter.reset_adapter_status() def _check_adapter_valid(self, adapter_name: str): assert adapter_name and adapter_name in self.optimizer_group, (f'Use a valid adapter_name first, ' diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index 387e747c..649c0c62 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -45,6 +45,14 @@ def _get_available_lora(self) -> Optional[LoraTenant]: def _count_available_loras(self): return len([_lora for _lora in self.loras if _lora.tenant_adapter_name is None]) + def reset_adapter_status(self): + """Force lora_0 require_grad, disable others""" + if isinstance(self.module, list): + for _module in self.module: + _module.set_adapter('lora_0') + else: + self.module.set_adapter('lora_0') + def activate_adapter(self, tenant_adapter_name: str, call_enable=False): if not self.has_lora(tenant_adapter_name): raise ValueError(f'Adapter {tenant_adapter_name} does not exist') @@ -374,11 +382,15 @@ def _get_weight_tensors(self): base_layer.forward = MethodType(_megatron_forward, base_layer) base_layer.layer_name = name - def patch(self, module: Union[torch.nn.Module, List[torch.nn.Module]], *args, **kwargs): + def patch(self, + module: Union[torch.nn.Module, List[torch.nn.Module]], + target_modules='all-linear', + *args, + **kwargs): for i in range(self.max_loras): config = LoraConfig( r=self.max_r, - target_modules='all-linear', + target_modules=target_modules, lora_alpha=32, ) lora_tenant = LoraTenant(index=i, adapter_name=f'lora_{i}', config=config) @@ -400,6 +412,8 @@ def _patch_peft(_module): def _patch_megatron(_module): # Expand target_modules (e.g., 'all-linear' -> actual module names) _config = deepcopy(config) + if _config.target_modules == 'all-linear': + _config.origin_target_modules = 'all-linear' if isinstance(_module, PeftModel): _module.add_adapter(lora_tenant.adapter_name, _config) else: diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index fc2b53cd..6a8b0532 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -33,6 +33,7 @@ def __init__( max_loras: int = 5, max_r: int = 32, max_length: int = 8192, + target_modules: Union[List[str], str] = 'all-linear', **kwargs): assert device_mesh.fsdp_world_size <= 0, f'MultiLora does not support FSDP, current is: {str(device_mesh)}' os.environ['TOKENIZERS_PARALLELISM'] = 'true' @@ -66,13 +67,14 @@ def __init__( self.optimizer_group: Dict[str, OptimizerGroup] = {} self.multi_adapter = MultiLora(max_loras=max_loras, max_r=max_r, max_length=max_length) self.model.gradient_checkpointing_enable() - self.model = self.multi_adapter.patch(self.model) + self.model = self.multi_adapter.patch(self.model, target_modules=target_modules) self.strategy = AccelerateStrategy(mixed_precision=mixed_precision, device_mesh=None) self.model = self.strategy.wrap_model(self.model) self.multi_adapter.save_initial_weights() # Active group for compatibility with single adapter self.active_group = None self.handler = self.register_global_mm_forward_hook() + self.multi_adapter.reset_adapter_status() def _check_adapter_valid(self, adapter_name: str): assert adapter_name and adapter_name in self.optimizer_group, (f'Use a valid adapter_name first, ' diff --git a/src/twinkle/server/utils/template_utils.py b/src/twinkle/server/utils/template_utils.py index ad015175..c593a8c0 100644 --- a/src/twinkle/server/utils/template_utils.py +++ b/src/twinkle/server/utils/template_utils.py @@ -10,6 +10,7 @@ # Key: model name pattern to match, Value: template name MODEL_TEMPLATE_MAPPING = { 'Qwen3.5': 'Qwen3_5Template', + 'Qwen3.6': 'Qwen3_5Template', # Add more model-template mappings here as needed # 'ModelName': 'TemplateName', }