diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 32ce84fa2..a3d958fb1 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -174,7 +174,8 @@ def load_model( revision: str = "main", debug: bool = False, trust_remote_code: bool=False, - base_model_revision: str ="main" + base_model_revision: str ="main", + base_model_path: str = None, ): """Load a model from Hugging Face.""" # get model adapter @@ -297,6 +298,7 @@ def load_model( kwargs["trust_remote_code"] = trust_remote_code if is_adapter_model(model_path, revision=revision) is True: kwargs["base_model_revision"] = base_model_revision + kwargs["base_model_path"] = base_model_path # Load model model, tokenizer = adapter.load_model(model_path, kwargs) @@ -487,7 +489,10 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): from peft import PeftConfig, PeftModel revision = from_pretrained_kwargs.get("revision", "main") config = PeftConfig.from_pretrained(model_path, revision=revision) - base_model_path = config.base_model_name_or_path + if "base_model_path" in from_pretrained_kwargs and from_pretrained_kwargs["base_model_path"] is not None: + base_model_path = from_pretrained_kwargs["base_model_path"] + else: + base_model_path = config.base_model_name_or_path if "peft" in base_model_path: raise ValueError( f"PeftModelAdapter cannot load a base model with 'peft' in the name: {config.base_model_name_or_path}" @@ -524,16 +529,23 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): return model, tokenizer # In the normal case, load up the base model weights again. - base_adapter = get_model_adapter(base_model_path) base_model_from_pretrained_kwargs = { "revision": from_pretrained_kwargs.get("base_model_revision", "main"), "trust_remote_code": from_pretrained_kwargs.get("trust_remote_code", False), "device_map": from_pretrained_kwargs.get("device_map", "auto"), "torch_dtype": from_pretrained_kwargs.get("torch_dtype", torch.float16), } + base_adapter = get_model_adapter(base_model_path, revision=base_model_from_pretrained_kwargs["revision"]) + print(f"Loading base model for {base_model_path=} and {base_model_from_pretrained_kwargs=}") base_model, tokenizer = base_adapter.load_model( base_model_path, base_model_from_pretrained_kwargs, ) + # If the base model is also a LoRA adapter, we need to merge those weights **before** loading the second adapter + # Without this, you will get garbage outputs! + if is_adapter_model(base_model_path, base_model_from_pretrained_kwargs["revision"]) is True: + print("Base model is adapter, merging LoRA weights") + base_model.eval() + base_model = base_model.merge_and_unload() print(f"Base model loaded on device {base_model.device} for {base_model_path=} and {base_model_from_pretrained_kwargs=}") model = PeftModel.from_pretrained(base_model, model_path, revision=revision) return model, tokenizer