Fix moe weight sync#172
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds a GRPO training script for GSM8K using MultiLoRA and Megatron, featuring filesystem-based LoRA synchronization. It refactors PEFT configuration logic and introduces a context manager for masking modules during export. Feedback identifies a leftover breakpoint and critical issues with in-place configuration modifications that could cause side effects or state corruption during model saving and configuration retrieval.
I am having trouble creating individual review comments. Click here to see my feedback.
src/twinkle/model/megatron/megatron.py (1561-1565)
This block modifies config.target_modules in place but does not restore it, leading to a permanent side effect on the configuration object. This will cause subsequent forward passes to apply LoRA to all linear layers if target_modules was originally a specific list.
Additionally, line 1564 overwrites the 'all-linear' value in the returned dictionary with the original target_modules, which seems to defeat the purpose of the change if the goal was to provide an 'all-linear' config to the sampler.
If the intention is to return a dictionary with target_modules set to 'all-linear', it is safer to modify the dictionary directly without affecting the config object.
_peft_config = config.to_dict() if hasattr(config, 'to_dict') else dict(config)
_peft_config['target_modules'] = 'all-linear'
return _peft_config
src/twinkle/model/megatron/multi_lora_megatron.py (102)
The breakpoint() call should be removed before merging. It will halt execution in non-interactive environments.
src/twinkle/model/megatron/megatron.py (1189-1192)
Modifying config.target_modules in place without a try...finally block is risky. If save_pretrained raises an exception, the configuration object will be left with the incorrect target_modules value, which could affect the model's behavior in subsequent steps.
target_modules = config.target_modules
config.target_modules = 'all-linear'
try:
model[0].peft_config[adapter_name].save_pretrained(output_dir)
finally:
config.target_modules = target_modules
PR type
PR information
Write the detail information belongs to this PR.
Experiment results
Paste your experiment result here(if needed).