Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/twinkle/model/megatron/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,14 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non
if dp_rank == 0:
self.hf_config.save_pretrained(output_dir)
if isinstance(model[0], PeftModel):
config = model[0].peft_config[adapter_name]
target_modules = None
if getattr(config, 'origin_target_modules', None) == 'all-linear':
target_modules = config.target_modules
config.target_modules = 'all-linear'
model[0].peft_config[adapter_name].save_pretrained(output_dir)
if getattr(config, 'origin_target_modules', None) == 'all-linear':
config.target_modules = target_modules
Comment thread
tastelikefeet marked this conversation as resolved.

def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_converter=None):
"""Save in Megatron checkpoint format."""
Expand Down Expand Up @@ -1274,6 +1281,9 @@ def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str
config_or_dir = LoraConfig(**config_or_dir)
config = config_or_dir

if config.target_modules == 'all-linear':
config.origin_target_modules = 'all-linear'

# Expand target_modules (e.g., 'all-linear' -> actual module names)
if config.target_modules:
if isinstance(config.target_modules, str):
Expand Down
4 changes: 3 additions & 1 deletion src/twinkle/model/megatron/multi_lora_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
max_loras: int = 5,
max_r: int = 32,
max_length: int = 8192,
target_modules: Union[List[str], str] = 'all-linear',
**kwargs,
):
requires('megatron_core')
Expand Down Expand Up @@ -87,14 +88,15 @@ def __init__(
self.model: List[nn.Module] = self.strategy.create_megatron_model(load_weights)
MegatronPeft().__call__()
self.multi_adapter = MultiLora(max_loras=max_loras, max_r=max_r, max_length=max_length)
self.model = self.multi_adapter.patch(self.model)
self.model = self.multi_adapter.patch(self.model, target_modules=target_modules)
self.model = self.strategy.wrap_model(self.model)
self.strategy.finish_param_config(self.model, None)
self.multi_adapter.save_initial_weights()
self._model_wrapped = True
self._finish_config = True
# Active group for compatibility with single adapter
self.active_group = None
self.multi_adapter.reset_adapter_status()

def _check_adapter_valid(self, adapter_name: str):
assert adapter_name and adapter_name in self.optimizer_group, (f'Use a valid adapter_name first, '
Expand Down
18 changes: 16 additions & 2 deletions src/twinkle/model/multi_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def _get_available_lora(self) -> Optional[LoraTenant]:
def _count_available_loras(self):
return len([_lora for _lora in self.loras if _lora.tenant_adapter_name is None])

def reset_adapter_status(self):
"""Force lora_0 require_grad, disable others"""
if isinstance(self.module, list):
for _module in self.module:
_module.set_adapter('lora_0')
else:
self.module.set_adapter('lora_0')

def activate_adapter(self, tenant_adapter_name: str, call_enable=False):
if not self.has_lora(tenant_adapter_name):
raise ValueError(f'Adapter {tenant_adapter_name} does not exist')
Expand Down Expand Up @@ -374,11 +382,15 @@ def _get_weight_tensors(self):
base_layer.forward = MethodType(_megatron_forward, base_layer)
base_layer.layer_name = name

def patch(self, module: Union[torch.nn.Module, List[torch.nn.Module]], *args, **kwargs):
def patch(self,
module: Union[torch.nn.Module, List[torch.nn.Module]],
target_modules='all-linear',
*args,
**kwargs):
for i in range(self.max_loras):
config = LoraConfig(
r=self.max_r,
target_modules='all-linear',
target_modules=target_modules,
lora_alpha=32,
)
lora_tenant = LoraTenant(index=i, adapter_name=f'lora_{i}', config=config)
Expand All @@ -400,6 +412,8 @@ def _patch_peft(_module):
def _patch_megatron(_module):
# Expand target_modules (e.g., 'all-linear' -> actual module names)
_config = deepcopy(config)
if _config.target_modules == 'all-linear':
_config.origin_target_modules = 'all-linear'
if isinstance(_module, PeftModel):
_module.add_adapter(lora_tenant.adapter_name, _config)
else:
Expand Down
4 changes: 3 additions & 1 deletion src/twinkle/model/transformers/multi_lora_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
max_loras: int = 5,
max_r: int = 32,
max_length: int = 8192,
target_modules: Union[List[str], str] = 'all-linear',
**kwargs):
assert device_mesh.fsdp_world_size <= 0, f'MultiLora does not support FSDP, current is: {str(device_mesh)}'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
Expand Down Expand Up @@ -66,13 +67,14 @@ def __init__(
self.optimizer_group: Dict[str, OptimizerGroup] = {}
self.multi_adapter = MultiLora(max_loras=max_loras, max_r=max_r, max_length=max_length)
self.model.gradient_checkpointing_enable()
self.model = self.multi_adapter.patch(self.model)
self.model = self.multi_adapter.patch(self.model, target_modules=target_modules)
self.strategy = AccelerateStrategy(mixed_precision=mixed_precision, device_mesh=None)
self.model = self.strategy.wrap_model(self.model)
self.multi_adapter.save_initial_weights()
# Active group for compatibility with single adapter
self.active_group = None
self.handler = self.register_global_mm_forward_hook()
self.multi_adapter.reset_adapter_status()

def _check_adapter_valid(self, adapter_name: str):
assert adapter_name and adapter_name in self.optimizer_group, (f'Use a valid adapter_name first, '
Expand Down
1 change: 1 addition & 0 deletions src/twinkle/server/utils/template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# Key: model name pattern to match, Value: template name
MODEL_TEMPLATE_MAPPING = {
'Qwen3.5': 'Qwen3_5Template',
'Qwen3.6': 'Qwen3_5Template',
# Add more model-template mappings here as needed
# 'ModelName': 'TemplateName',
}
Expand Down
Loading