From f6cd3d72219b8400ecbe2d92d5773848c8010ebb Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Tue, 2 Jan 2024 14:44:57 -0800 Subject: [PATCH] fix lora merge script Signed-off-by: Chen Cui --- .../merge_lora_weights/merge.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/scripts/nlp_language_modeling/merge_lora_weights/merge.py b/scripts/nlp_language_modeling/merge_lora_weights/merge.py index 9a957f487065..4015ce50f41c 100644 --- a/scripts/nlp_language_modeling/merge_lora_weights/merge.py +++ b/scripts/nlp_language_modeling/merge_lora_weights/merge.py @@ -28,7 +28,6 @@ from torch.utils.data import DataLoader, Dataset from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel -from nemo.collections.nlp.models.language_modeling.megatron_gpt_peft_models import MegatronGPTLoRAModel from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector @@ -72,7 +71,9 @@ def load_lora(lora_nemo, tp): l = torch.load(ckpt_file, map_location=torch.device('cpu')) lora_state_dict[i] = l - return lora_state_dict + config_file = f"{tmpdir}/model_config.yaml" + lora_config = OmegaConf.load(config_file) + return lora_state_dict, lora_config def fix_for_O2(state_dict): @@ -195,12 +196,8 @@ def main(cfg) -> None: else: raise ValueError("need at least a nemo file or checkpoint dir") - lora_model_cfg = MegatronGPTLoRAModel.restore_from( - restore_path=cfg.lora_model_path, trainer=trainer, return_config=True, mcore=model.mcore_gpt, - ) - # load the lora weights on cpu for all ranks of the lora model - lora_weights = load_lora(cfg.lora_model_path, model.cfg.tensor_model_parallel_size) + lora_weights, lora_model_cfg = load_lora(cfg.lora_model_path, model.cfg.tensor_model_parallel_size) # merge the lora weights with the base model, for this current rank. merged_weights = merge( @@ -209,6 +206,7 @@ def main(cfg) -> None: tp=model.cfg.tensor_model_parallel_size, num_layers=model.cfg.num_layers, curr_rank=model.global_rank, + mcore=model.mcore_gpt, ) # load the merged_weights back into the base model, for this current rank.