From c7a369afd3c88c7cbe412ff2b7b826d8d9aaaa64 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 18 Aug 2023 16:55:16 +0530 Subject: [PATCH 001/119] make controlnet sublcass from a loraloader --- src/diffusers/models/controlnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index ed3f3e687143..905b2c2b2280 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -19,7 +19,7 @@ from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import FromOriginalControlnetMixin +from ..loaders import FromOriginalControlnetMixin, UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps @@ -101,7 +101,7 @@ def forward(self, conditioning): return embedding -class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): +class ControlNetModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, FromOriginalControlnetMixin): """ A ControlNet model. From 9a78f038fae40fddfd29fb292b28b2f436d86219 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 18 Aug 2023 17:48:24 +0530 Subject: [PATCH 002/119] wondering' --- src/diffusers/loaders.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 81404e4c9968..5837e3472b8d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -362,7 +362,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict for key, value_dict in lora_grouped_dict.items(): attn_processor = self + print(f"Self type: {type(self)}") for sub_key in key.split("."): + print(f"From UNet: {key}") attn_processor = getattr(attn_processor, sub_key) # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers From e9fe443cca4e8128ae61ef6e5326fe3b553ab47b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 18 Aug 2023 17:53:01 +0530 Subject: [PATCH 003/119] wondering' --- src/diffusers/loaders.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 5837e3472b8d..9fafd8334c08 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -360,12 +360,14 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}" ) + temp = 0 for key, value_dict in lora_grouped_dict.items(): attn_processor = self - print(f"Self type: {type(self)}") for sub_key in key.split("."): - print(f"From UNet: {key}") + if temp < 50: + print(f"From UNet: {key}, {sub_key}") attn_processor = getattr(attn_processor, sub_key) + temp += 1 # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers # or add_{k,v,q,out_proj}_proj_lora layers. From 2d4ae0026d6c8d3b6c6c88b05613b6994f71ce49 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 11:25:09 +0530 Subject: [PATCH 004/119] relax check. --- src/diffusers/loaders.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 9fafd8334c08..92f7822a65bd 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1511,7 +1511,7 @@ def save_function(weights, filename): logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") @classmethod - def _convert_kohya_lora_to_diffusers(cls, state_dict): + def _convert_kohya_lora_to_diffusers(cls, state_dict, strict=True): unet_state_dict = {} te_state_dict = {} te2_state_dict = {} @@ -1647,11 +1647,12 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" network_alphas.update({new_name: alpha}) - if len(state_dict) > 0: + if strict and len(state_dict) > 0: raise ValueError( f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}" ) + logger.info("Kohya-style checkpoint detected.") unet_state_dict = {f"{cls.unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} te_state_dict = { @@ -1666,7 +1667,11 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): te_state_dict.update(te2_state_dict) new_state_dict = {**unet_state_dict, **te_state_dict} - return new_state_dict, network_alphas + + if strict: + return state_dict, new_state_dict, network_alphas + else: + return new_state_dict, network_alphas def unload_lora_weights(self): """ From 49327162c9ccf816d6f81a17e0301d325286ed11 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 11:29:35 +0530 Subject: [PATCH 005/119] exploring --- src/diffusers/loaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 92f7822a65bd..592b25a63928 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1669,9 +1669,9 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict, strict=True): new_state_dict = {**unet_state_dict, **te_state_dict} if strict: - return state_dict, new_state_dict, network_alphas - else: return new_state_dict, network_alphas + else: + return state_dict, new_state_dict, network_alphas def unload_lora_weights(self): """ From e73696082190c21a6404fa0d525b468faa2c5b56 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 11:33:43 +0530 Subject: [PATCH 006/119] sai controlnet --- src/diffusers/loaders.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 592b25a63928..74d43dad4234 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1511,7 +1511,7 @@ def save_function(weights, filename): logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") @classmethod - def _convert_kohya_lora_to_diffusers(cls, state_dict, strict=True): + def _convert_kohya_lora_to_diffusers(cls, state_dict, is_sai_controlnet=False, strict=True): unet_state_dict = {} te_state_dict = {} te2_state_dict = {} @@ -1524,7 +1524,7 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict, strict=True): lora_name_up = lora_name + ".lora_up.weight" lora_name_alpha = lora_name + ".alpha" - if lora_name.startswith("lora_unet_"): + if lora_name.startswith("lora_unet_") or is_sai_controlnet: diffusers_name = key.replace("lora_unet_", "").replace("_", ".") if "input.blocks" in diffusers_name: @@ -1652,7 +1652,6 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict, strict=True): f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}" ) - logger.info("Kohya-style checkpoint detected.") unet_state_dict = {f"{cls.unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} te_state_dict = { @@ -1667,8 +1666,8 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict, strict=True): te_state_dict.update(te2_state_dict) new_state_dict = {**unet_state_dict, **te_state_dict} - - if strict: + + if strict: return new_state_dict, network_alphas else: return state_dict, new_state_dict, network_alphas From 30dee21a346238c964da4e50e9ae4426ad45f454 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 13:20:14 +0530 Subject: [PATCH 007/119] let's see --- src/diffusers/loaders.py | 83 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 76 insertions(+), 7 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 74d43dad4234..8902079ec548 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1511,7 +1511,7 @@ def save_function(weights, filename): logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") @classmethod - def _convert_kohya_lora_to_diffusers(cls, state_dict, is_sai_controlnet=False, strict=True): + def _convert_kohya_lora_to_diffusers(cls, state_dict): unet_state_dict = {} te_state_dict = {} te2_state_dict = {} @@ -1524,7 +1524,7 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict, is_sai_controlnet=False, s lora_name_up = lora_name + ".lora_up.weight" lora_name_alpha = lora_name + ".alpha" - if lora_name.startswith("lora_unet_") or is_sai_controlnet: + if lora_name.startswith("lora_unet_"): diffusers_name = key.replace("lora_unet_", "").replace("_", ".") if "input.blocks" in diffusers_name: @@ -1647,7 +1647,7 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict, is_sai_controlnet=False, s new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" network_alphas.update({new_name: alpha}) - if strict and len(state_dict) > 0: + if len(state_dict) > 0: raise ValueError( f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}" ) @@ -1667,10 +1667,79 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict, is_sai_controlnet=False, s new_state_dict = {**unet_state_dict, **te_state_dict} - if strict: - return new_state_dict, network_alphas - else: - return state_dict, new_state_dict, network_alphas + return new_state_dict, network_alphas + + @classmethod + def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): + controlnet_state_dict = {} + + # every down weight has a corresponding up weight and potentially an alpha weight + lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] + for key in lora_keys: + lora_name = key.split(".")[0] + lora_name_up = lora_name + ".lora_up.weight" + diffusers_name = key.replace("_", ".") + + if "input.blocks" in diffusers_name: + diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") + else: + diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") + + if "middle.block" in diffusers_name: + diffusers_name = diffusers_name.replace("middle.block", "mid_block") + else: + diffusers_name = diffusers_name.replace("mid.block", "mid_block") + if "output.blocks" in diffusers_name: + diffusers_name = diffusers_name.replace("output.blocks", "up_blocks") + else: + diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") + + diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") + diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") + diffusers_name = diffusers_name.replace("proj.in", "proj_in") + diffusers_name = diffusers_name.replace("proj.out", "proj_out") + diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj") + + # SDXL specificity. + if "emb" in diffusers_name: + pattern = r"\.\d+(?=\D*$)" + diffusers_name = re.sub(pattern, "", diffusers_name, count=1) + if ".in." in diffusers_name: + diffusers_name = diffusers_name.replace("in.layers.2", "conv1") + if ".out." in diffusers_name: + diffusers_name = diffusers_name.replace("out.layers.3", "conv2") + if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name: + diffusers_name = diffusers_name.replace("op", "conv") + if "skip" in diffusers_name: + diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") + + if "transformer_blocks" in diffusers_name: + if "attn1" in diffusers_name or "attn2" in diffusers_name: + diffusers_name = diffusers_name.replace("attn1", "attn1.processor") + diffusers_name = diffusers_name.replace("attn2", "attn2.processor") + controlnet_state_dict[diffusers_name] = state_dict.pop(key) + controlnet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + elif "ff" in diffusers_name: + controlnet_state_dict[diffusers_name] = state_dict.pop(key) + controlnet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + elif any(key in diffusers_name for key in ("proj_in", "proj_out")): + controlnet_state_dict[diffusers_name] = state_dict.pop(key) + controlnet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + else: + controlnet_state_dict[diffusers_name] = state_dict.pop(key) + controlnet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + if len(state_dict) > 0: + raise ValueError( + f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}" + ) + + logger.info("StabilityAI ControlNet LoRA checkpoint detected.") + + return controlnet_state_dict def unload_lora_weights(self): """ From 6f9e14bcfc0739df0cbb324822e3049cbebcc30f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 13:25:10 +0530 Subject: [PATCH 008/119] debugging --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 8902079ec548..7496c8353045 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1729,6 +1729,7 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): controlnet_state_dict[diffusers_name] = state_dict.pop(key) controlnet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) else: + print(f"Actual key: {key}, Diffusers name: {diffusers_name}") controlnet_state_dict[diffusers_name] = state_dict.pop(key) controlnet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) From 2257ba9dd32b1b5b2b1dd9b28a5d6eb0ec443178 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 13:28:21 +0530 Subject: [PATCH 009/119] debugging --- src/diffusers/loaders.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 7496c8353045..f973cf48e2ad 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1730,6 +1730,9 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): controlnet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) else: print(f"Actual key: {key}, Diffusers name: {diffusers_name}") + print( + f"diffusers_name (replaced): {diffusers_name.replace('.down.', '.up.')}, lora_name_up: {lora_name_up}" + ) controlnet_state_dict[diffusers_name] = state_dict.pop(key) controlnet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) From 38fb6fe37be3214b798fcf47bd380b863a8933fc Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 13:38:42 +0530 Subject: [PATCH 010/119] debugging --- src/diffusers/loaders.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index f973cf48e2ad..3dbfe74c4ad7 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1576,6 +1576,8 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): unet_state_dict[diffusers_name] = state_dict.pop(key) unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) else: + if "weightsamplers" in key and "op" in key: + print(f"Actual key: {key} Diffusers name: {diffusers_name} lora_name_up: {lora_name_up}") unet_state_dict[diffusers_name] = state_dict.pop(key) unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) @@ -1673,7 +1675,7 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): controlnet_state_dict = {} - # every down weight has a corresponding up weight and potentially an alpha weight + # every down weight has a corresponding up weight lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] for key in lora_keys: lora_name = key.split(".")[0] From c8ec943cbad777ce27b00a5b0777ddc0154a5061 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 13:44:10 +0530 Subject: [PATCH 011/119] remove unnecessary statements. --- src/diffusers/loaders.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 3dbfe74c4ad7..7f208a8f4bb9 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -360,15 +360,11 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}" ) - temp = 0 for key, value_dict in lora_grouped_dict.items(): attn_processor = self for sub_key in key.split("."): - if temp < 50: - print(f"From UNet: {key}, {sub_key}") attn_processor = getattr(attn_processor, sub_key) - temp += 1 - + # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers # or add_{k,v,q,out_proj}_proj_lora layers. if "lora.down.weight" in value_dict: From 070983480faaa273e652efbcdbc267482db97b59 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 13:47:50 +0530 Subject: [PATCH 012/119] simplify condition. --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 7f208a8f4bb9..8611832004b8 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1572,7 +1572,7 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): unet_state_dict[diffusers_name] = state_dict.pop(key) unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) else: - if "weightsamplers" in key and "op" in key: + if "op" in key: print(f"Actual key: {key} Diffusers name: {diffusers_name} lora_name_up: {lora_name_up}") unet_state_dict[diffusers_name] = state_dict.pop(key) unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) From 86515e44912fda1bd508a96878f08ce3c360bade Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 13:52:46 +0530 Subject: [PATCH 013/119] seeing. --- src/diffusers/loaders.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 8611832004b8..bda9969914ee 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1559,6 +1559,9 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): if "skip" in diffusers_name: diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") + if "op" in key: + print(f"Actual key: {key} Diffusers name: {diffusers_name} lora_name_up: {lora_name_up}") + if "transformer_blocks" in diffusers_name: if "attn1" in diffusers_name or "attn2" in diffusers_name: diffusers_name = diffusers_name.replace("attn1", "attn1.processor") @@ -1572,8 +1575,6 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): unet_state_dict[diffusers_name] = state_dict.pop(key) unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) else: - if "op" in key: - print(f"Actual key: {key} Diffusers name: {diffusers_name} lora_name_up: {lora_name_up}") unet_state_dict[diffusers_name] = state_dict.pop(key) unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) From a9dfd863110a3be7cf2ee11610a911c17ccb2797 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 14:42:20 +0530 Subject: [PATCH 014/119] debugging --- src/diffusers/loaders.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index bda9969914ee..d9e2de2fa51a 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1124,6 +1124,10 @@ def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", bl + [str(block_id), inner_block_key, inner_layers_in_block] + key.split(delimiter)[block_slice_pos + 1 :] ) + if "op" in key: + print(key.split(delimiter)[: block_slice_pos - 1], + [str(block_id), inner_block_key, inner_layers_in_block], + key.split(delimiter)[block_slice_pos + 1 :]) new_state_dict[new_key] = state_dict.pop(key) for i in middle_block_ids: @@ -1527,7 +1531,6 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") else: diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") - if "middle.block" in diffusers_name: diffusers_name = diffusers_name.replace("middle.block", "mid_block") else: @@ -1559,9 +1562,6 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): if "skip" in diffusers_name: diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") - if "op" in key: - print(f"Actual key: {key} Diffusers name: {diffusers_name} lora_name_up: {lora_name_up}") - if "transformer_blocks" in diffusers_name: if "attn1" in diffusers_name or "attn2" in diffusers_name: diffusers_name = diffusers_name.replace("attn1", "attn1.processor") From 4baa7e3945eb8b354e9f1cd4625c0be3d1345c41 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 15:17:26 +0530 Subject: [PATCH 015/119] debugging --- src/diffusers/loaders.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index d9e2de2fa51a..20e3c1586d96 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1124,10 +1124,6 @@ def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", bl + [str(block_id), inner_block_key, inner_layers_in_block] + key.split(delimiter)[block_slice_pos + 1 :] ) - if "op" in key: - print(key.split(delimiter)[: block_slice_pos - 1], - [str(block_id), inner_block_key, inner_layers_in_block], - key.split(delimiter)[block_slice_pos + 1 :]) new_state_dict[new_key] = state_dict.pop(key) for i in middle_block_ids: From df3dfe3668dda8b62ac29f715fc24243c3e17b4d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 15:30:42 +0530 Subject: [PATCH 016/119] debugging --- src/diffusers/loaders.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 20e3c1586d96..8135563dfdd0 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1671,9 +1671,13 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): # every down weight has a corresponding up weight lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] for key in lora_keys: - lora_name = key.split(".")[0] - lora_name_up = lora_name + ".lora_up.weight" - diffusers_name = key.replace("_", ".") + if "linear_1" not in key: + lora_name = key.split(".")[0] + lora_name_up = lora_name + ".lora_up.weight" + diffusers_name = key.replace("_", ".") + else: + lora_name_up = key.replace("lora_down", ".lora_up") + diffusers_name = key if "input.blocks" in diffusers_name: diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") @@ -1724,10 +1728,6 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): controlnet_state_dict[diffusers_name] = state_dict.pop(key) controlnet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) else: - print(f"Actual key: {key}, Diffusers name: {diffusers_name}") - print( - f"diffusers_name (replaced): {diffusers_name.replace('.down.', '.up.')}, lora_name_up: {lora_name_up}" - ) controlnet_state_dict[diffusers_name] = state_dict.pop(key) controlnet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) From dde7ed64316f17ab9d3213ab3b334ab859a380d1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 15:32:16 +0530 Subject: [PATCH 017/119] debugging --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 8135563dfdd0..3c99fe9a4c21 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1676,7 +1676,7 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): lora_name_up = lora_name + ".lora_up.weight" diffusers_name = key.replace("_", ".") else: - lora_name_up = key.replace("lora_down", ".lora_up") + lora_name_up = key.replace("lora_down", "lora_up") diffusers_name = key if "input.blocks" in diffusers_name: From 04f663d6641599981d44d2b25384cc879458794f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 15:34:54 +0530 Subject: [PATCH 018/119] debugging --- src/diffusers/loaders.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 3c99fe9a4c21..19e14f41904f 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1671,7 +1671,7 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): # every down weight has a corresponding up weight lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] for key in lora_keys: - if "linear_1" not in key: + if "time_embedding" not in key: lora_name = key.split(".")[0] lora_name_up = lora_name + ".lora_up.weight" diffusers_name = key.replace("_", ".") @@ -1679,6 +1679,9 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): lora_name_up = key.replace("lora_down", "lora_up") diffusers_name = key + if "time" in key: + print(f"Diffusers name: {diffusers_name} actual key: {key} up: {lora_name_up}") + if "input.blocks" in diffusers_name: diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") else: From e47b47dab6a007d82707b8aa90ba83d49b64aa15 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 15:39:41 +0530 Subject: [PATCH 019/119] debugging --- src/diffusers/loaders.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 19e14f41904f..7d4afcc9a274 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1667,11 +1667,12 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): @classmethod def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): controlnet_state_dict = {} + exceptional_keys = {"time_embedding", "add_embedding"} # every down weight has a corresponding up weight lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] for key in lora_keys: - if "time_embedding" not in key: + if not any(k in key for k in exceptional_keys): lora_name = key.split(".")[0] lora_name_up = lora_name + ".lora_up.weight" diffusers_name = key.replace("_", ".") @@ -1679,8 +1680,8 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): lora_name_up = key.replace("lora_down", "lora_up") diffusers_name = key - if "time" in key: - print(f"Diffusers name: {diffusers_name} actual key: {key} up: {lora_name_up}") + # if "time" in key: + # print(f"Diffusers name: {diffusers_name} actual key: {key} up: {lora_name_up}") if "input.blocks" in diffusers_name: diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") From 54d1508c5affaad851ffab70aa69a2afaf06fc16 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 15:41:59 +0530 Subject: [PATCH 020/119] successful LoRA state dict parsing. --- src/diffusers/loaders.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 7d4afcc9a274..26e5f09df03d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1666,7 +1666,7 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): @classmethod def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): - controlnet_state_dict = {} + controlnet_lora_state_dict = {} exceptional_keys = {"time_embedding", "add_embedding"} # every down weight has a corresponding up weight @@ -1723,26 +1723,21 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): if "attn1" in diffusers_name or "attn2" in diffusers_name: diffusers_name = diffusers_name.replace("attn1", "attn1.processor") diffusers_name = diffusers_name.replace("attn2", "attn2.processor") - controlnet_state_dict[diffusers_name] = state_dict.pop(key) - controlnet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) elif "ff" in diffusers_name: - controlnet_state_dict[diffusers_name] = state_dict.pop(key) - controlnet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) elif any(key in diffusers_name for key in ("proj_in", "proj_out")): - controlnet_state_dict[diffusers_name] = state_dict.pop(key) - controlnet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) else: - controlnet_state_dict[diffusers_name] = state_dict.pop(key) - controlnet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - - if len(state_dict) > 0: - raise ValueError( - f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}" - ) + controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) logger.info("StabilityAI ControlNet LoRA checkpoint detected.") - return controlnet_state_dict + return controlnet_lora_state_dict, state_dict def unload_lora_weights(self): """ From 6adc8d55d550a74c20b94968d719870cf00f4452 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 15:49:51 +0530 Subject: [PATCH 021/119] successful LoRA state dict parsing. --- src/diffusers/loaders.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 26e5f09df03d..275e877eb058 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1668,9 +1668,11 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): controlnet_lora_state_dict = {} exceptional_keys = {"time_embedding", "add_embedding"} + print(f"Total state_dict: {len(state_dict)}") # every down weight has a corresponding up weight lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] + print(f"Total LoRA state_dict: {len(lora_keys)}") for key in lora_keys: if not any(k in key for k in exceptional_keys): lora_name = key.split(".")[0] From 24a2551f667cd7b4db96115eedb6b3140b1c9af1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 16:00:19 +0530 Subject: [PATCH 022/119] debugging --- src/diffusers/loaders.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 275e877eb058..b34cb2a710d4 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1671,9 +1671,9 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): print(f"Total state_dict: {len(state_dict)}") # every down weight has a corresponding up weight - lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] - print(f"Total LoRA state_dict: {len(lora_keys)}") - for key in lora_keys: + lora_state_dict = {k: v for k, v in state_dict.items() if k.endswith("lora_down.weight")} + print(f"Total LoRA state_dict: {len(lora_state_dict)}") + for key in lora_state_dict: if not any(k in key for k in exceptional_keys): lora_name = key.split(".")[0] lora_name_up = lora_name + ".lora_up.weight" @@ -1725,17 +1725,19 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): if "attn1" in diffusers_name or "attn2" in diffusers_name: diffusers_name = diffusers_name.replace("attn1", "attn1.processor") diffusers_name = diffusers_name.replace("attn2", "attn2.processor") - controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + controlnet_lora_state_dict[diffusers_name] = lora_state_dict.pop(key) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop(lora_name_up) elif "ff" in diffusers_name: - controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + controlnet_lora_state_dict[diffusers_name] = lora_state_dict.pop(key) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop(lora_name_up) elif any(key in diffusers_name for key in ("proj_in", "proj_out")): - controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + controlnet_lora_state_dict[diffusers_name] = lora_state_dict.pop(key) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop(lora_name_up) else: - controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + controlnet_lora_state_dict[diffusers_name] = lora_state_dict.pop(key) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop(lora_name_up) + + print(f"Remaining keys in the LoRA state dict: {lora_state_dict.keys()}") logger.info("StabilityAI ControlNet LoRA checkpoint detected.") From 09003fb60c3f8dc1fac9d9c0917395fc13a7544c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 16:02:58 +0530 Subject: [PATCH 023/119] debugging --- src/diffusers/loaders.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index b34cb2a710d4..d66a03706c70 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -364,7 +364,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict attn_processor = self for sub_key in key.split("."): attn_processor = getattr(attn_processor, sub_key) - + # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers # or add_{k,v,q,out_proj}_proj_lora layers. if "lora.down.weight" in value_dict: @@ -1671,7 +1671,7 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): print(f"Total state_dict: {len(state_dict)}") # every down weight has a corresponding up weight - lora_state_dict = {k: v for k, v in state_dict.items() if k.endswith("lora_down.weight")} + lora_state_dict = {k: v for k, v in state_dict.items() if k.endswith(("lora_down.weight", "lora_up.weight"))} print(f"Total LoRA state_dict: {len(lora_state_dict)}") for key in lora_state_dict: if not any(k in key for k in exceptional_keys): @@ -1682,9 +1682,6 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): lora_name_up = key.replace("lora_down", "lora_up") diffusers_name = key - # if "time" in key: - # print(f"Diffusers name: {diffusers_name} actual key: {key} up: {lora_name_up}") - if "input.blocks" in diffusers_name: diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") else: @@ -1726,16 +1723,24 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): diffusers_name = diffusers_name.replace("attn1", "attn1.processor") diffusers_name = diffusers_name.replace("attn2", "attn2.processor") controlnet_lora_state_dict[diffusers_name] = lora_state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop(lora_name_up) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop( + lora_name_up + ) elif "ff" in diffusers_name: controlnet_lora_state_dict[diffusers_name] = lora_state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop(lora_name_up) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop( + lora_name_up + ) elif any(key in diffusers_name for key in ("proj_in", "proj_out")): controlnet_lora_state_dict[diffusers_name] = lora_state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop(lora_name_up) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop( + lora_name_up + ) else: controlnet_lora_state_dict[diffusers_name] = lora_state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop(lora_name_up) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop( + lora_name_up + ) print(f"Remaining keys in the LoRA state dict: {lora_state_dict.keys()}") From 8d19befc03d30c79a920fed8ff452b60a0dade6d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 16:08:30 +0530 Subject: [PATCH 024/119] debugging --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index d66a03706c70..bb099f27711e 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1673,7 +1673,7 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): # every down weight has a corresponding up weight lora_state_dict = {k: v for k, v in state_dict.items() if k.endswith(("lora_down.weight", "lora_up.weight"))} print(f"Total LoRA state_dict: {len(lora_state_dict)}") - for key in lora_state_dict: + for key in state_dict: if not any(k in key for k in exceptional_keys): lora_name = key.split(".")[0] lora_name_up = lora_name + ".lora_up.weight" From 260d5cc619f8a42da0920feb94bda09af970fcfb Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 16:09:53 +0530 Subject: [PATCH 025/119] debugging --- src/diffusers/loaders.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index bb099f27711e..7f03894ddcd0 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1673,7 +1673,8 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): # every down weight has a corresponding up weight lora_state_dict = {k: v for k, v in state_dict.items() if k.endswith(("lora_down.weight", "lora_up.weight"))} print(f"Total LoRA state_dict: {len(lora_state_dict)}") - for key in state_dict: + lora_keys = list(lora_state_dict.keys()) + for key in lora_keys: if not any(k in key for k in exceptional_keys): lora_name = key.split(".")[0] lora_name_up = lora_name + ".lora_up.weight" From 3ad63ea168f55cbd96f7089e3b81aae3b9b9d11d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 16:17:04 +0530 Subject: [PATCH 026/119] debugging --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 7f03894ddcd0..f4a4b89a433b 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1673,7 +1673,7 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): # every down weight has a corresponding up weight lora_state_dict = {k: v for k, v in state_dict.items() if k.endswith(("lora_down.weight", "lora_up.weight"))} print(f"Total LoRA state_dict: {len(lora_state_dict)}") - lora_keys = list(lora_state_dict.keys()) + lora_keys = [k for k in lora_state_dict.keys() if "lora_down.weight" in k] for key in lora_keys: if not any(k in key for k in exceptional_keys): lora_name = key.split(".")[0] From 58604783b128387f3854fe07bc2e6bfd0496340e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 16:22:38 +0530 Subject: [PATCH 027/119] debugging --- src/diffusers/loaders.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index f4a4b89a433b..faee3a79c1e3 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1672,8 +1672,9 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): # every down weight has a corresponding up weight lora_state_dict = {k: v for k, v in state_dict.items() if k.endswith(("lora_down.weight", "lora_up.weight"))} - print(f"Total LoRA state_dict: {len(lora_state_dict)}") lora_keys = [k for k in lora_state_dict.keys() if "lora_down.weight" in k] + print(f"Total LoRA state_dict: {len(lora_keys) * 2}") + assert len(lora_state_dict) == 2 * len(lora_keys) for key in lora_keys: if not any(k in key for k in exceptional_keys): lora_name = key.split(".")[0] From e572736547045d30b0095079717823aff7e4babd Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 16:27:16 +0530 Subject: [PATCH 028/119] debugging --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index faee3a79c1e3..1224c43dc576 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1745,6 +1745,7 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): ) print(f"Remaining keys in the LoRA state dict: {lora_state_dict.keys()}") + assert 2 * len(lora_keys) == len(controlnet_lora_state_dict) logger.info("StabilityAI ControlNet LoRA checkpoint detected.") From c3e0dd830dcd786e2582debf01b1a371f953306f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 16:33:27 +0530 Subject: [PATCH 029/119] debugging --- src/diffusers/loaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 1224c43dc576..348e8f8e22e5 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1675,7 +1675,7 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): lora_keys = [k for k in lora_state_dict.keys() if "lora_down.weight" in k] print(f"Total LoRA state_dict: {len(lora_keys) * 2}") assert len(lora_state_dict) == 2 * len(lora_keys) - for key in lora_keys: + for totality, key in enumerate(lora_keys): if not any(k in key for k in exceptional_keys): lora_name = key.split(".")[0] lora_name_up = lora_name + ".lora_up.weight" @@ -1743,7 +1743,7 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop( lora_name_up ) - + print(f"Total traversed: {totality}.") print(f"Remaining keys in the LoRA state dict: {lora_state_dict.keys()}") assert 2 * len(lora_keys) == len(controlnet_lora_state_dict) From 3924166bed228d89c2e3fc132cf8b5435f351c21 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 16:38:02 +0530 Subject: [PATCH 030/119] debugging --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 348e8f8e22e5..4330bf7c293f 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1743,7 +1743,7 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop( lora_name_up ) - print(f"Total traversed: {totality}.") + print(f"Total traversed: {lora_keys[totality:]}.") print(f"Remaining keys in the LoRA state dict: {lora_state_dict.keys()}") assert 2 * len(lora_keys) == len(controlnet_lora_state_dict) From 00fea8a0e777184b569f4c1e8c1ed8bf68b791dd Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 16:42:12 +0530 Subject: [PATCH 031/119] debugging --- src/diffusers/loaders.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 4330bf7c293f..7efa05b27e1b 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -38,6 +38,7 @@ is_transformers_available, logging, ) +from tqdm import tqdm from .utils.import_utils import BACKENDS_MAPPING @@ -1675,7 +1676,7 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): lora_keys = [k for k in lora_state_dict.keys() if "lora_down.weight" in k] print(f"Total LoRA state_dict: {len(lora_keys) * 2}") assert len(lora_state_dict) == 2 * len(lora_keys) - for totality, key in enumerate(lora_keys): + for totality, key in enumerate(tqdm(lora_keys)): if not any(k in key for k in exceptional_keys): lora_name = key.split(".")[0] lora_name_up = lora_name + ".lora_up.weight" @@ -1743,9 +1744,9 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop( lora_name_up ) - print(f"Total traversed: {lora_keys[totality:]}.") - print(f"Remaining keys in the LoRA state dict: {lora_state_dict.keys()}") - assert 2 * len(lora_keys) == len(controlnet_lora_state_dict) + # print(f"Total traversed: {lora_keys[totality:]}.") + # print(f"Remaining keys in the LoRA state dict: {lora_state_dict.keys()}") + # assert 2 * len(lora_keys) == len(controlnet_lora_state_dict) logger.info("StabilityAI ControlNet LoRA checkpoint detected.") From 12d7b5dfd9f72f315e4084d2825b08005591ed69 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 16:44:31 +0530 Subject: [PATCH 032/119] debugging --- src/diffusers/loaders.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 7efa05b27e1b..ca34aa4e2f2b 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1676,7 +1676,7 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): lora_keys = [k for k in lora_state_dict.keys() if "lora_down.weight" in k] print(f"Total LoRA state_dict: {len(lora_keys) * 2}") assert len(lora_state_dict) == 2 * len(lora_keys) - for totality, key in enumerate(tqdm(lora_keys)): + for totality, key in enumerate(lora_keys): if not any(k in key for k in exceptional_keys): lora_name = key.split(".")[0] lora_name_up = lora_name + ".lora_up.weight" @@ -1740,11 +1740,12 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): lora_name_up ) else: + print(key) controlnet_lora_state_dict[diffusers_name] = lora_state_dict.pop(key) controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop( lora_name_up ) - # print(f"Total traversed: {lora_keys[totality:]}.") + print(f"Total traversed: {lora_keys[totality - 1:]}.") # print(f"Remaining keys in the LoRA state dict: {lora_state_dict.keys()}") # assert 2 * len(lora_keys) == len(controlnet_lora_state_dict) From a58abee3d517f2b45858570fac49bc2e4b1d43a2 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 16:49:13 +0530 Subject: [PATCH 033/119] debugging --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index ca34aa4e2f2b..2a1a90d5aac8 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1683,6 +1683,7 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): diffusers_name = key.replace("_", ".") else: lora_name_up = key.replace("lora_down", "lora_up") + print(lora_name_up) diffusers_name = key if "input.blocks" in diffusers_name: From 6295db5e174ea8cf9c52bc29189ea373913ca29a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 16:53:55 +0530 Subject: [PATCH 034/119] debugging --- src/diffusers/loaders.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 2a1a90d5aac8..4ae8bb89a6af 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1741,7 +1741,6 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): lora_name_up ) else: - print(key) controlnet_lora_state_dict[diffusers_name] = lora_state_dict.pop(key) controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop( lora_name_up From ae1a178b73309e7b4a6e3cbefd48f7654e50d37d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 16:59:28 +0530 Subject: [PATCH 035/119] debugging --- src/diffusers/loaders.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 4ae8bb89a6af..f2f48de676dd 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1676,7 +1676,7 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): lora_keys = [k for k in lora_state_dict.keys() if "lora_down.weight" in k] print(f"Total LoRA state_dict: {len(lora_keys) * 2}") assert len(lora_state_dict) == 2 * len(lora_keys) - for totality, key in enumerate(lora_keys): + for key in lora_keys: if not any(k in key for k in exceptional_keys): lora_name = key.split(".")[0] lora_name_up = lora_name + ".lora_up.weight" @@ -1740,14 +1740,19 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop( lora_name_up ) + elif any(k in diffusers_name for k in exceptional_keys): + diffusers_name = diffusers_name.replace("lora_down", "lora.down") + controlnet_lora_state_dict[diffusers_name] = lora_state_dict.pop(key) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop( + lora_name_up + ) else: controlnet_lora_state_dict[diffusers_name] = lora_state_dict.pop(key) controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop( lora_name_up ) - print(f"Total traversed: {lora_keys[totality - 1:]}.") - # print(f"Remaining keys in the LoRA state dict: {lora_state_dict.keys()}") - # assert 2 * len(lora_keys) == len(controlnet_lora_state_dict) + + assert 2 * len(lora_keys) == len(controlnet_lora_state_dict) logger.info("StabilityAI ControlNet LoRA checkpoint detected.") From 58c9f985aec089f8e6c9abb45645038347dec54e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 17:01:46 +0530 Subject: [PATCH 036/119] debugging --- src/diffusers/loaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index f2f48de676dd..0a950d4a90c4 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -38,7 +38,6 @@ is_transformers_available, logging, ) -from tqdm import tqdm from .utils.import_utils import BACKENDS_MAPPING @@ -1755,8 +1754,9 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): assert 2 * len(lora_keys) == len(controlnet_lora_state_dict) logger.info("StabilityAI ControlNet LoRA checkpoint detected.") + non_lora_state_dict = {k: v for k, v in state_dict.items() if k not in controlnet_lora_state_dict} - return controlnet_lora_state_dict, state_dict + return controlnet_lora_state_dict, non_lora_state_dict def unload_lora_weights(self): """ From e047c4e9bd0f5fa3dde51df170d48509d6540d36 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 17:05:24 +0530 Subject: [PATCH 037/119] better state dict munging --- src/diffusers/loaders.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 0a950d4a90c4..3b0fdc96aef9 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1668,13 +1668,9 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): controlnet_lora_state_dict = {} exceptional_keys = {"time_embedding", "add_embedding"} - print(f"Total state_dict: {len(state_dict)}") # every down weight has a corresponding up weight - lora_state_dict = {k: v for k, v in state_dict.items() if k.endswith(("lora_down.weight", "lora_up.weight"))} - lora_keys = [k for k in lora_state_dict.keys() if "lora_down.weight" in k] - print(f"Total LoRA state_dict: {len(lora_keys) * 2}") - assert len(lora_state_dict) == 2 * len(lora_keys) + lora_keys = [k for k in state_dict.keys() if "lora_down.weight" in k] for key in lora_keys: if not any(k in key for k in exceptional_keys): lora_name = key.split(".")[0] @@ -1725,38 +1721,37 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): if "attn1" in diffusers_name or "attn2" in diffusers_name: diffusers_name = diffusers_name.replace("attn1", "attn1.processor") diffusers_name = diffusers_name.replace("attn2", "attn2.processor") - controlnet_lora_state_dict[diffusers_name] = lora_state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop( + controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop( lora_name_up ) elif "ff" in diffusers_name: - controlnet_lora_state_dict[diffusers_name] = lora_state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop( + controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop( lora_name_up ) elif any(key in diffusers_name for key in ("proj_in", "proj_out")): - controlnet_lora_state_dict[diffusers_name] = lora_state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop( + controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop( lora_name_up ) elif any(k in diffusers_name for k in exceptional_keys): diffusers_name = diffusers_name.replace("lora_down", "lora.down") - controlnet_lora_state_dict[diffusers_name] = lora_state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop( + controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop( lora_name_up ) else: - controlnet_lora_state_dict[diffusers_name] = lora_state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = lora_state_dict.pop( + controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop( lora_name_up ) assert 2 * len(lora_keys) == len(controlnet_lora_state_dict) logger.info("StabilityAI ControlNet LoRA checkpoint detected.") - non_lora_state_dict = {k: v for k, v in state_dict.items() if k not in controlnet_lora_state_dict} - return controlnet_lora_state_dict, non_lora_state_dict + return controlnet_lora_state_dict, state_dict def unload_lora_weights(self): """ From 4436870fd91737dc942fabc1bc6b58b05ad0f590 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 17:07:06 +0530 Subject: [PATCH 038/119] remove print --- src/diffusers/loaders.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 3b0fdc96aef9..e9eac672f601 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1678,7 +1678,6 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): diffusers_name = key.replace("_", ".") else: lora_name_up = key.replace("lora_down", "lora_up") - print(lora_name_up) diffusers_name = key if "input.blocks" in diffusers_name: From 50f3f4a799e6b3e014bbc5d93cb32b042c361b42 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 17:20:00 +0530 Subject: [PATCH 039/119] make method a part of it now --- src/diffusers/loaders.py | 166 +++++++++++++++++++++++++++++++++++---- 1 file changed, 151 insertions(+), 15 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e9eac672f601..387304e7ac05 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1167,6 +1167,152 @@ def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", bl return new_state_dict + # A weird mix of `convert_ldm_unet_checkpoint()` and `_map_sgm_blocks_to_diffusers()`. + @classmethod + def _map_sai_controlnet_blocks_to_diffusers( + cls, state_dict, model_config, delimiter=".", join_delimiter="_", block_slice_pos=2 + ): + new_state_dict = {} + # safe to pop it out as it's blank. + if "lora_controlnet" in state_dict: + _ = state_dict.pop("lora_controlnet") + + inner_block_map = ["resnets", "attentions", "upsamplers"] + + # examples: 'input_blocks.0.0.bias', 'input_hint_block.0.bias' respectively. + indirect_patterns = [r"^\w+\.\d+\.\d+\.\w+$", r"^\w+\.\d+\.\w+$"] + + new_state_dict["conv_in.down.weight"] = state_dict.pop("input_blocks.0.0.down") + new_state_dict["conv_in.up.bias"] = state_dict.pop("input_blocks.0.0.bias") + new_state_dict["conv_in.up.weight"] = state_dict.pop("input_blocks.0.0.up") + + # Retrieves # of down, mid and up blocks + input_block_ids, middle_block_ids = set(), set() + for layer in state_dict: + if not re.match(indirect_patterns[0], layer) and not re.match(indirect_patterns[1], layer): + if "text" not in layer: + layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) + if "input_blocks" in layer: + input_block_ids.add(layer_id) + elif "middle_block" in layer: + middle_block_ids.add(layer_id) + else: + raise ValueError("Checkpoint not supported") + + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key] + for layer_id in input_block_ids + } + middle_blocks = { + layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key] + for layer_id in middle_block_ids + } + + # Rename keys accordingly + for i in input_block_ids: + block_id = (i - 1) // (model_config.layers_per_block + 1) + layer_in_block_id = (i - 1) % (model_config.layers_per_block + 1) + + for key in input_blocks[i]: + inner_block_id = int(key.split(delimiter)[block_slice_pos]) + inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers" + inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0" + new_key = join_delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + + [str(block_id), inner_block_key, inner_layers_in_block] + + key.split(delimiter)[block_slice_pos + 1 :] + ) + + new_key = new_key.replace("_bias", ".bias") + if "norm" in key: + new_key = new_key.replace("_weight", ".weight") + else: + down_pattern = r"_down\b" + up_pattern = r"_up\b" + new_key = re.sub(down_pattern, ".lora_down.weight", new_key) + new_key = re.sub(up_pattern, ".lora_up.weight", new_key) + + new_state_dict[new_key] = state_dict.pop(key) + + for i in middle_block_ids: + key_part = None + if i == 0: + key_part = [inner_block_map[0], "0"] + elif i == 1: + key_part = [inner_block_map[1], "0"] + elif i == 2: + key_part = [inner_block_map[0], "1"] + else: + raise ValueError(f"Invalid middle block id {i}.") + + for key in middle_blocks[i]: + new_key = join_delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:] + ) + new_key = new_key.replace("_bias", ".bias") + if "norm" in key: + new_key = new_key.replace("_weight", ".weight") + else: + down_pattern = r"_down\b" + up_pattern = r"_up\b" + new_key = re.sub(down_pattern, ".lora_down.weight", new_key) + new_key = re.sub(up_pattern, ".lora_up.weight", new_key) + + new_state_dict[new_key] = state_dict.pop(key) + + # conditioning embedding + orig_index = 0 + + new_state_dict["controlnet_cond_embedding.conv_in.weight"] = state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_state_dict["controlnet_cond_embedding.conv_in.bias"] = state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + orig_index += 2 + diffusers_index = 0 + + while diffusers_index < 7: + new_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + diffusers_index += 1 + orig_index += 2 + + # down blocks + for i in range(num_input_blocks + 1): + new_state_dict[f"controlnet_down_blocks.{i}.weight"] = state_dict.pop(f"zero_convs.{i}.0.weight") + new_state_dict[f"controlnet_down_blocks.{i}.bias"] = state_dict.pop(f"zero_convs.{i}.0.bias") + + # mid block + new_state_dict["controlnet_mid_block.weight"] = state_dict.pop("middle_block_out.0.weight") + new_state_dict["controlnet_mid_block.bias"] = state_dict.pop("middle_block_out.0.bias") + + # time embeddings + new_state_dict["time_embedding.linear_1.lora_down.weight"] = state_dict.pop("time_embed.0.down") + new_state_dict["time_embedding.linear_1.lora_up.weight"] = state_dict.pop("time_embed.0.up") + new_state_dict["time_embedding.linear_1.bias"] = state_dict.pop("time_embed.0.bias") + new_state_dict["time_embedding.linear_2.lora_down.weight"] = state_dict.pop("time_embed.2.down") + new_state_dict["time_embedding.linear_2.lora_up.weight"] = state_dict.pop("time_embed.2.up") + new_state_dict["time_embedding.linear_2.bias"] = state_dict.pop("time_embed.2.bias") + + # additional embeddings. + new_state_dict["add_embedding.linear_1.lora_down.weight"] = state_dict.pop("label_emb.0.0.down") + new_state_dict["add_embedding.linear_1.lora_up.weight"] = state_dict.pop("label_emb.0.0.up") + new_state_dict["add_embedding.linear_1.bias"] = state_dict.pop("label_emb.0.0.bias") + new_state_dict["add_embedding.linear_2.lora_down.weight"] = state_dict.pop("label_emb.0.2.down") + new_state_dict["add_embedding.linear_2.lora_up.weight"] = state_dict.pop("label_emb.0.2.up") + new_state_dict["add_embedding.linear_2.bias"] = state_dict.pop("label_emb.0.2.bias") + + assert len(state_dict) == 0, "All keys should have been popped at this point." + + return new_state_dict + @classmethod def load_lora_into_unet(cls, state_dict, network_alphas, unet): """ @@ -1721,30 +1867,20 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): diffusers_name = diffusers_name.replace("attn1", "attn1.processor") diffusers_name = diffusers_name.replace("attn2", "attn2.processor") controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop( - lora_name_up - ) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) elif "ff" in diffusers_name: controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop( - lora_name_up - ) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) elif any(key in diffusers_name for key in ("proj_in", "proj_out")): controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop( - lora_name_up - ) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) elif any(k in diffusers_name for k in exceptional_keys): diffusers_name = diffusers_name.replace("lora_down", "lora.down") controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop( - lora_name_up - ) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) else: controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop( - lora_name_up - ) + controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) assert 2 * len(lora_keys) == len(controlnet_lora_state_dict) From 48257fb21837068534cc9518a934af7b16725ffb Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 22 Aug 2023 17:25:44 +0530 Subject: [PATCH 040/119] fix --- src/diffusers/loaders.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 387304e7ac05..9ff51e9e8951 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1182,9 +1182,9 @@ def _map_sai_controlnet_blocks_to_diffusers( # examples: 'input_blocks.0.0.bias', 'input_hint_block.0.bias' respectively. indirect_patterns = [r"^\w+\.\d+\.\d+\.\w+$", r"^\w+\.\d+\.\w+$"] - new_state_dict["conv_in.down.weight"] = state_dict.pop("input_blocks.0.0.down") - new_state_dict["conv_in.up.bias"] = state_dict.pop("input_blocks.0.0.bias") - new_state_dict["conv_in.up.weight"] = state_dict.pop("input_blocks.0.0.up") + new_state_dict["conv_in.lora_down.weight"] = state_dict.pop("input_blocks.0.0.down") + new_state_dict["conv_in.bias"] = state_dict.pop("input_blocks.0.0.bias") + new_state_dict["conv_in.lora_up.weight"] = state_dict.pop("input_blocks.0.0.up") # Retrieves # of down, mid and up blocks input_block_ids, middle_block_ids = set(), set() @@ -1223,7 +1223,6 @@ def _map_sai_controlnet_blocks_to_diffusers( + [str(block_id), inner_block_key, inner_layers_in_block] + key.split(delimiter)[block_slice_pos + 1 :] ) - new_key = new_key.replace("_bias", ".bias") if "norm" in key: new_key = new_key.replace("_weight", ".weight") From 40480deb602e557b6d0076342897f251a87610b1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 24 Aug 2023 07:43:36 +0530 Subject: [PATCH 041/119] more stuff --- src/diffusers/loaders.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 9ff51e9e8951..78d3cd9bc313 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1812,12 +1812,12 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): @classmethod def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): controlnet_lora_state_dict = {} - exceptional_keys = {"time_embedding", "add_embedding"} + exceptional_keys_lora = {"time_embedding", "add_embedding"} # every down weight has a corresponding up weight lora_keys = [k for k in state_dict.keys() if "lora_down.weight" in k] for key in lora_keys: - if not any(k in key for k in exceptional_keys): + if not any(k in key for k in exceptional_keys_lora): lora_name = key.split(".")[0] lora_name_up = lora_name + ".lora_up.weight" diffusers_name = key.replace("_", ".") @@ -1873,7 +1873,7 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): elif any(key in diffusers_name for key in ("proj_in", "proj_out")): controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif any(k in diffusers_name for k in exceptional_keys): + elif any(k in diffusers_name for k in exceptional_keys_lora): diffusers_name = diffusers_name.replace("lora_down", "lora.down") controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) @@ -1885,6 +1885,8 @@ def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): logger.info("StabilityAI ControlNet LoRA checkpoint detected.") + # Need to handle the `state_dict` which should be same as how we do + # it for existing ControlNets that are in non-diffusers format. return controlnet_lora_state_dict, state_dict def unload_lora_weights(self): From 13dffc38923074459d07f1cd375124fec7161082 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 08:00:20 +0530 Subject: [PATCH 042/119] debugging --- src/diffusers/loaders.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 78d3cd9bc313..e7a3537248f5 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1199,6 +1199,8 @@ def _map_sai_controlnet_blocks_to_diffusers( else: raise ValueError("Checkpoint not supported") + print("Input blocks:\n") + print({".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer}) num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer}) input_blocks = { layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key] @@ -1809,6 +1811,9 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): return new_state_dict, network_alphas + # Differs from the existing checkpoint conversion functions. To not hurt the readability, + # it's better to delegate the SAI ControlNet handling related conversions to a separate + # function. @classmethod def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): controlnet_lora_state_dict = {} From 6b6195fa8ae61a07f6aeaa1b7d04e827c3f14f19 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 08:12:38 +0530 Subject: [PATCH 043/119] debugging --- src/diffusers/loaders.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e7a3537248f5..20cbb6ab0e94 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1200,7 +1200,8 @@ def _map_sai_controlnet_blocks_to_diffusers( raise ValueError("Checkpoint not supported") print("Input blocks:\n") - print({".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer}) + ib = {".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer} + print(ib, len(ib)) num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer}) input_blocks = { layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key] From 7e87bf935bcb54ff1d65cd1685ff480021e2a9bf Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 08:45:01 +0530 Subject: [PATCH 044/119] changes --- src/diffusers/loaders.py | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 20cbb6ab0e94..b310bbe9c568 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1186,34 +1186,37 @@ def _map_sai_controlnet_blocks_to_diffusers( new_state_dict["conv_in.bias"] = state_dict.pop("input_blocks.0.0.bias") new_state_dict["conv_in.lora_up.weight"] = state_dict.pop("input_blocks.0.0.up") - # Retrieves # of down, mid and up blocks - input_block_ids, middle_block_ids = set(), set() - for layer in state_dict: - if not re.match(indirect_patterns[0], layer) and not re.match(indirect_patterns[1], layer): - if "text" not in layer: - layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) - if "input_blocks" in layer: - input_block_ids.add(layer_id) - elif "middle_block" in layer: - middle_block_ids.add(layer_id) - else: - raise ValueError("Checkpoint not supported") + # # Retrieves # of down, mid and up blocks + # input_block_ids, middle_block_ids = set(), set() + # for layer in state_dict: + # if not re.match(indirect_patterns[0], layer) and not re.match(indirect_patterns[1], layer): + # if "text" not in layer: + # layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) + # if "input_blocks" in layer: + # input_block_ids.add(layer_id) + # elif "middle_block" in layer: + # middle_block_ids.add(layer_id) + # else: + # raise ValueError("Checkpoint not supported") print("Input blocks:\n") - ib = {".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer} + ib = {".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer and not(re.match(indirect_patterns[0], layer) and re.match(indirect_patterns[1], layer))} print(ib, len(ib)) - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer}) + + num_input_blocks = len({".".join(layer.split(delimiter)[:block_slice_pos]) for layer in state_dict if "input_blocks" in layer and not(re.match(indirect_patterns[0], layer) and re.match(indirect_patterns[1], layer))}) input_blocks = { layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key] - for layer_id in input_block_ids + for layer_id in range(num_input_blocks) } + num_middle_blocks = len({".".join(layer.split(delimiter)[:block_slice_pos]) for layer in state_dict if "middle_block" in layer}) middle_blocks = { layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key] - for layer_id in middle_block_ids + for layer_id in range(num_middle_blocks) } + print(len(input_blocks), len(middle_blocks)) # Rename keys accordingly - for i in input_block_ids: + for i in range(1, num_input_blocks): block_id = (i - 1) // (model_config.layers_per_block + 1) layer_in_block_id = (i - 1) % (model_config.layers_per_block + 1) @@ -1237,7 +1240,7 @@ def _map_sai_controlnet_blocks_to_diffusers( new_state_dict[new_key] = state_dict.pop(key) - for i in middle_block_ids: + for i in range(num_middle_blocks): key_part = None if i == 0: key_part = [inner_block_map[0], "0"] From 4c93de5db027db3287176a543dddb69546ecec0b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 08:46:59 +0530 Subject: [PATCH 045/119] changes --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index b310bbe9c568..5ca947a7126f 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1240,7 +1240,7 @@ def _map_sai_controlnet_blocks_to_diffusers( new_state_dict[new_key] = state_dict.pop(key) - for i in range(num_middle_blocks): + for i in range(num_middle_blocks - 1): key_part = None if i == 0: key_part = [inner_block_map[0], "0"] From 182e4552a7d7f3b7ec4b69ff2064f7bd8f71f577 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 08:48:54 +0530 Subject: [PATCH 046/119] changes --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 5ca947a7126f..da48fa33a604 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1314,7 +1314,7 @@ def _map_sai_controlnet_blocks_to_diffusers( new_state_dict["add_embedding.linear_2.lora_up.weight"] = state_dict.pop("label_emb.0.2.up") new_state_dict["add_embedding.linear_2.bias"] = state_dict.pop("label_emb.0.2.bias") - assert len(state_dict) == 0, "All keys should have been popped at this point." + assert len(state_dict) == 0, f"All keys should have been popped at this point: {state_dict.keys()}." return new_state_dict From c13e824570d5f444504d11b9526da0fe92a0a1e8 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 08:51:03 +0530 Subject: [PATCH 047/119] changes --- src/diffusers/loaders.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index da48fa33a604..044aa71da60e 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1186,22 +1186,20 @@ def _map_sai_controlnet_blocks_to_diffusers( new_state_dict["conv_in.bias"] = state_dict.pop("input_blocks.0.0.bias") new_state_dict["conv_in.lora_up.weight"] = state_dict.pop("input_blocks.0.0.up") - # # Retrieves # of down, mid and up blocks - # input_block_ids, middle_block_ids = set(), set() - # for layer in state_dict: - # if not re.match(indirect_patterns[0], layer) and not re.match(indirect_patterns[1], layer): - # if "text" not in layer: - # layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) - # if "input_blocks" in layer: - # input_block_ids.add(layer_id) - # elif "middle_block" in layer: - # middle_block_ids.add(layer_id) - # else: - # raise ValueError("Checkpoint not supported") - - print("Input blocks:\n") - ib = {".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer and not(re.match(indirect_patterns[0], layer) and re.match(indirect_patterns[1], layer))} - print(ib, len(ib)) + # Retrieves # of down, mid and up blocks + input_block_ids, middle_block_ids = set(), set() + for layer in state_dict: + if not re.match(indirect_patterns[0], layer) and not re.match(indirect_patterns[1], layer): + if "text" not in layer: + layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) + if "input_blocks" in layer: + input_block_ids.add(layer_id) + elif "middle_block" in layer: + middle_block_ids.add(layer_id) + else: + raise ValueError("Checkpoint not supported") + + print(input_block_ids, middle_block_ids) num_input_blocks = len({".".join(layer.split(delimiter)[:block_slice_pos]) for layer in state_dict if "input_blocks" in layer and not(re.match(indirect_patterns[0], layer) and re.match(indirect_patterns[1], layer))}) input_blocks = { From dc27a087dc1f27c0b8d9dc1b3e0e13804eb4bfd6 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 08:56:42 +0530 Subject: [PATCH 048/119] changes --- src/diffusers/loaders.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 044aa71da60e..ef3aa616af02 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1201,20 +1201,18 @@ def _map_sai_controlnet_blocks_to_diffusers( print(input_block_ids, middle_block_ids) - num_input_blocks = len({".".join(layer.split(delimiter)[:block_slice_pos]) for layer in state_dict if "input_blocks" in layer and not(re.match(indirect_patterns[0], layer) and re.match(indirect_patterns[1], layer))}) input_blocks = { layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key] - for layer_id in range(num_input_blocks) + for layer_id in input_block_ids } - num_middle_blocks = len({".".join(layer.split(delimiter)[:block_slice_pos]) for layer in state_dict if "middle_block" in layer}) middle_blocks = { layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key] - for layer_id in range(num_middle_blocks) + for layer_id in middle_block_ids } print(len(input_blocks), len(middle_blocks)) # Rename keys accordingly - for i in range(1, num_input_blocks): + for i in input_block_ids: block_id = (i - 1) // (model_config.layers_per_block + 1) layer_in_block_id = (i - 1) % (model_config.layers_per_block + 1) @@ -1238,7 +1236,7 @@ def _map_sai_controlnet_blocks_to_diffusers( new_state_dict[new_key] = state_dict.pop(key) - for i in range(num_middle_blocks - 1): + for i in middle_block_ids: key_part = None if i == 0: key_part = [inner_block_map[0], "0"] From e2e547722cc6a28b54febfc62301066f4e2db145 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 08:59:54 +0530 Subject: [PATCH 049/119] changes --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index ef3aa616af02..3ef71d26f65d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1286,7 +1286,7 @@ def _map_sai_controlnet_blocks_to_diffusers( orig_index += 2 # down blocks - for i in range(num_input_blocks + 1): + for i in input_block_ids: new_state_dict[f"controlnet_down_blocks.{i}.weight"] = state_dict.pop(f"zero_convs.{i}.0.weight") new_state_dict[f"controlnet_down_blocks.{i}.bias"] = state_dict.pop(f"zero_convs.{i}.0.bias") From efec092b4d497d58346da38bf79d25b74fb3f700 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 09:01:51 +0530 Subject: [PATCH 050/119] changes --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 3ef71d26f65d..bc92e8a9a58a 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1287,6 +1287,7 @@ def _map_sai_controlnet_blocks_to_diffusers( # down blocks for i in input_block_ids: + print(f"Popping up zero corvs: {i}") new_state_dict[f"controlnet_down_blocks.{i}.weight"] = state_dict.pop(f"zero_convs.{i}.0.weight") new_state_dict[f"controlnet_down_blocks.{i}.bias"] = state_dict.pop(f"zero_convs.{i}.0.bias") From e871eeefd08d4f12adb754304476f15a1d7c2bac Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 09:04:21 +0530 Subject: [PATCH 051/119] changes --- src/diffusers/loaders.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index bc92e8a9a58a..96309b4c4a9b 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1287,10 +1287,12 @@ def _map_sai_controlnet_blocks_to_diffusers( # down blocks for i in input_block_ids: - print(f"Popping up zero corvs: {i}") new_state_dict[f"controlnet_down_blocks.{i}.weight"] = state_dict.pop(f"zero_convs.{i}.0.weight") new_state_dict[f"controlnet_down_blocks.{i}.bias"] = state_dict.pop(f"zero_convs.{i}.0.bias") + new_state_dict[f"controlnet_down_blocks.0.weight"] = state_dict.pop(f"zero_convs.0.0.weight") + new_state_dict[f"controlnet_down_blocks.0.bias"] = state_dict.pop(f"zero_convs.0.0.bias") + # mid block new_state_dict["controlnet_mid_block.weight"] = state_dict.pop("middle_block_out.0.weight") new_state_dict["controlnet_mid_block.bias"] = state_dict.pop("middle_block_out.0.bias") From 9d43c953cc37b0eedd3ce836766a893744264b3b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 09:11:56 +0530 Subject: [PATCH 052/119] changes --- src/diffusers/loaders.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 96309b4c4a9b..b335ebe52da0 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1235,6 +1235,7 @@ def _map_sai_controlnet_blocks_to_diffusers( new_key = re.sub(up_pattern, ".lora_up.weight", new_key) new_state_dict[new_key] = state_dict.pop(key) + print(f"Input block key: {key}") for i in middle_block_ids: key_part = None @@ -1261,6 +1262,7 @@ def _map_sai_controlnet_blocks_to_diffusers( new_key = re.sub(up_pattern, ".lora_up.weight", new_key) new_state_dict[new_key] = state_dict.pop(key) + print(f"Middle block key: {key}") # conditioning embedding orig_index = 0 From 7c26e9037b6354e38193cadc964896e88814cae3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 09:45:22 +0530 Subject: [PATCH 053/119] changes --- src/diffusers/loaders.py | 2 -- .../stable_diffusion/convert_from_ckpt.py | 18 +++++++++++++----- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index b335ebe52da0..96309b4c4a9b 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1235,7 +1235,6 @@ def _map_sai_controlnet_blocks_to_diffusers( new_key = re.sub(up_pattern, ".lora_up.weight", new_key) new_state_dict[new_key] = state_dict.pop(key) - print(f"Input block key: {key}") for i in middle_block_ids: key_part = None @@ -1262,7 +1261,6 @@ def _map_sai_controlnet_blocks_to_diffusers( new_key = re.sub(up_pattern, ".lora_up.weight", new_key) new_state_dict[new_key] = state_dict.pop(key) - print(f"Middle block key: {key}") # conditioning embedding orig_index = 0 diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 1530de0f064f..7d20240558cb 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -377,7 +377,7 @@ def create_ldm_bert_config(original_config): def convert_ldm_unet_checkpoint( - checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False + checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False, controlnet_lora=False ): """ Takes a state dict and a config, and returns a converted checkpoint. @@ -419,10 +419,18 @@ def convert_ldm_unet_checkpoint( new_checkpoint = {} - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + if not controlnet_lora: + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + else: + new_checkpoint["time_embedding.linear_1.lora_down.weight"] = unet_state_dict["time_embed.0.down"] + new_checkpoint["time_embedding.linear_1.lora_up.weight"] = unet_state_dict["time_embed.0.up"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.lora_down.weight"] = unet_state_dict["time_embed.2.down"] + new_checkpoint["time_embedding.linear_2.lora_up.weight"] = unet_state_dict["time_embed.2.up"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] if config["class_embed_type"] is None: # No parameters to port From f9eb243c743e72e9391552f384e4ecc1b0fc6da1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 09:53:06 +0530 Subject: [PATCH 054/119] changes --- .../stable_diffusion/convert_from_ckpt.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 7d20240558cb..0278628fb691 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -444,10 +444,18 @@ def convert_ldm_unet_checkpoint( raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") if config["addition_embed_type"] == "text_time": - new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] - new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] - new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] - new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + if not controlnet_lora: + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + new_checkpoint["add_embedding.linear_1.lora_down.weight"] = unet_state_dict["label_emb.0.0.down"] + new_checkpoint["add_embedding.linear_1.lora_up.weight"] = unet_state_dict["label_emb.0.0.up"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.lora_down.weight"] = unet_state_dict["label_emb.0.2.down"] + new_checkpoint["add_embedding.linear_2.lora_up.weight"] = unet_state_dict["label_emb.0.2.up"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] From 000f74cedb36c16319225c22510db79e6d3ec728 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 09:55:46 +0530 Subject: [PATCH 055/119] changes --- .../pipelines/stable_diffusion/convert_from_ckpt.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 0278628fb691..a0b644d86b4a 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -457,8 +457,13 @@ def convert_ldm_unet_checkpoint( new_checkpoint["add_embedding.linear_2.lora_up.weight"] = unet_state_dict["label_emb.0.2.up"] new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + if not controlnet_lora: + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + else: + new_checkpoint["conv_in.lora_down.weight"] = unet_state_dict["input_blocks.0.0.down"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + new_checkpoint["conv_in.lora_up.weight"] = unet_state_dict["input_blocks.0.0.up"] if not controlnet: new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] From 101ceebe5af99eede9ed30c3d2b786de4059628a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 10:01:15 +0530 Subject: [PATCH 056/119] changes --- src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index a0b644d86b4a..97225f21bdc5 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -609,8 +609,9 @@ def convert_ldm_unet_checkpoint( orig_index += 2 diffusers_index = 0 + diffusers_index_limit = 6 if not controlnet_lora else 7 - while diffusers_index < 6: + while diffusers_index < diffusers_index_limit: new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) From d326f24fd59d2d06f3a2cb7d1a93903142db1424 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 10:06:42 +0530 Subject: [PATCH 057/119] changes --- .../pipelines/stable_diffusion/convert_from_ckpt.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 97225f21bdc5..3d6c65ecd58b 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -621,12 +621,13 @@ def convert_ldm_unet_checkpoint( diffusers_index += 1 orig_index += 2 - new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) + if not controlnet_lora: + new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) # down blocks for i in range(num_input_blocks): From c35161dc9bd5032cf813b8f07220d3a54d7b5533 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 10:19:19 +0530 Subject: [PATCH 058/119] changes --- .../stable_diffusion/convert_from_ckpt.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 3d6c65ecd58b..94c66053eb7a 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -382,6 +382,10 @@ def convert_ldm_unet_checkpoint( """ Takes a state dict and a config, and returns a converted checkpoint. """ + if not controlnet and controlnet_lora: + raise ValueError(f"`controlnet_lora` cannot be done with `controlnet` set to {controlnet}.") + if controlnet and controlnet_lora: + skip_extract_state_dict = True if skip_extract_state_dict: unet_state_dict = checkpoint @@ -638,6 +642,18 @@ def convert_ldm_unet_checkpoint( new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + if controlnet_lora: + modified_new_checkpoint = {} + down_pattern = r"_down\b" + up_pattern = r"_up\b" + + for key in new_checkpoint: + new_key = re.sub(down_pattern, ".lora_down.weight", new_key) + new_key = re.sub(up_pattern, ".lora_up.weight", new_key) + modified_new_checkpoint[new_key] = new_checkpoint[key] + + new_checkpoint = modified_new_checkpoint + return new_checkpoint From e103f776c232c5bfee7e02f67cdb27089e35d5a7 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 10:25:02 +0530 Subject: [PATCH 059/119] changes --- src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 94c66053eb7a..20262bdf60f2 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -648,6 +648,7 @@ def convert_ldm_unet_checkpoint( up_pattern = r"_up\b" for key in new_checkpoint: + new_key = key new_key = re.sub(down_pattern, ".lora_down.weight", new_key) new_key = re.sub(up_pattern, ".lora_up.weight", new_key) modified_new_checkpoint[new_key] = new_checkpoint[key] From 0e42a2c850b4fe128210480f8a89ad3724f885f7 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 10:27:02 +0530 Subject: [PATCH 060/119] changes --- .../stable_diffusion/convert_from_ckpt.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 20262bdf60f2..c3ff28f9178a 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -642,18 +642,18 @@ def convert_ldm_unet_checkpoint( new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") - if controlnet_lora: - modified_new_checkpoint = {} - down_pattern = r"_down\b" - up_pattern = r"_up\b" + # if controlnet_lora: + # modified_new_checkpoint = {} + # down_pattern = r"_down\b" + # up_pattern = r"_up\b" - for key in new_checkpoint: - new_key = key - new_key = re.sub(down_pattern, ".lora_down.weight", new_key) - new_key = re.sub(up_pattern, ".lora_up.weight", new_key) - modified_new_checkpoint[new_key] = new_checkpoint[key] + # for key in new_checkpoint: + # new_key = key + # new_key = re.sub(down_pattern, ".lora_down.weight", new_key) + # new_key = re.sub(up_pattern, ".lora_up.weight", new_key) + # modified_new_checkpoint[new_key] = new_checkpoint[key] - new_checkpoint = modified_new_checkpoint + # new_checkpoint = modified_new_checkpoint return new_checkpoint From 5bdb7bb25d24d013910f7413e8e5d2bf5d8697e0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 10:31:54 +0530 Subject: [PATCH 061/119] changes --- .../stable_diffusion/convert_from_ckpt.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index c3ff28f9178a..988ba10fc2bb 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -642,18 +642,18 @@ def convert_ldm_unet_checkpoint( new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") - # if controlnet_lora: - # modified_new_checkpoint = {} - # down_pattern = r"_down\b" - # up_pattern = r"_up\b" + if controlnet_lora: + modified_new_checkpoint = {} + down_pattern = r"\.down$" + up_pattern = r"\.up$" - # for key in new_checkpoint: - # new_key = key - # new_key = re.sub(down_pattern, ".lora_down.weight", new_key) - # new_key = re.sub(up_pattern, ".lora_up.weight", new_key) - # modified_new_checkpoint[new_key] = new_checkpoint[key] + for key in new_checkpoint: + new_key = key + new_key = re.sub(down_pattern, ".lora_down.weight", new_key) + new_key = re.sub(up_pattern, ".lora_up.weight", new_key) + modified_new_checkpoint[new_key] = new_checkpoint[key] - # new_checkpoint = modified_new_checkpoint + new_checkpoint = modified_new_checkpoint return new_checkpoint From e143979ad3b6201be9d3c6053f7375a09b7c36d4 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 11:11:25 +0530 Subject: [PATCH 062/119] changes --- src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 988ba10fc2bb..dfecb2d5a6e6 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -613,7 +613,7 @@ def convert_ldm_unet_checkpoint( orig_index += 2 diffusers_index = 0 - diffusers_index_limit = 6 if not controlnet_lora else 7 + diffusers_index_limit = 6 while diffusers_index < diffusers_index_limit: new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( From 2baae10d268b1ca132a78324814e5978f028e5bc Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 11:16:37 +0530 Subject: [PATCH 063/119] remove unnecessary stuff from loaders.py --- src/diffusers/loaders.py | 233 ------------------ .../stable_diffusion/convert_from_ckpt.py | 16 +- 2 files changed, 11 insertions(+), 238 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 96309b4c4a9b..d1524f1b9c23 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1167,156 +1167,6 @@ def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", bl return new_state_dict - # A weird mix of `convert_ldm_unet_checkpoint()` and `_map_sgm_blocks_to_diffusers()`. - @classmethod - def _map_sai_controlnet_blocks_to_diffusers( - cls, state_dict, model_config, delimiter=".", join_delimiter="_", block_slice_pos=2 - ): - new_state_dict = {} - # safe to pop it out as it's blank. - if "lora_controlnet" in state_dict: - _ = state_dict.pop("lora_controlnet") - - inner_block_map = ["resnets", "attentions", "upsamplers"] - - # examples: 'input_blocks.0.0.bias', 'input_hint_block.0.bias' respectively. - indirect_patterns = [r"^\w+\.\d+\.\d+\.\w+$", r"^\w+\.\d+\.\w+$"] - - new_state_dict["conv_in.lora_down.weight"] = state_dict.pop("input_blocks.0.0.down") - new_state_dict["conv_in.bias"] = state_dict.pop("input_blocks.0.0.bias") - new_state_dict["conv_in.lora_up.weight"] = state_dict.pop("input_blocks.0.0.up") - - # Retrieves # of down, mid and up blocks - input_block_ids, middle_block_ids = set(), set() - for layer in state_dict: - if not re.match(indirect_patterns[0], layer) and not re.match(indirect_patterns[1], layer): - if "text" not in layer: - layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) - if "input_blocks" in layer: - input_block_ids.add(layer_id) - elif "middle_block" in layer: - middle_block_ids.add(layer_id) - else: - raise ValueError("Checkpoint not supported") - - print(input_block_ids, middle_block_ids) - - input_blocks = { - layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key] - for layer_id in input_block_ids - } - middle_blocks = { - layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key] - for layer_id in middle_block_ids - } - print(len(input_blocks), len(middle_blocks)) - - # Rename keys accordingly - for i in input_block_ids: - block_id = (i - 1) // (model_config.layers_per_block + 1) - layer_in_block_id = (i - 1) % (model_config.layers_per_block + 1) - - for key in input_blocks[i]: - inner_block_id = int(key.split(delimiter)[block_slice_pos]) - inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers" - inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0" - new_key = join_delimiter.join( - key.split(delimiter)[: block_slice_pos - 1] - + [str(block_id), inner_block_key, inner_layers_in_block] - + key.split(delimiter)[block_slice_pos + 1 :] - ) - new_key = new_key.replace("_bias", ".bias") - if "norm" in key: - new_key = new_key.replace("_weight", ".weight") - else: - down_pattern = r"_down\b" - up_pattern = r"_up\b" - new_key = re.sub(down_pattern, ".lora_down.weight", new_key) - new_key = re.sub(up_pattern, ".lora_up.weight", new_key) - - new_state_dict[new_key] = state_dict.pop(key) - - for i in middle_block_ids: - key_part = None - if i == 0: - key_part = [inner_block_map[0], "0"] - elif i == 1: - key_part = [inner_block_map[1], "0"] - elif i == 2: - key_part = [inner_block_map[0], "1"] - else: - raise ValueError(f"Invalid middle block id {i}.") - - for key in middle_blocks[i]: - new_key = join_delimiter.join( - key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:] - ) - new_key = new_key.replace("_bias", ".bias") - if "norm" in key: - new_key = new_key.replace("_weight", ".weight") - else: - down_pattern = r"_down\b" - up_pattern = r"_up\b" - new_key = re.sub(down_pattern, ".lora_down.weight", new_key) - new_key = re.sub(up_pattern, ".lora_up.weight", new_key) - - new_state_dict[new_key] = state_dict.pop(key) - - # conditioning embedding - orig_index = 0 - - new_state_dict["controlnet_cond_embedding.conv_in.weight"] = state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_state_dict["controlnet_cond_embedding.conv_in.bias"] = state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - - orig_index += 2 - diffusers_index = 0 - - while diffusers_index < 7: - new_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - diffusers_index += 1 - orig_index += 2 - - # down blocks - for i in input_block_ids: - new_state_dict[f"controlnet_down_blocks.{i}.weight"] = state_dict.pop(f"zero_convs.{i}.0.weight") - new_state_dict[f"controlnet_down_blocks.{i}.bias"] = state_dict.pop(f"zero_convs.{i}.0.bias") - - new_state_dict[f"controlnet_down_blocks.0.weight"] = state_dict.pop(f"zero_convs.0.0.weight") - new_state_dict[f"controlnet_down_blocks.0.bias"] = state_dict.pop(f"zero_convs.0.0.bias") - - # mid block - new_state_dict["controlnet_mid_block.weight"] = state_dict.pop("middle_block_out.0.weight") - new_state_dict["controlnet_mid_block.bias"] = state_dict.pop("middle_block_out.0.bias") - - # time embeddings - new_state_dict["time_embedding.linear_1.lora_down.weight"] = state_dict.pop("time_embed.0.down") - new_state_dict["time_embedding.linear_1.lora_up.weight"] = state_dict.pop("time_embed.0.up") - new_state_dict["time_embedding.linear_1.bias"] = state_dict.pop("time_embed.0.bias") - new_state_dict["time_embedding.linear_2.lora_down.weight"] = state_dict.pop("time_embed.2.down") - new_state_dict["time_embedding.linear_2.lora_up.weight"] = state_dict.pop("time_embed.2.up") - new_state_dict["time_embedding.linear_2.bias"] = state_dict.pop("time_embed.2.bias") - - # additional embeddings. - new_state_dict["add_embedding.linear_1.lora_down.weight"] = state_dict.pop("label_emb.0.0.down") - new_state_dict["add_embedding.linear_1.lora_up.weight"] = state_dict.pop("label_emb.0.0.up") - new_state_dict["add_embedding.linear_1.bias"] = state_dict.pop("label_emb.0.0.bias") - new_state_dict["add_embedding.linear_2.lora_down.weight"] = state_dict.pop("label_emb.0.2.down") - new_state_dict["add_embedding.linear_2.lora_up.weight"] = state_dict.pop("label_emb.0.2.up") - new_state_dict["add_embedding.linear_2.bias"] = state_dict.pop("label_emb.0.2.bias") - - assert len(state_dict) == 0, f"All keys should have been popped at this point: {state_dict.keys()}." - - return new_state_dict - @classmethod def load_lora_into_unet(cls, state_dict, network_alphas, unet): """ @@ -1814,89 +1664,6 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): return new_state_dict, network_alphas - # Differs from the existing checkpoint conversion functions. To not hurt the readability, - # it's better to delegate the SAI ControlNet handling related conversions to a separate - # function. - @classmethod - def _convert_sai_controlnet_lora_to_diffusers(cls, state_dict): - controlnet_lora_state_dict = {} - exceptional_keys_lora = {"time_embedding", "add_embedding"} - - # every down weight has a corresponding up weight - lora_keys = [k for k in state_dict.keys() if "lora_down.weight" in k] - for key in lora_keys: - if not any(k in key for k in exceptional_keys_lora): - lora_name = key.split(".")[0] - lora_name_up = lora_name + ".lora_up.weight" - diffusers_name = key.replace("_", ".") - else: - lora_name_up = key.replace("lora_down", "lora_up") - diffusers_name = key - - if "input.blocks" in diffusers_name: - diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") - else: - diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") - - if "middle.block" in diffusers_name: - diffusers_name = diffusers_name.replace("middle.block", "mid_block") - else: - diffusers_name = diffusers_name.replace("mid.block", "mid_block") - if "output.blocks" in diffusers_name: - diffusers_name = diffusers_name.replace("output.blocks", "up_blocks") - else: - diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") - - diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") - diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") - diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") - diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") - diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") - diffusers_name = diffusers_name.replace("proj.in", "proj_in") - diffusers_name = diffusers_name.replace("proj.out", "proj_out") - diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj") - - # SDXL specificity. - if "emb" in diffusers_name: - pattern = r"\.\d+(?=\D*$)" - diffusers_name = re.sub(pattern, "", diffusers_name, count=1) - if ".in." in diffusers_name: - diffusers_name = diffusers_name.replace("in.layers.2", "conv1") - if ".out." in diffusers_name: - diffusers_name = diffusers_name.replace("out.layers.3", "conv2") - if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name: - diffusers_name = diffusers_name.replace("op", "conv") - if "skip" in diffusers_name: - diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") - - if "transformer_blocks" in diffusers_name: - if "attn1" in diffusers_name or "attn2" in diffusers_name: - diffusers_name = diffusers_name.replace("attn1", "attn1.processor") - diffusers_name = diffusers_name.replace("attn2", "attn2.processor") - controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif "ff" in diffusers_name: - controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif any(key in diffusers_name for key in ("proj_in", "proj_out")): - controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif any(k in diffusers_name for k in exceptional_keys_lora): - diffusers_name = diffusers_name.replace("lora_down", "lora.down") - controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - else: - controlnet_lora_state_dict[diffusers_name] = state_dict.pop(key) - controlnet_lora_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - - assert 2 * len(lora_keys) == len(controlnet_lora_state_dict) - - logger.info("StabilityAI ControlNet LoRA checkpoint detected.") - - # Need to handle the `state_dict` which should be same as how we do - # it for existing ControlNets that are in non-diffusers format. - return controlnet_lora_state_dict, state_dict - def unload_lora_weights(self): """ Unloads the LoRA parameters. diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index dfecb2d5a6e6..0f9eed65c6eb 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -377,7 +377,13 @@ def create_ldm_bert_config(original_config): def convert_ldm_unet_checkpoint( - checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False, controlnet_lora=False + checkpoint, + config, + path=None, + extract_ema=False, + controlnet=False, + skip_extract_state_dict=False, + controlnet_lora=False, ): """ Takes a state dict and a config, and returns a converted checkpoint. @@ -613,7 +619,7 @@ def convert_ldm_unet_checkpoint( orig_index += 2 diffusers_index = 0 - diffusers_index_limit = 6 + diffusers_index_limit = 6 while diffusers_index < diffusers_index_limit: new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( @@ -646,13 +652,13 @@ def convert_ldm_unet_checkpoint( modified_new_checkpoint = {} down_pattern = r"\.down$" up_pattern = r"\.up$" - + for key in new_checkpoint: new_key = key new_key = re.sub(down_pattern, ".lora_down.weight", new_key) - new_key = re.sub(up_pattern, ".lora_up.weight", new_key) + new_key = re.sub(up_pattern, ".lora_up.weight", new_key) modified_new_checkpoint[new_key] = new_checkpoint[key] - + new_checkpoint = modified_new_checkpoint return new_checkpoint From 95f09d8fb816687e4b851cededfb1d53973327dd Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 11:24:46 +0530 Subject: [PATCH 064/119] remove unneeded stuff. --- src/diffusers/loaders.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 6bddbcfc1994..ebd367fea72d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -137,7 +137,6 @@ def _unfuse_lora(self): self.w_down = None def forward(self, input): - # print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}") if self.lora_scale is None: self.lora_scale = 1.0 if self.lora_linear_layer is None: @@ -1013,14 +1012,6 @@ def lora_state_dict( r""" Return state dict for lora weights and the network alphas. - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: @@ -2517,3 +2508,8 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): controlnet.to(torch_dtype=torch_dtype) return controlnet + + +class ControlNetLoaderMixin(LoraLoaderMixin): + def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): + pass From d88c806a5d3dc1caf0084ebb2765be6c63288d6a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 11:46:52 +0530 Subject: [PATCH 065/119] better simplicity. --- src/diffusers/loaders.py | 56 ++++++++++++++----- .../stable_diffusion/convert_from_ckpt.py | 4 +- 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index ebd367fea72d..d481ed7cc8b5 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1007,6 +1007,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + controlnet=False, **kwargs, ): r""" @@ -1023,6 +1024,8 @@ def lora_state_dict( - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + controlnet (`bool`, *optional*, defaults to False): + If we're convert a ControlNet LoRA checkpoint. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. @@ -1134,20 +1137,21 @@ def lora_state_dict( state_dict = pretrained_model_name_or_path_or_dict network_alphas = None - if all( - ( - k.startswith("lora_te_") - or k.startswith("lora_unet_") - or k.startswith("lora_te1_") - or k.startswith("lora_te2_") - ) - for k in state_dict.keys() - ): - # Map SDXL blocks correctly. - if unet_config is not None: - # use unet config to remap block numbers - state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) - state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict) + if not controlnet: + if all( + ( + k.startswith("lora_te_") + or k.startswith("lora_unet_") + or k.startswith("lora_te1_") + or k.startswith("lora_te2_") + ) + for k in state_dict.keys() + ): + # Map SDXL blocks correctly. + if unet_config is not None: + # use unet config to remap block numbers + state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) + state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict) return state_dict, network_alphas @@ -2512,4 +2516,26 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): class ControlNetLoaderMixin(LoraLoaderMixin): def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): - pass + from .pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint + + state_dict, _ = self.lora_state_dict(pretrained_model_name_or_path_or_dict, controlnet=True, **kwargs) + controlnet_config = kwargs.pop("controlnet_config", None) + if controlnet_config is None: + raise ("Must provide a `controlnet_config`.") + + # ControlNet LoRA has a mix of things. Some parameters correspond to LoRA and some correspond + # to the ones belonging to the original state_dict (initialized from the underlying UNet). + # So, we first map the LoRA parameters and then we load the remaining state_dict into + # the ControlNet. + converted_state_dict = convert_ldm_unet_checkpoint( + state_dict, controlnet=True, config=controlnet_config, skip_extract_state_dict=True, controlnet_lora=True + ) + + # Load whatever is matching. + load_state_dict_results = self.controlnet.load_state_dict(converted_state_dict, strict=False) + if not all("lora" in k for k in load_state_dict_results.unexpected_keys): + raise ValueError( + f"The unexpected keys must only belong to LoRA parameters at this point: {load_state_dict_results.unexpected_keys}" + ) + + # Handle LoRA. diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 3f9433e49245..d51aaccd4fca 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -655,8 +655,8 @@ def convert_ldm_unet_checkpoint( for key in new_checkpoint: new_key = key - new_key = re.sub(down_pattern, ".lora_down.weight", new_key) - new_key = re.sub(up_pattern, ".lora_up.weight", new_key) + new_key = re.sub(down_pattern, ".lora.down.weight", new_key) + new_key = re.sub(up_pattern, ".lora.up.weight", new_key) modified_new_checkpoint[new_key] = new_checkpoint[key] new_checkpoint = modified_new_checkpoint From 260bc7527e1d9c7099a1c131162afb569d9a2ed3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 12:06:27 +0530 Subject: [PATCH 066/119] better modularity --- src/diffusers/loaders.py | 64 +++++++++++++++++++ src/diffusers/models/controlnet.py | 6 +- .../controlnet/pipeline_controlnet_sd_xl.py | 5 +- 3 files changed, 72 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index d481ed7cc8b5..190a953f03ac 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2515,7 +2515,9 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): class ControlNetLoaderMixin(LoraLoaderMixin): + # Simplify ControlNet LoRA loading. def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): + from .models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer from .pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint state_dict, _ = self.lora_state_dict(pretrained_model_name_or_path_or_dict, controlnet=True, **kwargs) @@ -2539,3 +2541,65 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): ) # Handle LoRA. + lora_grouped_dict = defaultdict(dict) + lora_layers_list = [] + + all_keys = list(converted_state_dict.keys()) + for key in all_keys: + value = converted_state_dict.pop(key) + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + lora_grouped_dict[attn_processor_key][sub_key] = value + + if len(converted_state_dict) > 0: + raise ValueError( + f"The `converted_state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}" + ) + + for key, value_dict in lora_grouped_dict.items(): + attn_processor = self + for sub_key in key.split("."): + attn_processor = getattr(attn_processor, sub_key) + + # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers + # or add_{k,v,q,out_proj}_proj_lora layers. + rank = value_dict["lora.down.weight"].shape[0] + + if isinstance(attn_processor, LoRACompatibleConv): + in_features = attn_processor.in_channels + out_features = attn_processor.out_channels + kernel_size = attn_processor.kernel_size + + lora = LoRAConv2dLayer( + in_features=in_features, + out_features=out_features, + rank=rank, + kernel_size=kernel_size, + stride=attn_processor.stride, + padding=attn_processor.padding, + ) + elif isinstance(attn_processor, LoRACompatibleLinear): + lora = LoRALinearLayer( + attn_processor.in_features, + attn_processor.out_features, + rank, + ) + else: + raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") + + value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} + lora.load_state_dict(value_dict) + lora_layers_list.append((attn_processor, lora)) + + # set correct dtype & device + lora_layers_list = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in lora_layers_list] + + # set lora layers + for target_module, lora_layer in lora_layers_list: + target_module.set_lora_layer(lora_layer) + + def unload_lora_weights(self): + for _, module in self.controlnet.named_modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + # Implement `fuse_lora()` and `unfuse_lora()` (sayakpaul). diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 851906c670ee..4347c294f81a 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -19,7 +19,7 @@ from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import FromOriginalControlnetMixin, UNet2DConditionLoadersMixin +from ..loaders import ControlNetLoaderMixin, FromOriginalControlnetMixin, UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -107,7 +107,9 @@ def forward(self, conditioning): return embedding -class ControlNetModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, FromOriginalControlnetMixin): +class ControlNetModel( + ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, FromOriginalControlnetMixin, ControlNetLoaderMixin +): """ A ControlNet model. diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 391d58134627..6ff88e8275f2 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -104,7 +104,10 @@ class StableDiffusionXLControlNetPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, + TextualInversionLoaderMixin, + LoraLoaderMixin, + FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. From 5e5004da0de6f38609343f4557de52621c82174a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 12:10:54 +0530 Subject: [PATCH 067/119] fix: exception raise/. --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 190a953f03ac..15ff0cc7fda8 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2523,7 +2523,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): state_dict, _ = self.lora_state_dict(pretrained_model_name_or_path_or_dict, controlnet=True, **kwargs) controlnet_config = kwargs.pop("controlnet_config", None) if controlnet_config is None: - raise ("Must provide a `controlnet_config`.") + raise ValueError("Must provide a `controlnet_config`.") # ControlNet LoRA has a mix of things. Some parameters correspond to LoRA and some correspond # to the ones belonging to the original state_dict (initialized from the underlying UNet). From 11a85cdf2558d73f2fdb5add49ebb49ad3f217b4 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 12:15:47 +0530 Subject: [PATCH 068/119] empty lora controlnet key --- src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index d51aaccd4fca..15e372719f67 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -429,6 +429,10 @@ def convert_ldm_unet_checkpoint( new_checkpoint = {} + if controlnet_lora: + # Safe to pop as it doesn't have anything. + _ = unet_state_dict.pop("lora_controlnet") + if not controlnet_lora: new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] From d16673242e6b5f9032f8fea410800adc68531639 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 12:17:26 +0530 Subject: [PATCH 069/119] empty lora controlnet key --- src/diffusers/loaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 15ff0cc7fda8..d6936c590206 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2534,7 +2534,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): ) # Load whatever is matching. - load_state_dict_results = self.controlnet.load_state_dict(converted_state_dict, strict=False) + load_state_dict_results = self.load_state_dict(converted_state_dict, strict=False) if not all("lora" in k for k in load_state_dict_results.unexpected_keys): raise ValueError( f"The unexpected keys must only belong to LoRA parameters at this point: {load_state_dict_results.unexpected_keys}" @@ -2598,7 +2598,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): target_module.set_lora_layer(lora_layer) def unload_lora_weights(self): - for _, module in self.controlnet.named_modules(): + for _, module in self.named_modules(): if hasattr(module, "set_lora_layer"): module.set_lora_layer(None) From b3b7798a30ba26aca85d0efcfc416b45509f2a98 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 12:26:48 +0530 Subject: [PATCH 070/119] debugging --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index d6936c590206..e9c12108d6d3 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2562,6 +2562,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers # or add_{k,v,q,out_proj}_proj_lora layers. + print(f"Value dict: {value_dict}.") rank = value_dict["lora.down.weight"].shape[0] if isinstance(attn_processor, LoRACompatibleConv): From d0e1cfb5d46d1ea8d64b973531d64e1ff0da5dea Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 12:30:27 +0530 Subject: [PATCH 071/119] debugging --- src/diffusers/loaders.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e9c12108d6d3..5bc35ac73817 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -395,7 +395,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict raise ValueError( f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}" ) - + temp = 0 for key, value_dict in lora_grouped_dict.items(): attn_processor = self for sub_key in key.split("."): @@ -403,6 +403,8 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers # or add_{k,v,q,out_proj}_proj_lora layers. + if temp == 0: + print(f"Value dict: {value_dict}") rank = value_dict["lora.down.weight"].shape[0] if isinstance(attn_processor, LoRACompatibleConv): From 11ddd6cecff72669703cfcf6555fecaf7bde3d02 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 12:34:43 +0530 Subject: [PATCH 072/119] debugging --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 5bc35ac73817..0cf21ffa6c74 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -404,7 +404,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers # or add_{k,v,q,out_proj}_proj_lora layers. if temp == 0: - print(f"Value dict: {value_dict}") + print(f"Value dict: {value_dict.keys()}") rank = value_dict["lora.down.weight"].shape[0] if isinstance(attn_processor, LoRACompatibleConv): From 8f6608d67066c272b5f039ea07163aa73beb5a2d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 12:42:04 +0530 Subject: [PATCH 073/119] debugging --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 0cf21ffa6c74..eb77f32b0099 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2546,7 +2546,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): lora_grouped_dict = defaultdict(dict) lora_layers_list = [] - all_keys = list(converted_state_dict.keys()) + all_keys = [k for k in converted_state_dict if "lora" in k] for key in all_keys: value = converted_state_dict.pop(key) attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) From fa4782f3ec758c0f6a8a17dfc9e9c23de79e95f6 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 12:45:49 +0530 Subject: [PATCH 074/119] debugging --- src/diffusers/loaders.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index eb77f32b0099..987074599c12 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -373,10 +373,13 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict mapped_network_alphas = {} all_keys = list(state_dict.keys()) + temp = 0 for key in all_keys: value = state_dict.pop(key) attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) lora_grouped_dict[attn_processor_key][sub_key] = value + if temp == 0: + print(attn_processor_key, sub_key) # Create another `mapped_network_alphas` dictionary so that we can properly map them. if network_alphas is not None: From aa4f65f06605ac33c6b27cd640155317c1c8b1f6 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 12:47:07 +0530 Subject: [PATCH 075/119] debugging --- src/diffusers/loaders.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 987074599c12..56d7bde82c7f 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -380,6 +380,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict lora_grouped_dict[attn_processor_key][sub_key] = value if temp == 0: print(attn_processor_key, sub_key) + temp = 999 # Create another `mapped_network_alphas` dictionary so that we can properly map them. if network_alphas is not None: @@ -408,6 +409,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict # or add_{k,v,q,out_proj}_proj_lora layers. if temp == 0: print(f"Value dict: {value_dict.keys()}") + temp = 999 rank = value_dict["lora.down.weight"].shape[0] if isinstance(attn_processor, LoRACompatibleConv): From e238f3a7a6c365fd85ff008b7afd4f41c1e9b139 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 12:48:14 +0530 Subject: [PATCH 076/119] debugging --- src/diffusers/loaders.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 56d7bde82c7f..130e5e7aae64 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2552,10 +2552,14 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): lora_layers_list = [] all_keys = [k for k in converted_state_dict if "lora" in k] + temp = 0 for key in all_keys: value = converted_state_dict.pop(key) attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) lora_grouped_dict[attn_processor_key][sub_key] = value + if temp == 0: + print(attn_processor_key, sub_key) + temp = 999 if len(converted_state_dict) > 0: raise ValueError( From 8206ef02a21f21b0908355d1d942962ab4c6e5f6 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 12:52:24 +0530 Subject: [PATCH 077/119] debugging --- src/diffusers/loaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 130e5e7aae64..c28024e38d36 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -379,7 +379,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) lora_grouped_dict[attn_processor_key][sub_key] = value if temp == 0: - print(attn_processor_key, sub_key) + print(key, attn_processor_key, sub_key) temp = 999 # Create another `mapped_network_alphas` dictionary so that we can properly map them. @@ -2558,7 +2558,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) lora_grouped_dict[attn_processor_key][sub_key] = value if temp == 0: - print(attn_processor_key, sub_key) + print(key, attn_processor_key, sub_key) temp = 999 if len(converted_state_dict) > 0: From 33cfc2d64d7b70717ab8481f84beae29c312cc39 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 12:54:47 +0530 Subject: [PATCH 078/119] debugging --- src/diffusers/loaders.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index c28024e38d36..476854b9c092 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2551,9 +2551,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): lora_grouped_dict = defaultdict(dict) lora_layers_list = [] - all_keys = [k for k in converted_state_dict if "lora" in k] temp = 0 - for key in all_keys: + for key in converted_state_dict: value = converted_state_dict.pop(key) attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) lora_grouped_dict[attn_processor_key][sub_key] = value From 71f3c91ac2c91a7639970b053efe34e235592aa7 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 12:59:32 +0530 Subject: [PATCH 079/119] better state_dict munging --- src/diffusers/loaders.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 476854b9c092..2e63d2361fee 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2546,23 +2546,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): raise ValueError( f"The unexpected keys must only belong to LoRA parameters at this point: {load_state_dict_results.unexpected_keys}" ) + + # Filter out the rest of the state_dict for handling LoRA. + remaining_state_dict = {k: v for k, v in converted_state_dict.items() if k in load_state_dict_results.unexpected_keys} # Handle LoRA. lora_grouped_dict = defaultdict(dict) lora_layers_list = [] temp = 0 - for key in converted_state_dict: - value = converted_state_dict.pop(key) + for key in remaining_state_dict.keys(): + value = remaining_state_dict.pop(key) attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) lora_grouped_dict[attn_processor_key][sub_key] = value if temp == 0: print(key, attn_processor_key, sub_key) temp = 999 - if len(converted_state_dict) > 0: + if len(remaining_state_dict) > 0: raise ValueError( - f"The `converted_state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}" + f"The `remaining_state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}" ) for key, value_dict in lora_grouped_dict.items(): From 1bfbefba32ddb5c23c00f4b131da0a3f6bf6abc7 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 13:00:57 +0530 Subject: [PATCH 080/119] better state_dict munging --- src/diffusers/loaders.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 2e63d2361fee..9fbfb7ddd7b3 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2555,7 +2555,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): lora_layers_list = [] temp = 0 - for key in remaining_state_dict.keys(): + all_keys = remaining_state_dict.keys() + for key in all_keys: value = remaining_state_dict.pop(key) attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) lora_grouped_dict[attn_processor_key][sub_key] = value From 8ad9b977f399b241dfadf11a45ef1b12c0e7aef1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 13:01:35 +0530 Subject: [PATCH 081/119] better state_dict munging --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 9fbfb7ddd7b3..a62fc7085b44 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2555,7 +2555,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): lora_layers_list = [] temp = 0 - all_keys = remaining_state_dict.keys() + all_keys = list(remaining_state_dict.keys()) for key in all_keys: value = remaining_state_dict.pop(key) attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) From d901a9a04aa04abedc1865e5539a11c298aa9560 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 13:10:31 +0530 Subject: [PATCH 082/119] sanity --- src/diffusers/loaders.py | 6 ++++-- .../pipelines/stable_diffusion/convert_from_ckpt.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index a62fc7085b44..733a9abfa3ae 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2546,9 +2546,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): raise ValueError( f"The unexpected keys must only belong to LoRA parameters at this point: {load_state_dict_results.unexpected_keys}" ) - + # Filter out the rest of the state_dict for handling LoRA. - remaining_state_dict = {k: v for k, v in converted_state_dict.items() if k in load_state_dict_results.unexpected_keys} + remaining_state_dict = { + k: v for k, v in converted_state_dict.items() if k in load_state_dict_results.unexpected_keys + } # Handle LoRA. lora_grouped_dict = defaultdict(dict) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 15e372719f67..9bfbab0c7d7b 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -661,6 +661,8 @@ def convert_ldm_unet_checkpoint( new_key = key new_key = re.sub(down_pattern, ".lora.down.weight", new_key) new_key = re.sub(up_pattern, ".lora.up.weight", new_key) + new_key = new_key.replace("lora_down", "lora.down") + new_key = new_key.replace("lora_up", "lora.up") modified_new_checkpoint[new_key] = new_checkpoint[key] new_checkpoint = modified_new_checkpoint From 610be144b0f9ba41bdcb58b3ad2f658e26e379af Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 13:15:09 +0530 Subject: [PATCH 083/119] sanity --- src/diffusers/loaders.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 733a9abfa3ae..02400fffae94 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -373,14 +373,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict mapped_network_alphas = {} all_keys = list(state_dict.keys()) - temp = 0 for key in all_keys: value = state_dict.pop(key) attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) lora_grouped_dict[attn_processor_key][sub_key] = value - if temp == 0: - print(key, attn_processor_key, sub_key) - temp = 999 # Create another `mapped_network_alphas` dictionary so that we can properly map them. if network_alphas is not None: @@ -399,7 +395,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict raise ValueError( f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}" ) - temp = 0 + for key, value_dict in lora_grouped_dict.items(): attn_processor = self for sub_key in key.split("."): @@ -407,9 +403,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers # or add_{k,v,q,out_proj}_proj_lora layers. - if temp == 0: - print(f"Value dict: {value_dict.keys()}") - temp = 999 rank = value_dict["lora.down.weight"].shape[0] if isinstance(attn_processor, LoRACompatibleConv): @@ -434,6 +427,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict mapped_network_alphas.get(key), ) else: + print(type(attn_processor), attn_processor.__class__.__name__) raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} @@ -2556,15 +2550,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): lora_grouped_dict = defaultdict(dict) lora_layers_list = [] - temp = 0 all_keys = list(remaining_state_dict.keys()) for key in all_keys: value = remaining_state_dict.pop(key) attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) lora_grouped_dict[attn_processor_key][sub_key] = value - if temp == 0: - print(key, attn_processor_key, sub_key) - temp = 999 if len(remaining_state_dict) > 0: raise ValueError( @@ -2578,7 +2568,6 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers # or add_{k,v,q,out_proj}_proj_lora layers. - print(f"Value dict: {value_dict}.") rank = value_dict["lora.down.weight"].shape[0] if isinstance(attn_processor, LoRACompatibleConv): From 2027143f81b3fd513aa2a3148891e3df7885db1d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 13:17:09 +0530 Subject: [PATCH 084/119] sanity --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 02400fffae94..6a18faa02075 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -427,7 +427,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict mapped_network_alphas.get(key), ) else: - print(type(attn_processor), attn_processor.__class__.__name__) raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} @@ -2590,6 +2589,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): rank, ) else: + print(type(attn_processor), attn_processor.__class__.__name__) raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} From f7fde8a68df2b850e151122897b5d2e2cf21c772 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 13:19:59 +0530 Subject: [PATCH 085/119] fix: embeddings. --- src/diffusers/models/embeddings.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 3bdd758117cd..a40d16a8a89d 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -19,6 +19,7 @@ from torch import nn from .activations import get_activation +from ..models.lora import LoRACompatibleLinear def get_timestep_embedding( @@ -166,10 +167,10 @@ def __init__( ): super().__init__() - self.linear_1 = nn.Linear(in_channels, time_embed_dim) + self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim) if cond_proj_dim is not None: - self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + self.cond_proj = LoRACompatibleLinear(cond_proj_dim, in_channels, bias=False) else: self.cond_proj = None @@ -179,7 +180,7 @@ def __init__( time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out) if post_act_fn is None: self.post_act = None From b35f61fac3b4bc6c9e92ab6c94118e2468efb4a0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 13:23:42 +0530 Subject: [PATCH 086/119] fix: embeddings. --- src/diffusers/models/controlnet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 4347c294f81a..cd46a419f27f 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -21,6 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import ControlNetLoaderMixin, FromOriginalControlnetMixin, UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging +from ..models.lora import LoRACompatibleConv from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, @@ -80,7 +81,7 @@ def __init__( ): super().__init__() - self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + self.conv_in = LoRACompatibleConv(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) self.blocks = nn.ModuleList([]) From ebec2119cf705ea79eb1e4742cd5ed42e0c30cbb Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 13:25:17 +0530 Subject: [PATCH 087/119] fix: embeddings. --- src/diffusers/models/controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index cd46a419f27f..8c0130df0742 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -250,7 +250,7 @@ def __init__( # input conv_in_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = nn.Conv2d( + self.conv_in = LoRACompatibleConv( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) From 367e6c0b25462b6c699461c69789d46fecdd2dfa Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 14:45:54 +0530 Subject: [PATCH 088/119] remove prints. --- src/diffusers/loaders.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 6a18faa02075..435423952d22 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2589,7 +2589,6 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): rank, ) else: - print(type(attn_processor), attn_processor.__class__.__name__) raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} From dd0ce66cc42cb9db204ae3ddd94f12dbb1df62ec Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 5 Sep 2023 15:04:00 +0530 Subject: [PATCH 089/119] make style --- src/diffusers/loaders.py | 2 +- src/diffusers/models/controlnet.py | 2 +- src/diffusers/models/embeddings.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 435423952d22..b61b114fbd0a 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -395,7 +395,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict raise ValueError( f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}" ) - + for key, value_dict in lora_grouped_dict.items(): attn_processor = self for sub_key in key.split("."): diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 8c0130df0742..39cbc3e712bd 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -20,8 +20,8 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import ControlNetLoaderMixin, FromOriginalControlnetMixin, UNet2DConditionLoadersMixin -from ..utils import BaseOutput, logging from ..models.lora import LoRACompatibleConv +from ..utils import BaseOutput, logging from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index a40d16a8a89d..d6a919ef5d34 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -18,8 +18,8 @@ import torch from torch import nn -from .activations import get_activation from ..models.lora import LoRACompatibleLinear +from .activations import get_activation def get_timestep_embedding( From f17befc1a0fe10c1f8c776a6bb89003249725dd6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 18 Sep 2023 11:17:27 +0100 Subject: [PATCH 090/119] fix: doc --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index b61b114fbd0a..4adf366851f3 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1025,7 +1025,7 @@ def lora_state_dict( dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). controlnet (`bool`, *optional*, defaults to False): - If we're convert a ControlNet LoRA checkpoint. + If we're converting a ControlNet LoRA checkpoint. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. From a66a46847adc8599fa05b6c7cadea9dcf65068b5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 18 Sep 2023 11:36:23 +0100 Subject: [PATCH 091/119] debugging --- src/diffusers/loaders.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 4adf366851f3..4f9a63262ede 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2544,6 +2544,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): remaining_state_dict = { k: v for k, v in converted_state_dict.items() if k in load_state_dict_results.unexpected_keys } + converted_sd_keys = set(converted_state_dict.keys()) + print(f"Differences in between the keys here: {converted_sd_keys.difference(set(load_state_dict_results.unexpected_keys))}") # Handle LoRA. lora_grouped_dict = defaultdict(dict) From 96993823117f892ce1e40a79aa137aff2653ca4b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 18 Sep 2023 11:55:12 +0100 Subject: [PATCH 092/119] debugging --- src/diffusers/loaders.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 4f9a63262ede..4046bb337ee9 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2545,7 +2545,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): k: v for k, v in converted_state_dict.items() if k in load_state_dict_results.unexpected_keys } converted_sd_keys = set(converted_state_dict.keys()) - print(f"Differences in between the keys here: {converted_sd_keys.difference(set(load_state_dict_results.unexpected_keys))}") + # print(f"Differences in between the keys here: {converted_sd_keys.difference(set(load_state_dict_results.unexpected_keys))}") + print(f"Remaining state dict: {remaining_state_dict}") # Handle LoRA. lora_grouped_dict = defaultdict(dict) From 70c0c684289d95a3f71ea3a98448e910eb6ec420 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 18 Sep 2023 11:57:05 +0100 Subject: [PATCH 093/119] debugging --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 4046bb337ee9..b3c0025a960f 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2546,7 +2546,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): } converted_sd_keys = set(converted_state_dict.keys()) # print(f"Differences in between the keys here: {converted_sd_keys.difference(set(load_state_dict_results.unexpected_keys))}") - print(f"Remaining state dict: {remaining_state_dict}") + print(f"Remaining state dict: {remaining_state_dict.keys()}") # Handle LoRA. lora_grouped_dict = defaultdict(dict) From 432fa6b65d294ca4cdcff07febf1ec08866b33e2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 18 Sep 2023 11:58:45 +0100 Subject: [PATCH 094/119] debugging --- src/diffusers/loaders.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index b3c0025a960f..4805c5cc9a1d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2546,7 +2546,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): } converted_sd_keys = set(converted_state_dict.keys()) # print(f"Differences in between the keys here: {converted_sd_keys.difference(set(load_state_dict_results.unexpected_keys))}") - print(f"Remaining state dict: {remaining_state_dict.keys()}") + # print(f"Remaining state dict: {remaining_state_dict.keys()}") + print(set(load_state_dict_results.unexpected_keys).difference(set(remaining_state_dict))) # Handle LoRA. lora_grouped_dict = defaultdict(dict) From b1099e8b51f5ffc109aa6e84627037ca989866d6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 18 Sep 2023 12:38:56 +0100 Subject: [PATCH 095/119] minor clean up --- src/diffusers/loaders.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 4805c5cc9a1d..4adf366851f3 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2544,10 +2544,6 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): remaining_state_dict = { k: v for k, v in converted_state_dict.items() if k in load_state_dict_results.unexpected_keys } - converted_sd_keys = set(converted_state_dict.keys()) - # print(f"Differences in between the keys here: {converted_sd_keys.difference(set(load_state_dict_results.unexpected_keys))}") - # print(f"Remaining state dict: {remaining_state_dict.keys()}") - print(set(load_state_dict_results.unexpected_keys).difference(set(remaining_state_dict))) # Handle LoRA. lora_grouped_dict = defaultdict(dict) From 87ee3728bc9225d26dc928690dd4e529b363bce7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 18 Sep 2023 22:49:02 +0100 Subject: [PATCH 096/119] debugging --- src/diffusers/loaders.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 4adf366851f3..fc54e33e542e 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -397,6 +397,8 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict ) for key, value_dict in lora_grouped_dict.items(): + if "time" in key: + print(key) attn_processor = self for sub_key in key.split("."): attn_processor = getattr(attn_processor, sub_key) From 05b7f8b2badf9fa02343af9c94823303764bf993 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 18 Sep 2023 22:55:49 +0100 Subject: [PATCH 097/119] debugging --- src/diffusers/loaders.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index fc54e33e542e..e9ad960e4b37 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2563,6 +2563,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): ) for key, value_dict in lora_grouped_dict.items(): + if "time" in key: + print(key) attn_processor = self for sub_key in key.split("."): attn_processor = getattr(attn_processor, sub_key) From e1286db6d2c950ebb5eb902d6925c828fdb2c5e5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 18 Sep 2023 23:11:33 +0100 Subject: [PATCH 098/119] debugging --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e9ad960e4b37..1f5926e30a41 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -401,6 +401,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict print(key) attn_processor = self for sub_key in key.split("."): + print(f"{key}: {sub_key}") attn_processor = getattr(attn_processor, sub_key) # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers From 9cfce5f19e39ca4f677d2be429c598056123a53c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 18 Sep 2023 23:13:35 +0100 Subject: [PATCH 099/119] debugging --- src/diffusers/loaders.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 1f5926e30a41..8cf029dfd6ce 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2568,6 +2568,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): print(key) attn_processor = self for sub_key in key.split("."): + if "time" in key: + print(f"{key}: {sub_key}") attn_processor = getattr(attn_processor, sub_key) # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers From 57d52b4e8e5dc456e72e29081a6122c54f981546 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 19 Sep 2023 09:08:04 +0100 Subject: [PATCH 100/119] debugging --- src/diffusers/loaders.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 8cf029dfd6ce..dbe39dcb5e9c 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2556,6 +2556,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): for key in all_keys: value = remaining_state_dict.pop(key) attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + print(f"key: {key}, attn_processor_key: {attn_processor_key}, sub_key: {sub_key}") lora_grouped_dict[attn_processor_key][sub_key] = value if len(remaining_state_dict) > 0: @@ -2564,12 +2565,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): ) for key, value_dict in lora_grouped_dict.items(): - if "time" in key: - print(key) attn_processor = self for sub_key in key.split("."): - if "time" in key: - print(f"{key}: {sub_key}") attn_processor = getattr(attn_processor, sub_key) # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers From 8dcc44ba31084785499d13a178dd0c06ce83c0d6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 19 Sep 2023 09:08:24 +0100 Subject: [PATCH 101/119] debugging --- src/diffusers/loaders.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index dbe39dcb5e9c..2b18b50639b9 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2556,7 +2556,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): for key in all_keys: value = remaining_state_dict.pop(key) attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) - print(f"key: {key}, attn_processor_key: {attn_processor_key}, sub_key: {sub_key}") + if "time" in key: + print(f"key: {key}, attn_processor_key: {attn_processor_key}, sub_key: {sub_key}") lora_grouped_dict[attn_processor_key][sub_key] = value if len(remaining_state_dict) > 0: From a054d80ceb62ff12daf84c898ea2243e8c1db869 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 28 Sep 2023 11:11:19 +0530 Subject: [PATCH 102/119] better support? --- src/diffusers/loaders.py | 6 ++- src/diffusers/models/lora.py | 82 ++++++++++++++++++++++++++++++------ 2 files changed, 74 insertions(+), 14 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 2b18b50639b9..8b514505e3a6 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2556,8 +2556,6 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): for key in all_keys: value = remaining_state_dict.pop(key) attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) - if "time" in key: - print(f"key: {key}, attn_processor_key: {attn_processor_key}, sub_key: {sub_key}") lora_grouped_dict[attn_processor_key][sub_key] = value if len(remaining_state_dict) > 0: @@ -2586,12 +2584,16 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): kernel_size=kernel_size, stride=attn_processor.stride, padding=attn_processor.padding, + initial_weight=attn_processor.weight, + initial_bias=attn_processor.bias, ) elif isinstance(attn_processor, LoRACompatibleLinear): lora = LoRALinearLayer( attn_processor.in_features, attn_processor.out_features, rank, + initial_weight=attn_processor.weight, + initial_bias=attn_processor.bias, ) else: raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 834a7051b06d..3b764637040f 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -40,7 +40,17 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): class LoRALinearLayer(nn.Module): - def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): + def __init__( + self, + in_features, + out_features, + rank=4, + network_alpha=None, + device=None, + dtype=None, + initial_weight=None, + initial_bias=None, + ): super().__init__() self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) @@ -52,6 +62,10 @@ def __init__(self, in_features, out_features, rank=4, network_alpha=None, device self.out_features = out_features self.in_features = in_features + # Control-LoRA specific. + self.initial_weight = initial_weight + self.initial_bias = initial_bias + nn.init.normal_(self.down.weight, std=1 / rank) nn.init.zeros_(self.up.weight) @@ -59,18 +73,40 @@ def forward(self, hidden_states): orig_dtype = hidden_states.dtype dtype = self.down.weight.dtype - down_hidden_states = self.down(hidden_states.to(dtype)) - up_hidden_states = self.up(down_hidden_states) + if not (self.initial_weight and self.initial_bias): + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) - if self.network_alpha is not None: - up_hidden_states *= self.network_alpha / self.rank + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank - return up_hidden_states.to(orig_dtype) + return up_hidden_states.to(orig_dtype) + else: + initial_weight = self.initial_weight + if initial_weight.device != hidden_states.device: + initial_weight = initial_weight.to(hidden_states.device) + return torch.nn.functional.linear( + hidden_states.to(dtype), + initial_weight + + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))) + .reshape(self.initial_weight.shape) + .type(orig_dtype), + self.initial_bias, + ) class LoRAConv2dLayer(nn.Module): def __init__( - self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None + self, + in_features, + out_features, + rank=4, + kernel_size=(1, 1), + stride=(1, 1), + padding=0, + network_alpha=None, + initial_weight=None, + initial_bias=None, ): super().__init__() @@ -84,6 +120,13 @@ def __init__( self.network_alpha = network_alpha self.rank = rank + # Control-LoRA specific. + self.initial_weight = initial_weight + self.initial_bias = initial_bias + self.stride = stride + self.kernel_size = kernel_size + self.padding = padding + nn.init.normal_(self.down.weight, std=1 / rank) nn.init.zeros_(self.up.weight) @@ -91,13 +134,28 @@ def forward(self, hidden_states): orig_dtype = hidden_states.dtype dtype = self.down.weight.dtype - down_hidden_states = self.down(hidden_states.to(dtype)) - up_hidden_states = self.up(down_hidden_states) + if not (self.initial_weight and self.initial_bias): + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) - if self.network_alpha is not None: - up_hidden_states *= self.network_alpha / self.rank + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank - return up_hidden_states.to(orig_dtype) + return up_hidden_states.to(orig_dtype) + else: + initial_weight = self.initial_weight + if initial_weight.device != hidden_states.device: + initial_weight = initial_weight.to(hidden_states.device) + return torch.nn.functional.conv2d( + hidden_states, + initial_weight + + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))) + .reshape(self.initial_weight.shape) + .type(orig_dtype), + self.initial_bias, + self.stride, + self.padding, + ) class LoRACompatibleConv(nn.Conv2d): From 64284b174299660dc7d86fc81f8c91e974b7be01 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 28 Sep 2023 11:14:59 +0530 Subject: [PATCH 103/119] make strict loading false --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 8b514505e3a6..7f20f1c41034 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2599,7 +2599,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} - lora.load_state_dict(value_dict) + lora.load_state_dict(value_dict, strict=False) lora_layers_list.append((attn_processor, lora)) # set correct dtype & device From 13e8c8777733bc5d4255caf75508bdb94dd32411 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 28 Sep 2023 11:19:18 +0530 Subject: [PATCH 104/119] better conditioning --- src/diffusers/models/lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 3b764637040f..ae596e578473 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -73,7 +73,7 @@ def forward(self, hidden_states): orig_dtype = hidden_states.dtype dtype = self.down.weight.dtype - if not (self.initial_weight and self.initial_bias): + if self.initial_weight is None and self.initial_bias is None: down_hidden_states = self.down(hidden_states.to(dtype)) up_hidden_states = self.up(down_hidden_states) @@ -134,7 +134,7 @@ def forward(self, hidden_states): orig_dtype = hidden_states.dtype dtype = self.down.weight.dtype - if not (self.initial_weight and self.initial_bias): + if self.initial_weight is None and self.initial_bias is None: down_hidden_states = self.down(hidden_states.to(dtype)) up_hidden_states = self.up(down_hidden_states) From b42169482c4cb611dd0324a4529c143985a17772 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 28 Sep 2023 11:55:19 +0530 Subject: [PATCH 105/119] another --- src/diffusers/models/lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index ae596e578473..bd61ed6ad5c7 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -88,7 +88,7 @@ def forward(self, hidden_states): return torch.nn.functional.linear( hidden_states.to(dtype), initial_weight - + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))) + + (torch.mm(self.up.weight.data.flatten(start_dim=1), self.down.weight.data.flatten(start_dim=1))) .reshape(self.initial_weight.shape) .type(orig_dtype), self.initial_bias, @@ -149,7 +149,7 @@ def forward(self, hidden_states): return torch.nn.functional.conv2d( hidden_states, initial_weight - + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))) + + (torch.mm(self.up.weight.flatten(start_dim=1), self.down.weight.flatten(start_dim=1))) .reshape(self.initial_weight.shape) .type(orig_dtype), self.initial_bias, From 5ceb0a2f089023d05d27258d5cbb6f65d90260e0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 28 Sep 2023 12:01:49 +0530 Subject: [PATCH 106/119] log --- src/diffusers/models/lora.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index bd61ed6ad5c7..9b3a2413db2d 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -82,6 +82,7 @@ def forward(self, hidden_states): return up_hidden_states.to(orig_dtype) else: + print(f"{self.__class__} is running Control LoRA.") initial_weight = self.initial_weight if initial_weight.device != hidden_states.device: initial_weight = initial_weight.to(hidden_states.device) @@ -143,6 +144,7 @@ def forward(self, hidden_states): return up_hidden_states.to(orig_dtype) else: + print(f"{self.__class__} is running Control LoRA.") initial_weight = self.initial_weight if initial_weight.device != hidden_states.device: initial_weight = initial_weight.to(hidden_states.device) From 567a2dee1a55dcc77f005730f3643ad5ac6988d5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 28 Sep 2023 12:31:52 +0530 Subject: [PATCH 107/119] log --- src/diffusers/models/lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 9b3a2413db2d..9dea1bc2b1c4 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -82,7 +82,7 @@ def forward(self, hidden_states): return up_hidden_states.to(orig_dtype) else: - print(f"{self.__class__} is running Control LoRA.") + # print(f"{self.__class__} is running Control LoRA.") initial_weight = self.initial_weight if initial_weight.device != hidden_states.device: initial_weight = initial_weight.to(hidden_states.device) @@ -144,7 +144,7 @@ def forward(self, hidden_states): return up_hidden_states.to(orig_dtype) else: - print(f"{self.__class__} is running Control LoRA.") + # print(f"{self.__class__} is running Control LoRA.") initial_weight = self.initial_weight if initial_weight.device != hidden_states.device: initial_weight = initial_weight.to(hidden_states.device) From c6a04063cc69c37a1da085cfa42d0beb11812e08 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 28 Sep 2023 13:14:18 +0530 Subject: [PATCH 108/119] remove print --- src/diffusers/loaders.py | 17 +++--- src/diffusers/models/lora.py | 106 +++++++++++++++++------------------ 2 files changed, 59 insertions(+), 64 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 7f20f1c41034..f007b0599418 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -397,11 +397,8 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict ) for key, value_dict in lora_grouped_dict.items(): - if "time" in key: - print(key) attn_processor = self for sub_key in key.split("."): - print(f"{key}: {sub_key}") attn_processor = getattr(attn_processor, sub_key) # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers @@ -2540,7 +2537,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): load_state_dict_results = self.load_state_dict(converted_state_dict, strict=False) if not all("lora" in k for k in load_state_dict_results.unexpected_keys): raise ValueError( - f"The unexpected keys must only belong to LoRA parameters at this point: {load_state_dict_results.unexpected_keys}" + f"The unexpected keys must only belong to LoRA parameters at this point, but found the following keys that are non-LoRA\n: {load_state_dict_results.unexpected_keys}" ) # Filter out the rest of the state_dict for handling LoRA. @@ -2584,22 +2581,24 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): kernel_size=kernel_size, stride=attn_processor.stride, padding=attn_processor.padding, - initial_weight=attn_processor.weight, - initial_bias=attn_processor.bias, + # initial_weight=attn_processor.weight, + # initial_bias=attn_processor.bias, ) elif isinstance(attn_processor, LoRACompatibleLinear): lora = LoRALinearLayer( attn_processor.in_features, attn_processor.out_features, rank, - initial_weight=attn_processor.weight, - initial_bias=attn_processor.bias, + # initial_weight=attn_processor.weight, + # initial_bias=attn_processor.bias, ) else: raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} - lora.load_state_dict(value_dict, strict=False) + load_state_dict_results = lora.load_state_dict(value_dict, strict=False) + if not all("initial" in k for k in load_state_dict_results.unexpected_keys): + raise ValueError("Incorrect `value_dict` for the LoRA layer.") lora_layers_list.append((attn_processor, lora)) # set correct dtype & device diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 9dea1bc2b1c4..791a19fb1f10 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -48,8 +48,8 @@ def __init__( network_alpha=None, device=None, dtype=None, - initial_weight=None, - initial_bias=None, + # initial_weight=None, + # initial_bias=None, ): super().__init__() @@ -62,9 +62,9 @@ def __init__( self.out_features = out_features self.in_features = in_features - # Control-LoRA specific. - self.initial_weight = initial_weight - self.initial_bias = initial_bias + # # Control-LoRA specific. + # self.initial_weight = initial_weight + # self.initial_bias = initial_bias nn.init.normal_(self.down.weight, std=1 / rank) nn.init.zeros_(self.up.weight) @@ -73,27 +73,25 @@ def forward(self, hidden_states): orig_dtype = hidden_states.dtype dtype = self.down.weight.dtype - if self.initial_weight is None and self.initial_bias is None: - down_hidden_states = self.down(hidden_states.to(dtype)) - up_hidden_states = self.up(down_hidden_states) + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) - if self.network_alpha is not None: - up_hidden_states *= self.network_alpha / self.rank + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank - return up_hidden_states.to(orig_dtype) - else: - # print(f"{self.__class__} is running Control LoRA.") - initial_weight = self.initial_weight - if initial_weight.device != hidden_states.device: - initial_weight = initial_weight.to(hidden_states.device) - return torch.nn.functional.linear( - hidden_states.to(dtype), - initial_weight - + (torch.mm(self.up.weight.data.flatten(start_dim=1), self.down.weight.data.flatten(start_dim=1))) - .reshape(self.initial_weight.shape) - .type(orig_dtype), - self.initial_bias, - ) + return up_hidden_states.to(orig_dtype) + # else: + # initial_weight = self.initial_weight + # if initial_weight.device != hidden_states.device: + # initial_weight = initial_weight.to(hidden_states.device) + # return torch.nn.functional.linear( + # hidden_states.to(dtype), + # initial_weight + # + (torch.mm(self.up.weight.data.flatten(start_dim=1), self.down.weight.data.flatten(start_dim=1))) + # .reshape(self.initial_weight.shape) + # .type(orig_dtype), + # self.initial_bias, + # ) class LoRAConv2dLayer(nn.Module): @@ -106,8 +104,8 @@ def __init__( stride=(1, 1), padding=0, network_alpha=None, - initial_weight=None, - initial_bias=None, + # initial_weight=None, + # initial_bias=None, ): super().__init__() @@ -121,12 +119,12 @@ def __init__( self.network_alpha = network_alpha self.rank = rank - # Control-LoRA specific. - self.initial_weight = initial_weight - self.initial_bias = initial_bias - self.stride = stride - self.kernel_size = kernel_size - self.padding = padding + # # Control-LoRA specific. + # self.initial_weight = initial_weight + # self.initial_bias = initial_bias + # self.stride = stride + # self.kernel_size = kernel_size + # self.padding = padding nn.init.normal_(self.down.weight, std=1 / rank) nn.init.zeros_(self.up.weight) @@ -135,29 +133,27 @@ def forward(self, hidden_states): orig_dtype = hidden_states.dtype dtype = self.down.weight.dtype - if self.initial_weight is None and self.initial_bias is None: - down_hidden_states = self.down(hidden_states.to(dtype)) - up_hidden_states = self.up(down_hidden_states) - - if self.network_alpha is not None: - up_hidden_states *= self.network_alpha / self.rank - - return up_hidden_states.to(orig_dtype) - else: - # print(f"{self.__class__} is running Control LoRA.") - initial_weight = self.initial_weight - if initial_weight.device != hidden_states.device: - initial_weight = initial_weight.to(hidden_states.device) - return torch.nn.functional.conv2d( - hidden_states, - initial_weight - + (torch.mm(self.up.weight.flatten(start_dim=1), self.down.weight.flatten(start_dim=1))) - .reshape(self.initial_weight.shape) - .type(orig_dtype), - self.initial_bias, - self.stride, - self.padding, - ) + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + # else: + # initial_weight = self.initial_weight + # if initial_weight.device != hidden_states.device: + # initial_weight = initial_weight.to(hidden_states.device) + # return torch.nn.functional.conv2d( + # hidden_states, + # initial_weight + # + (torch.mm(self.up.weight.flatten(start_dim=1), self.down.weight.flatten(start_dim=1))) + # .reshape(self.initial_weight.shape) + # .type(orig_dtype), + # self.initial_bias, + # self.stride, + # self.padding, + # ) class LoRACompatibleConv(nn.Conv2d): From 86f5980ce8b83c8585c98994ba489aa7c6f8fe9c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 28 Sep 2023 14:28:51 +0530 Subject: [PATCH 109/119] change class name --- src/diffusers/loaders.py | 2 +- src/diffusers/models/controlnet.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index f007b0599418..79274c98659d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2514,7 +2514,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): return controlnet -class ControlNetLoaderMixin(LoraLoaderMixin): +class ControlLoRAMixin(LoraLoaderMixin): # Simplify ControlNet LoRA loading. def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): from .models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 39cbc3e712bd..f32e789079ef 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -19,7 +19,7 @@ from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import ControlNetLoaderMixin, FromOriginalControlnetMixin, UNet2DConditionLoadersMixin +from ..loaders import ControlLoRAMixin, FromOriginalControlnetMixin, UNet2DConditionLoadersMixin from ..models.lora import LoRACompatibleConv from ..utils import BaseOutput, logging from .attention_processor import ( @@ -109,7 +109,7 @@ def forward(self, conditioning): class ControlNetModel( - ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, FromOriginalControlnetMixin, ControlNetLoaderMixin + ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, FromOriginalControlnetMixin, ControlLoRAMixin ): """ A ControlNet model. From 4087dbfbb660856e2a67afc719db62ebc6e6becf Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 9 Oct 2023 15:36:27 +0200 Subject: [PATCH 110/119] step by step debug --- src/diffusers/models/controlnet.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index f32e789079ef..afddb04e51d4 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -722,6 +722,7 @@ def forward( timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) + print(f"t_emb: {t_emb}") # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. @@ -729,6 +730,8 @@ def forward( t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) + print(f"emb: {emb}") + aug_emb = None if self.class_embedding is not None: From ef430bfae9c22ad4a6eec437a6e72d7ca490e498 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 9 Oct 2023 16:52:55 +0200 Subject: [PATCH 111/119] step by step debug --- src/diffusers/models/controlnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index afddb04e51d4..e9620c871fe5 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -722,7 +722,7 @@ def forward( timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) - print(f"t_emb: {t_emb}") + print(f"t_emb: {t_emb[0, :3]}") # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. @@ -730,7 +730,7 @@ def forward( t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) - print(f"emb: {emb}") + print(f"emb: {emb[0, :3]}") aug_emb = None From c4ad76e16c746f563f2878203e64f88967c93aea Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 9 Oct 2023 17:00:44 +0200 Subject: [PATCH 112/119] have t printed. --- src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 6ff88e8275f2..abd9710a1405 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1126,6 +1126,7 @@ def __call__( controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] + print(f"t: {t.dtype} {t.shape}") down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, From bf7afc2f788b8ab059f260fbd026e97c22799068 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 9 Oct 2023 17:11:08 +0200 Subject: [PATCH 113/119] remove dtype of t from commit trail. --- src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index abd9710a1405..7ce400ed4e29 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1126,7 +1126,7 @@ def __call__( controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] - print(f"t: {t.dtype} {t.shape}") + # print(f"t: {t.dtype} {t.shape}") down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, From 5871ecc980a148b0304e68204b329eaf6a5a07a8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 9 Oct 2023 17:13:29 +0200 Subject: [PATCH 114/119] remove dtype of t from commit trail. --- src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 7ce400ed4e29..6ff88e8275f2 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1126,7 +1126,6 @@ def __call__( controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] - # print(f"t: {t.dtype} {t.shape}") down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, From 332cbfd303d9c1635f8fd34751ce290f68caf10b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 9 Oct 2023 21:56:33 +0200 Subject: [PATCH 115/119] debug --- src/diffusers/models/controlnet.py | 1 + src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index e9620c871fe5..ef9928001809 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -97,6 +97,7 @@ def __init__( def forward(self, conditioning): embedding = self.conv_in(conditioning) + print(f"From conv_in embedding of ControlNet: {embedding[0, :5]}") embedding = F.silu(embedding) for block in self.blocks: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 6ff88e8275f2..d28993feafaf 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1070,6 +1070,7 @@ def __call__( target_size = target_size or (height, width) add_text_embeds = pooled_prompt_embeds + print(f"pooled_prompt_embeds: {pooled_prompt_embeds.shape}") add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype ) @@ -1126,6 +1127,7 @@ def __call__( controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] + print(f"ControlNet conditioning dimension: {image.shape}.") down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, From 26662de8686a7f91bc1bfeed4f4946efbf6bed85 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 9 Oct 2023 21:58:17 +0200 Subject: [PATCH 116/119] debug --- src/diffusers/models/controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index ef9928001809..c3954e9f98f6 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -97,7 +97,7 @@ def __init__( def forward(self, conditioning): embedding = self.conv_in(conditioning) - print(f"From conv_in embedding of ControlNet: {embedding[0, :5]}") + print(f"From conv_in embedding of ControlNet: {embedding[0, :5, :5, :5]}") embedding = F.silu(embedding) for block in self.blocks: From b08a0a61ce0681a4943594e48d1aa113d213cf70 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 9 Oct 2023 22:03:53 +0200 Subject: [PATCH 117/119] debug --- src/diffusers/models/controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index c3954e9f98f6..dbe35e90dcb2 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -97,7 +97,7 @@ def __init__( def forward(self, conditioning): embedding = self.conv_in(conditioning) - print(f"From conv_in embedding of ControlNet: {embedding[0, :5, :5, :5]}") + print(f"From conv_in embedding of ControlNet: {embedding[0, :5, :5, -1]}") embedding = F.silu(embedding) for block in self.blocks: From ca6895a1140e2675d566fa99c4f29025d86ae9f4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 9 Oct 2023 22:07:41 +0200 Subject: [PATCH 118/119] debug --- src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index d28993feafaf..627df5b86b62 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1127,7 +1127,6 @@ def __call__( controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] - print(f"ControlNet conditioning dimension: {image.shape}.") down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, From 6dc4d694c4865119c405cd995c567498ce2bf295 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 10 Oct 2023 09:29:01 +0200 Subject: [PATCH 119/119] debug --- src/diffusers/models/controlnet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index dbe35e90dcb2..4c8c1d93ab20 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -771,6 +771,7 @@ def forward( # 2. pre-process sample = self.conv_in(sample) + print(f"From ControlNet conv_in: {sample[0, :5, :5, -1]}") controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) sample = sample + controlnet_cond