Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions scripts/nlp_language_modeling/merge_lora_weights/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from torch.utils.data import DataLoader, Dataset

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.models.language_modeling.megatron_gpt_peft_models import MegatronGPTLoRAModel
from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector
Expand Down Expand Up @@ -72,7 +71,9 @@ def load_lora(lora_nemo, tp):

l = torch.load(ckpt_file, map_location=torch.device('cpu'))
lora_state_dict[i] = l
return lora_state_dict
config_file = f"{tmpdir}/model_config.yaml"
lora_config = OmegaConf.load(config_file)
return lora_state_dict, lora_config


def fix_for_O2(state_dict):
Expand Down Expand Up @@ -195,12 +196,8 @@ def main(cfg) -> None:
else:
raise ValueError("need at least a nemo file or checkpoint dir")

lora_model_cfg = MegatronGPTLoRAModel.restore_from(
restore_path=cfg.lora_model_path, trainer=trainer, return_config=True, mcore=model.mcore_gpt,
)

# load the lora weights on cpu for all ranks of the lora model
lora_weights = load_lora(cfg.lora_model_path, model.cfg.tensor_model_parallel_size)
lora_weights, lora_model_cfg = load_lora(cfg.lora_model_path, model.cfg.tensor_model_parallel_size)

# merge the lora weights with the base model, for this current rank.
merged_weights = merge(
Expand All @@ -209,6 +206,7 @@ def main(cfg) -> None:
tp=model.cfg.tensor_model_parallel_size,
num_layers=model.cfg.num_layers,
curr_rank=model.global_rank,
mcore=model.mcore_gpt,
)

# load the merged_weights back into the base model, for this current rank.
Expand Down