From cbe006fa5320c4b2b8d31ff0c3f98de6871276ef Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 19 Jul 2023 12:00:21 +0530 Subject: [PATCH 1/9] change the expected values since we have better coverage. --- tests/models/test_lora_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 1396561367e0..fd405a668c21 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -554,7 +554,7 @@ def test_a1111(self): images = images[0, -3:, -3:, -1].flatten() - expected = np.array([0.3743, 0.3893, 0.3835, 0.3891, 0.3949, 0.3649, 0.3858, 0.3802, 0.3245]) + expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) self.assertTrue(np.allclose(images, expected, atol=1e-4)) From b44363942251c5948e0a17629d3fbd2017aba380 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 19 Jul 2023 12:51:32 +0530 Subject: [PATCH 2/9] debugging --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index dd307350d385..cbb6f2dcc67b 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1184,6 +1184,7 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): attn_module.out_proj = attn_module.out_proj.regular_linear_layer for _, aux_module in text_encoder_aux_modules(text_encoder): + print("Removed auxiliary modules too.") if isinstance(aux_module.fc1, PatchedLoraProjection): aux_module.fc1 = aux_module.fc1.regular_linear_layer aux_module.fc2 = aux_module.fc2.regular_linear_layer From 175958283d6b5b393494942f32909cdde0071650 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 19 Jul 2023 13:00:25 +0530 Subject: [PATCH 3/9] restart generator for lora_load too. --- tests/models/test_lora_layers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index fd405a668c21..a378429f75ab 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -594,6 +594,7 @@ def test_unload_lora(self): lora_filename = "Colored_Icons_by_vizsumit.safetensors" pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + generator = torch.manual_seed(0) lora_images = pipe( prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps ).images From 1ea4e5e707758e685c0842712fd340be172998b9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 19 Jul 2023 13:56:38 +0530 Subject: [PATCH 4/9] logging when kohya style checkpoint is detected. --- src/diffusers/loaders.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index cbb6f2dcc67b..35418a2003c0 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -62,6 +62,7 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" +TOTAL_EXAMPLE_KEYS = 5 TEXT_INVERSION_NAME = "learned_embeds.bin" TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" @@ -1315,9 +1316,12 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): unet_state_dict_aux = {} te_state_dict_aux = {} network_alpha = None + unloaded_keys = [] for key, value in state_dict.items(): - if "lora_down" in key: + if "hada" in key or "skip" in key: + unloaded_keys.append(key) + elif "lora_down" in key: lora_name = key.split(".")[0] lora_name_up = lora_name + ".lora_up.weight" lora_name_alpha = lora_name + ".alpha" @@ -1352,6 +1356,7 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): elif any(key in diffusers_name for key in ("proj_in", "proj_out")): unet_state_dict_aux[diffusers_name] = value unet_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + elif lora_name.startswith("lora_te_"): diffusers_name = key.replace("lora_te_", "").replace("_", ".") diffusers_name = diffusers_name.replace("text.model", "text_model") @@ -1367,6 +1372,13 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): te_state_dict_aux[diffusers_name] = value te_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + logger.info("Kohya-style checkpoint detected.") + if len(unloaded_keys) > 0: + example_unloaded_keys = ", ".join(x for x in unloaded_keys[:TOTAL_EXAMPLE_KEYS]) + logger.warning( + f"There are some keys (such as: {example_unloaded_keys}) in the checkpoints we don't provide support for." + ) + unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} new_state_dict = {**unet_state_dict, **te_state_dict} From 6e4859dd36c6908584f082e2492f0e925af8b8ab Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 19 Jul 2023 14:38:29 +0530 Subject: [PATCH 5/9] debugging --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 35418a2003c0..e84c2d9f2df7 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -508,6 +508,7 @@ def _load_lora_aux(self, state_dict, network_alpha=None): # install lora target_module.lora_layer = lora + print(f"From UNet aux: {target_module}") class TextualInversionLoaderMixin: From fc8d31cbd06e8741afcf3d17b28c358dc3381a32 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 19 Jul 2023 14:51:14 +0530 Subject: [PATCH 6/9] debugging --- src/diffusers/loaders.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e84c2d9f2df7..32994933f183 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -492,7 +492,9 @@ def _load_lora_aux(self, state_dict, network_alpha=None): if len(target_modules) == 0: logger.warning(f"Could not find module {key} in the model. Skipping.") continue - + + print(f"target modules: {len(target_modules)}") + print(f"target modules: {target_modules[:2]}") target_module = target_modules[0] value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} @@ -508,7 +510,7 @@ def _load_lora_aux(self, state_dict, network_alpha=None): # install lora target_module.lora_layer = lora - print(f"From UNet aux: {target_module}") + # print(f"From UNet aux: {target_module}") class TextualInversionLoaderMixin: From fa59930cbaa709ef8436f4c93caa09ca5753c449 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 19 Jul 2023 15:34:48 +0530 Subject: [PATCH 7/9] handle unloading. --- src/diffusers/loaders.py | 10 +++++++--- src/diffusers/models/lora.py | 4 ++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 32994933f183..27436383331a 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -188,6 +188,7 @@ def map_from(module, state_dict, *args, **kwargs): class UNet2DConditionLoadersMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME + aux_state_dict_populated = None def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): r""" @@ -493,8 +494,6 @@ def _load_lora_aux(self, state_dict, network_alpha=None): logger.warning(f"Could not find module {key} in the model. Skipping.") continue - print(f"target modules: {len(target_modules)}") - print(f"target modules: {target_modules[:2]}") target_module = target_modules[0] value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} @@ -510,7 +509,6 @@ def _load_lora_aux(self, state_dict, network_alpha=None): # install lora target_module.lora_layer = lora - # print(f"From UNet aux: {target_module}") class TextualInversionLoaderMixin: @@ -1066,6 +1064,7 @@ def load_lora_into_unet(cls, state_dict, network_alpha, unet, state_dict_aux=Non if state_dict_aux: unet._load_lora_aux(state_dict_aux, network_alpha=network_alpha) + unet.aux_state_dict_populated = True @classmethod def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0, state_dict_aux=None): @@ -1415,6 +1414,11 @@ def unload_lora_weights(self): self.unet.set_attn_processor(unet_attn_proc_cls()) else: self.unet.set_default_attn_processor() + + if self.unet.aux_state_dict_populated: + for _, module in self.unet.named_modules(): + if hasattr(module, "old_forward") and module.old_forward is not None: + module.forward = module.old_forward # Safe to call the following regardless of LoRA. self._remove_text_encoder_monkey_patch() diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 4949e3c082be..78ab03081fc5 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -87,11 +87,13 @@ class Conv2dWithLoRA(nn.Conv2d): def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs): super().__init__(*args, **kwargs) self.lora_layer = lora_layer + self.old_forward = None def forward(self, x): if self.lora_layer is None: return super().forward(x) else: + self.old_forward = super().forward return super().forward(x) + self.lora_layer(x) @@ -103,9 +105,11 @@ class LinearWithLoRA(nn.Linear): def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): super().__init__(*args, **kwargs) self.lora_layer = lora_layer + self.old_forward = None def forward(self, x): if self.lora_layer is None: return super().forward(x) else: + self.old_forward = super().forward return super().forward(x) + self.lora_layer(x) From 9ae0186e36232a80598bb63f46e7bd3a1cb63167 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 19 Jul 2023 15:43:02 +0530 Subject: [PATCH 8/9] remove unneeded print. --- src/diffusers/loaders.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 27436383331a..ed7fb69c4fed 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -493,7 +493,7 @@ def _load_lora_aux(self, state_dict, network_alpha=None): if len(target_modules) == 0: logger.warning(f"Could not find module {key} in the model. Skipping.") continue - + target_module = target_modules[0] value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} @@ -1187,7 +1187,6 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): attn_module.out_proj = attn_module.out_proj.regular_linear_layer for _, aux_module in text_encoder_aux_modules(text_encoder): - print("Removed auxiliary modules too.") if isinstance(aux_module.fc1, PatchedLoraProjection): aux_module.fc1 = aux_module.fc1.regular_linear_layer aux_module.fc2 = aux_module.fc2.regular_linear_layer @@ -1414,10 +1413,10 @@ def unload_lora_weights(self): self.unet.set_attn_processor(unet_attn_proc_cls()) else: self.unet.set_default_attn_processor() - + if self.unet.aux_state_dict_populated: for _, module in self.unet.named_modules(): - if hasattr(module, "old_forward") and module.old_forward is not None: + if hasattr(module, "old_forward") and module.old_forward is not None: module.forward = module.old_forward # Safe to call the following regardless of LoRA. From c26a02d13fc64bb4dff0983a9694f5fea827b413 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 19 Jul 2023 21:13:04 +0530 Subject: [PATCH 9/9] Update src/diffusers/loaders.py Co-authored-by: Batuhan Taskaya --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index ed7fb69c4fed..7401345d93b3 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1418,6 +1418,7 @@ def unload_lora_weights(self): for _, module in self.unet.named_modules(): if hasattr(module, "old_forward") and module.old_forward is not None: module.forward = module.old_forward + self.unet.aux_state_dict_populated = False # Safe to call the following regardless of LoRA. self._remove_text_encoder_monkey_patch()