From 27d1ad4b9767fc61f062a595642d22251e32aae5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 29 Aug 2023 20:44:42 +0000 Subject: [PATCH 1/7] Fix Unfuse Lora --- src/diffusers/loaders.py | 26 +++++++++++++++----------- src/diffusers/models/lora.py | 9 ++++----- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 699395418c51..f0b8ddcc7dab 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -85,12 +85,16 @@ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank= self.lora_scale = lora_scale + # overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved + # when saving the whole text encoder model + def state_dict(self, *args, destination=None, prefix='', keep_vars=False): + return self.regular_linear_layer.state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) + def _fuse_lora(self): if self.lora_linear_layer is None: return dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device - logger.info(f"Fusing LoRA weights for {self.__class__}") w_orig = self.regular_linear_layer.weight.data.float() w_up = self.lora_linear_layer.up.weight.data.float() @@ -112,14 +116,14 @@ def _fuse_lora(self): def _unfuse_lora(self): if not (hasattr(self, "w_up") and hasattr(self, "w_down")): return - logger.info(f"Unfusing LoRA weights for {self.__class__}") fused_weight = self.regular_linear_layer.weight.data dtype, device = fused_weight.dtype, fused_weight.device - self.w_up = self.w_up.to(device=device, dtype=dtype) - self.w_down = self.w_down.to(device, dtype=dtype) - unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0] + w_up = self.w_up.to(device=device).float() + w_down = self.w_down.to(device).float() + + unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0] self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype) self.w_up = None @@ -1405,15 +1409,15 @@ def _remove_text_encoder_monkey_patch(self): def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): for _, attn_module in text_encoder_attn_modules(text_encoder): if isinstance(attn_module.q_proj, PatchedLoraProjection): - attn_module.q_proj = attn_module.q_proj.regular_linear_layer - attn_module.k_proj = attn_module.k_proj.regular_linear_layer - attn_module.v_proj = attn_module.v_proj.regular_linear_layer - attn_module.out_proj = attn_module.out_proj.regular_linear_layer + attn_module.q_proj.lora_linear_layer = None + attn_module.k_proj.lora_linear_layer = None + attn_module.v_proj.lora_linear_layer = None + attn_module.out_proj.linear_layer = None for _, mlp_module in text_encoder_mlp_modules(text_encoder): if isinstance(mlp_module.fc1, PatchedLoraProjection): - mlp_module.fc1 = mlp_module.fc1.regular_linear_layer - mlp_module.fc2 = mlp_module.fc2.regular_linear_layer + mlp_module.fc1.lora_linear_layer = None + mlp_module.fc2.lora_linear_layer = None @classmethod def _modify_text_encoder( diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index cb0b3b45eb69..671c93a3b2b2 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -168,7 +168,6 @@ def _fuse_lora(self): return dtype, device = self.weight.data.dtype, self.weight.data.device - logger.info(f"Fusing LoRA weights for {self.__class__}") w_orig = self.weight.data.float() w_up = self.lora_layer.up.weight.data.float() @@ -190,14 +189,14 @@ def _fuse_lora(self): def _unfuse_lora(self): if not (hasattr(self, "w_up") and hasattr(self, "w_down")): return - logger.info(f"Unfusing LoRA weights for {self.__class__}") fused_weight = self.weight.data dtype, device = fused_weight.dtype, fused_weight.device - self.w_up = self.w_up.to(device=device, dtype=dtype) - self.w_down = self.w_down.to(device, dtype=dtype) - unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0] + w_up = self.w_up.to(device=device).float() + w_down = self.w_down.to(device).float() + + unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0] self.weight.data = unfused_weight.to(device=device, dtype=dtype) self.w_up = None From aec02399406dc7514acc774d74ac0d587d94e7d2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 29 Aug 2023 21:00:21 +0000 Subject: [PATCH 2/7] add tests --- tests/models/test_lora_layers.py | 65 ++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index f17529a1680e..3f0544bf0354 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -16,6 +16,7 @@ import tempfile import time import unittest +import copy import numpy as np import torch @@ -99,6 +100,16 @@ def set_lora_weights(lora_attn_parameters, randn_weight=False): else: torch.zero_(parameter) +def state_dicts_almost_equal(sd1, sd2): + sd1 = sorted(sd1) + sd2 = sorted(sd2) + + models_are_equal = True + for ten1, ten2 in zip(sd1.values(), sd2.values()): + if (ten1 - ten2).abs().sum() > 1e-3: + models_are_equal = False + + return models_are_equal class LoraLoaderMixinTests(unittest.TestCase): def get_dummy_components(self): @@ -674,6 +685,41 @@ def test_load_lora_locally(self): sd_pipe.unload_lora_weights() + def test_text_encoder_lora_state_dict_unchanged(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + + text_encoder_1_sd_keys = sorted(list(sd_pipe.text_encoder.state_dict().keys())) + text_encoder_2_sd_keys = sorted(list(sd_pipe.text_encoder.state_dict().keys())) + + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=False, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + + text_encoder_1_sd_keys_2 = sorted(list(sd_pipe.text_encoder.state_dict().keys())) + text_encoder_2_sd_keys_2 = sorted(list(sd_pipe.text_encoder.state_dict().keys())) + + sd_pipe.unload_lora_weights() + + text_encoder_1_sd_keys_3 = sorted(list(sd_pipe.text_encoder.state_dict().keys())) + text_encoder_2_sd_keys_3 = sorted(list(sd_pipe.text_encoder.state_dict().keys())) + + assert text_encoder_1_sd_keys == text_encoder_1_sd_keys_2 + assert text_encoder_1_sd_keys == text_encoder_1_sd_keys_3 + + assert text_encoder_2_sd_keys == text_encoder_2_sd_keys_2 + assert text_encoder_2_sd_keys == text_encoder_2_sd_keys_3 + def test_load_lora_locally_safetensors(self): pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**pipeline_components) @@ -1187,3 +1233,22 @@ def test_sdxl_1_0_last_ben(self): expected = np.array([0.5244, 0.4347, 0.4312, 0.4246, 0.4398, 0.4409, 0.4884, 0.4938, 0.4094]) self.assertTrue(np.allclose(images, expected, atol=1e-3)) + + def test_sdxl_1_0_fuse_unfuse_all(self): + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict) + text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict) + unet_sd = copy.deepcopy(pipe.unet.state_dict) + + pipe.load_lora_weights("davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors") + pipe.fuse_lora() + pipe.unload_lora_weights() + pipe.unfuse_lora() + + new_text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict) + new_text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict) + new_unet_sd = copy.deepcopy(pipe.unet.state_dict) + + assert(state_dicts_almost_equal(text_encoder_1_sd, new_text_encoder_1_sd)) + assert(state_dicts_almost_equal(text_encoder_2_sd, new_text_encoder_2_sd)) + assert(state_dicts_almost_equal(unet_sd, new_unet_sd)) From 682f66458759ef3faa04195dc5c2da95efd17771 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 29 Aug 2023 21:05:02 +0000 Subject: [PATCH 3/7] Fix more --- src/diffusers/loaders.py | 6 ++++-- tests/models/test_lora_layers.py | 8 ++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index f0b8ddcc7dab..5c0d36303add 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -86,9 +86,11 @@ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank= self.lora_scale = lora_scale # overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved - # when saving the whole text encoder model + # when saving the whole text encoder model and when LoRA is unloaded or fused def state_dict(self, *args, destination=None, prefix='', keep_vars=False): - return self.regular_linear_layer.state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) + if self.lora_linear_layer is None: + return self.regular_linear_layer.state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) + return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) def _fuse_lora(self): if self.lora_linear_layer is None: diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 3f0544bf0354..cf19084c5f0c 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -714,11 +714,15 @@ def test_text_encoder_lora_state_dict_unchanged(self): text_encoder_1_sd_keys_3 = sorted(list(sd_pipe.text_encoder.state_dict().keys())) text_encoder_2_sd_keys_3 = sorted(list(sd_pipe.text_encoder.state_dict().keys())) - assert text_encoder_1_sd_keys == text_encoder_1_sd_keys_2 + # default & unloaded LoRA weights should have identical state_dicts assert text_encoder_1_sd_keys == text_encoder_1_sd_keys_3 + # default & loaded LoRA weights should NOT have identical state_dicts + assert text_encoder_1_sd_keys != text_encoder_1_sd_keys_2 # - assert text_encoder_2_sd_keys == text_encoder_2_sd_keys_2 + # default & unloaded LoRA weights should have identical state_dicts assert text_encoder_2_sd_keys == text_encoder_2_sd_keys_3 + # default & loaded LoRA weights should NOT have identical state_dicts + assert text_encoder_2_sd_keys != text_encoder_2_sd_keys_2 def test_load_lora_locally_safetensors(self): pipeline_components, lora_components = self.get_dummy_components() From e796874a5ddaed1ae9b1cc837cd45fc3714abcbd Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 29 Aug 2023 21:11:25 +0000 Subject: [PATCH 4/7] Fix more --- src/diffusers/loaders.py | 3 ++- tests/models/test_lora_layers.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 5c0d36303add..2401941bed6e 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -90,6 +90,7 @@ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank= def state_dict(self, *args, destination=None, prefix='', keep_vars=False): if self.lora_linear_layer is None: return self.regular_linear_layer.state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) + return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) def _fuse_lora(self): @@ -1414,7 +1415,7 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): attn_module.q_proj.lora_linear_layer = None attn_module.k_proj.lora_linear_layer = None attn_module.v_proj.lora_linear_layer = None - attn_module.out_proj.linear_layer = None + attn_module.out_proj.lora_linear_layer = None for _, mlp_module in text_encoder_mlp_modules(text_encoder): if isinstance(mlp_module.fc1, PatchedLoraProjection): diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index cf19084c5f0c..ee0d5156f930 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -690,7 +690,7 @@ def test_text_encoder_lora_state_dict_unchanged(self): sd_pipe = StableDiffusionXLPipeline(**pipeline_components) text_encoder_1_sd_keys = sorted(list(sd_pipe.text_encoder.state_dict().keys())) - text_encoder_2_sd_keys = sorted(list(sd_pipe.text_encoder.state_dict().keys())) + text_encoder_2_sd_keys = sorted(list(sd_pipe.text_encoder_2.state_dict().keys())) sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) @@ -707,12 +707,13 @@ def test_text_encoder_lora_state_dict_unchanged(self): sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) text_encoder_1_sd_keys_2 = sorted(list(sd_pipe.text_encoder.state_dict().keys())) - text_encoder_2_sd_keys_2 = sorted(list(sd_pipe.text_encoder.state_dict().keys())) + text_encoder_2_sd_keys_2 = sorted(list(sd_pipe.text_encoder_2.state_dict().keys())) sd_pipe.unload_lora_weights() + print("suh du") text_encoder_1_sd_keys_3 = sorted(list(sd_pipe.text_encoder.state_dict().keys())) - text_encoder_2_sd_keys_3 = sorted(list(sd_pipe.text_encoder.state_dict().keys())) + text_encoder_2_sd_keys_3 = sorted(list(sd_pipe.text_encoder_2.state_dict().keys())) # default & unloaded LoRA weights should have identical state_dicts assert text_encoder_1_sd_keys == text_encoder_1_sd_keys_3 From 3bccb9878e86b398fefc09ee2f752c73c63c98ac Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 29 Aug 2023 21:13:32 +0000 Subject: [PATCH 5/7] Fix all --- tests/models/test_lora_layers.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index ee0d5156f930..a5455fbb8c6a 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -101,8 +101,8 @@ def set_lora_weights(lora_attn_parameters, randn_weight=False): torch.zero_(parameter) def state_dicts_almost_equal(sd1, sd2): - sd1 = sorted(sd1) - sd2 = sorted(sd2) + sd1 = dict(sorted(sd1.items())) + sd2 = dict(sorted(sd2.items())) models_are_equal = True for ten1, ten2 in zip(sd1.values(), sd2.values()): @@ -710,7 +710,6 @@ def test_text_encoder_lora_state_dict_unchanged(self): text_encoder_2_sd_keys_2 = sorted(list(sd_pipe.text_encoder_2.state_dict().keys())) sd_pipe.unload_lora_weights() - print("suh du") text_encoder_1_sd_keys_3 = sorted(list(sd_pipe.text_encoder.state_dict().keys())) text_encoder_2_sd_keys_3 = sorted(list(sd_pipe.text_encoder_2.state_dict().keys())) @@ -1241,18 +1240,18 @@ def test_sdxl_1_0_last_ben(self): def test_sdxl_1_0_fuse_unfuse_all(self): pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") - text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict) - text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict) - unet_sd = copy.deepcopy(pipe.unet.state_dict) + text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict()) + text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict()) + unet_sd = copy.deepcopy(pipe.unet.state_dict()) pipe.load_lora_weights("davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors") pipe.fuse_lora() pipe.unload_lora_weights() pipe.unfuse_lora() - new_text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict) - new_text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict) - new_unet_sd = copy.deepcopy(pipe.unet.state_dict) + new_text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict()) + new_text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict()) + new_unet_sd = copy.deepcopy(pipe.unet.state_dict()) assert(state_dicts_almost_equal(text_encoder_1_sd, new_text_encoder_1_sd)) assert(state_dicts_almost_equal(text_encoder_2_sd, new_text_encoder_2_sd)) From 88789150a28d930a970ccb64819b0b979cc7a213 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 29 Aug 2023 21:15:05 +0000 Subject: [PATCH 6/7] make style --- src/diffusers/loaders.py | 6 ++++-- tests/models/test_lora_layers.py | 26 ++++++++++++++------------ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 2401941bed6e..eb4204976193 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -87,9 +87,11 @@ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank= # overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved # when saving the whole text encoder model and when LoRA is unloaded or fused - def state_dict(self, *args, destination=None, prefix='', keep_vars=False): + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): if self.lora_linear_layer is None: - return self.regular_linear_layer.state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) + return self.regular_linear_layer.state_dict( + *args, destination=destination, prefix=prefix, keep_vars=keep_vars + ) return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index a5455fbb8c6a..848f2f44adc9 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -12,11 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import os import tempfile import time import unittest -import copy import numpy as np import torch @@ -100,6 +100,7 @@ def set_lora_weights(lora_attn_parameters, randn_weight=False): else: torch.zero_(parameter) + def state_dicts_almost_equal(sd1, sd2): sd1 = dict(sorted(sd1.items())) sd2 = dict(sorted(sd2.items())) @@ -108,9 +109,10 @@ def state_dicts_almost_equal(sd1, sd2): for ten1, ten2 in zip(sd1.values(), sd2.values()): if (ten1 - ten2).abs().sum() > 1e-3: models_are_equal = False - + return models_are_equal + class LoraLoaderMixinTests(unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) @@ -689,8 +691,8 @@ def test_text_encoder_lora_state_dict_unchanged(self): pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**pipeline_components) - text_encoder_1_sd_keys = sorted(list(sd_pipe.text_encoder.state_dict().keys())) - text_encoder_2_sd_keys = sorted(list(sd_pipe.text_encoder_2.state_dict().keys())) + text_encoder_1_sd_keys = sorted(sd_pipe.text_encoder.state_dict().keys()) + text_encoder_2_sd_keys = sorted(sd_pipe.text_encoder_2.state_dict().keys()) sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) @@ -706,18 +708,18 @@ def test_text_encoder_lora_state_dict_unchanged(self): self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - text_encoder_1_sd_keys_2 = sorted(list(sd_pipe.text_encoder.state_dict().keys())) - text_encoder_2_sd_keys_2 = sorted(list(sd_pipe.text_encoder_2.state_dict().keys())) + text_encoder_1_sd_keys_2 = sorted(sd_pipe.text_encoder.state_dict().keys()) + text_encoder_2_sd_keys_2 = sorted(sd_pipe.text_encoder_2.state_dict().keys()) sd_pipe.unload_lora_weights() - text_encoder_1_sd_keys_3 = sorted(list(sd_pipe.text_encoder.state_dict().keys())) - text_encoder_2_sd_keys_3 = sorted(list(sd_pipe.text_encoder_2.state_dict().keys())) + text_encoder_1_sd_keys_3 = sorted(sd_pipe.text_encoder.state_dict().keys()) + text_encoder_2_sd_keys_3 = sorted(sd_pipe.text_encoder_2.state_dict().keys()) # default & unloaded LoRA weights should have identical state_dicts assert text_encoder_1_sd_keys == text_encoder_1_sd_keys_3 # default & loaded LoRA weights should NOT have identical state_dicts - assert text_encoder_1_sd_keys != text_encoder_1_sd_keys_2 # + assert text_encoder_1_sd_keys != text_encoder_1_sd_keys_2 # # default & unloaded LoRA weights should have identical state_dicts assert text_encoder_2_sd_keys == text_encoder_2_sd_keys_3 @@ -1253,6 +1255,6 @@ def test_sdxl_1_0_fuse_unfuse_all(self): new_text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict()) new_unet_sd = copy.deepcopy(pipe.unet.state_dict()) - assert(state_dicts_almost_equal(text_encoder_1_sd, new_text_encoder_1_sd)) - assert(state_dicts_almost_equal(text_encoder_2_sd, new_text_encoder_2_sd)) - assert(state_dicts_almost_equal(unet_sd, new_unet_sd)) + assert state_dicts_almost_equal(text_encoder_1_sd, new_text_encoder_1_sd) + assert state_dicts_almost_equal(text_encoder_2_sd, new_text_encoder_2_sd) + assert state_dicts_almost_equal(unet_sd, new_unet_sd) From 17474c4b96e0b010f5da19639b4efba2f026c399 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 29 Aug 2023 21:42:41 +0000 Subject: [PATCH 7/7] make style --- src/diffusers/loaders.py | 42 ++++++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index eb4204976193..52970e48147d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1456,23 +1456,43 @@ def _modify_text_encoder( else: current_rank = rank + q_linear_layer = ( + attn_module.q_proj.regular_linear_layer + if isinstance(attn_module.q_proj, PatchedLoraProjection) + else attn_module.q_proj + ) attn_module.q_proj = PatchedLoraProjection( - attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype + q_linear_layer, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype ) lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters()) + k_linear_layer = ( + attn_module.k_proj.regular_linear_layer + if isinstance(attn_module.k_proj, PatchedLoraProjection) + else attn_module.k_proj + ) attn_module.k_proj = PatchedLoraProjection( - attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype + k_linear_layer, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype ) lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters()) + v_linear_layer = ( + attn_module.v_proj.regular_linear_layer + if isinstance(attn_module.v_proj, PatchedLoraProjection) + else attn_module.v_proj + ) attn_module.v_proj = PatchedLoraProjection( - attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype + v_linear_layer, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype ) lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters()) + out_linear_layer = ( + attn_module.out_proj.regular_linear_layer + if isinstance(attn_module.out_proj, PatchedLoraProjection) + else attn_module.out_proj + ) attn_module.out_proj = PatchedLoraProjection( - attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype + out_linear_layer, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype ) lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters()) @@ -1484,13 +1504,23 @@ def _modify_text_encoder( current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight") current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight") + fc1_linear_layer = ( + mlp_module.fc1.regular_linear_layer + if isinstance(mlp_module.fc1, PatchedLoraProjection) + else mlp_module.fc1 + ) mlp_module.fc1 = PatchedLoraProjection( - mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype + fc1_linear_layer, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype ) lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters()) + fc2_linear_layer = ( + mlp_module.fc2.regular_linear_layer + if isinstance(mlp_module.fc2, PatchedLoraProjection) + else mlp_module.fc2 + ) mlp_module.fc2 = PatchedLoraProjection( - mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype + fc2_linear_layer, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype ) lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())