From 11d22b926eaf382533d662958fe1dfdaa83e31e6 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 2 Jan 2024 02:10:40 +0000 Subject: [PATCH 1/4] Enable loading of chained PEFT models --- fastchat/model/model_adapter.py | 12 +++++++++--- pyproject.toml | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 32ce84fa2..e8380b48e 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -174,7 +174,8 @@ def load_model( revision: str = "main", debug: bool = False, trust_remote_code: bool=False, - base_model_revision: str ="main" + base_model_revision: str ="main", + base_model_path: str = None, ): """Load a model from Hugging Face.""" # get model adapter @@ -297,6 +298,7 @@ def load_model( kwargs["trust_remote_code"] = trust_remote_code if is_adapter_model(model_path, revision=revision) is True: kwargs["base_model_revision"] = base_model_revision + kwargs["base_model_path"] = base_model_path # Load model model, tokenizer = adapter.load_model(model_path, kwargs) @@ -487,7 +489,10 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): from peft import PeftConfig, PeftModel revision = from_pretrained_kwargs.get("revision", "main") config = PeftConfig.from_pretrained(model_path, revision=revision) - base_model_path = config.base_model_name_or_path + if "base_model_path" in from_pretrained_kwargs and from_pretrained_kwargs["base_model_path"] is not None: + base_model_path = from_pretrained_kwargs["base_model_path"] + else: + base_model_path = config.base_model_name_or_path if "peft" in base_model_path: raise ValueError( f"PeftModelAdapter cannot load a base model with 'peft' in the name: {config.base_model_name_or_path}" @@ -524,13 +529,14 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): return model, tokenizer # In the normal case, load up the base model weights again. - base_adapter = get_model_adapter(base_model_path) base_model_from_pretrained_kwargs = { "revision": from_pretrained_kwargs.get("base_model_revision", "main"), "trust_remote_code": from_pretrained_kwargs.get("trust_remote_code", False), "device_map": from_pretrained_kwargs.get("device_map", "auto"), "torch_dtype": from_pretrained_kwargs.get("torch_dtype", torch.float16), } + base_adapter = get_model_adapter(base_model_path, revision=base_model_from_pretrained_kwargs["revision"]) + print(f"Loading base model for {base_model_path=} and {base_model_from_pretrained_kwargs=}") base_model, tokenizer = base_adapter.load_model( base_model_path, base_model_from_pretrained_kwargs, ) diff --git a/pyproject.toml b/pyproject.toml index 01e60c035..1e0063eb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0"] webui = ["gradio"] train = ["einops", "flash-attn>=2.0", "wandb"] -llm_judge = ["openai", "anthropic>=0.3", "ray"] +llm_judge = ["openai", "anthropic>=0.3", "ray", "hf_transfer"] dev = ["black==23.3.0", "pylint==2.8.2"] [project.urls] From 99872d9b1f2dc832dfa452b80ba85c6b42c93676 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 2 Jan 2024 02:51:17 +0000 Subject: [PATCH 2/4] Fix deps --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1e0063eb9..01e60c035 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0"] webui = ["gradio"] train = ["einops", "flash-attn>=2.0", "wandb"] -llm_judge = ["openai", "anthropic>=0.3", "ray", "hf_transfer"] +llm_judge = ["openai", "anthropic>=0.3", "ray"] dev = ["black==23.3.0", "pylint==2.8.2"] [project.urls] From 2524800466a9d722b3667ac151e0bb76910beb52 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 2 Jan 2024 06:07:03 +0000 Subject: [PATCH 3/4] Fix LoRA loading --- fastchat/model/model_adapter.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index e8380b48e..5f1168c07 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -540,6 +540,11 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): base_model, tokenizer = base_adapter.load_model( base_model_path, base_model_from_pretrained_kwargs, ) + # If the base model is also a LoRA adapter, we need to merge those weights **before** loading the second adapter + # Without this, you will get garbage outputs! + if is_adapter_model(base_model_path, base_model_from_pretrained_kwargs["revision"]) is True: + print("Base model is adapter, merging LoRA weights") + base_model.merge_and_unload() print(f"Base model loaded on device {base_model.device} for {base_model_path=} and {base_model_from_pretrained_kwargs=}") model = PeftModel.from_pretrained(base_model, model_path, revision=revision) return model, tokenizer From e536b51ddc5dd92972bbeaf8cb65e44650c149a4 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 2 Jan 2024 12:16:50 +0000 Subject: [PATCH 4/4] Fix merge --- fastchat/model/model_adapter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 5f1168c07..a3d958fb1 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -544,7 +544,8 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): # Without this, you will get garbage outputs! if is_adapter_model(base_model_path, base_model_from_pretrained_kwargs["revision"]) is True: print("Base model is adapter, merging LoRA weights") - base_model.merge_and_unload() + base_model.eval() + base_model = base_model.merge_and_unload() print(f"Base model loaded on device {base_model.device} for {base_model_path=} and {base_model_from_pretrained_kwargs=}") model = PeftModel.from_pretrained(base_model, model_path, revision=revision) return model, tokenizer